利用RNN生成中文语句实例


本实例使用2014人民日报上的一些新闻(约28万条,60M),使用PyTorch提供的nn.rnn模型,根据提示字符串预测给定长度的语句。

数据下载

提取码:6l9e

3.3.1 模型架构

图1-13 RNN实例模型架构图

3.3.2 导入需要的模块

3.3.3 定义预处理函数

使用torch.utils.data生成可迭代的数据集。

3.3.4 定义模型

根据图1-13构建模型,以Embedding为输入层,隐含节点数为256,共两层。为何使用Embedding层作为输入层,而不使用One-hot编码作为输入层?Embedding输入层,除可以有效压缩维度空间外(与以one-hot编码为输入层),更重要的是Embedding层在整个迭代过程中参与学习。

3.3.5 定义训练模型函数

为便于管理,这里参数传入采用argparse方式。

3.3.6 设置参数

这里参数设置运行环境为jupyter notebook ,如果在命令行运行需要做一些改动。

3.3.7 运行模型

3.3.8 练习

1.把Emedding层改为One-hot层
2.目前学习率为固定,把固定改为动态(如与迭代次数相关关联),查看损失值的变化
3.修改学习率及迭代次数等,比较损失值的变化。
4.使用LSTM或GRU模型
5.使用GPT或BERT模型