主要思路:
对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

示例的唐诗大概有40000多首,
首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。
Dataset_Dataloader.py
- import torch
- import torch.nn as nn
- from torch.utils.data import Dataset, DataLoader
-
-
- def deal_tangshi():
- with open("poems.txt", "r", encoding="utf-8") as fr:
- lines = fr.read().strip().split("\n")
-
- tangshis = []
- for line in lines:
- splits = line.split(":")
- if len(splits) != 2:
- continue
- tangshis.append("S" + splits[1] + "E")
-
- word2idx = {"S": 0, "E": 1}
- word2idx_count = 2
-
- tangshi_ids = []
-
- for tangshi in tangshis:
- for word in tangshi:
- if word not in word2idx:
- word2idx[word] = word2idx_count
- word2idx_count += 1
-
- idx2word = {idx: w for w, idx in word2idx.items()}
-
- for tangshi in tangshis:
- tangshi_ids.extend([word2idx[w] for w in tangshi])
-
- return word2idx, idx2word, tangshis, word2idx_count, tangshi_ids
-
-
- word2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()
-
-
- class TangShiDataset(Dataset):
- def __init__(self, tangshi_ids, num_chars):
- # 语料数据
- self.tangshi_ids = tangshi_ids
- # 语料长度
- self.num_chars = num_chars
- # 词的数量
- self.word_count = len(self.tangshi_ids)
- # 句子数量
- self.number = self.word_count // self.num_chars
-
- def __len__(self):
- return self.number
-
- def __getitem__(self, idx):
- # 修正索引值到: [0, self.word_count - 1]
- start = min(max(idx, 0), self.word_count - self.num_chars - 2)
-
- x = self.tangshi_ids[start: start + self.num_chars]
- y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]
-
- return torch.tensor(x), torch.tensor(y)
-
-
- def __test_Dataset():
- dataset = TangShiDataset(tangshi_ids, 8)
- x, y = dataset[0]
-
- print(x, y)
-
-
- if __name__ == '__main__':
- # deal_tangshi()
- __test_Dataset()
TangShiModel.py:唐诗的模型
- import torch
- import torch.nn as nn
- from Dataset_Dataloader import *
- import torch.nn.functional as F
-
-
- class TangShiRNN(nn.Module):
- def __init__(self, vocab_size):
- super().__init__()
- # 初始化词嵌入层
- self.ebd = nn.Embedding(vocab_size, 128)
- # 循环网络层
- self.rnn = nn.RNN(128, 128, 1)
- # 输出层
- self.out = nn.Linear(128, vocab_size)
-
- def forward(self, inputs, hidden):
-
- embed = self.ebd(inputs)
-
- # 正则化层
- embed = F.dropout(embed, p=0.2)
-
- output, hidden = self.rnn(embed.transpose(0, 1), hidden)
-
- # 正则化层
- embed = F.dropout(output, p=0.2)
-
- output = self.out(output.squeeze())
-
- return output, hidden
-
- def init_hidden(self):
- return torch.zeros(1, 64, 128)
main.py:
- import time
-
- import torch
-
- from Dataset_Dataloader import *
- from TangShiModel import *
- import torch.optim as optim
- from tqdm import tqdm
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- def train():
- dataset = TangShiDataset(tangshi_ids, 128)
- epochs = 100
- model = TangShiRNN(word2idx_count).to(device)
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=1e-3)
-
- for idx in range(epochs):
- dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
- start_time = time.time()
- total_loss = 0
- total_num = 0
- total_correct = 0
- total_correct_num = 0
- hidden = model.init_hidden()
-
- for x, y in tqdm(dataloader):
- x = x.to(device)
- y = y.to(device)
- # 隐藏状态
- hidden = model.init_hidden()
- hidden = hidden.to(device)
- # 模型计算
- output, hidden = model(x, hidden)
- # print(output.shape)
- # print(y.shape)
- # 计算损失
- loss = criterion(output.permute(1, 2, 0), y)
- # 梯度清零
- optimizer.zero_grad()
- # 反向传播
- loss.backward()
- # 参数更新
- optimizer.step()
-
- total_loss += loss.sum().item()
- total_num += len(y)
- total_correct_num += y.shape[0] * y.shape[1]
- # print(output.shape)
- total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()
-
- print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %
- (idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))
-
- torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")
-
-
- if __name__ == '__main__':
- train()
predict.py:
- import torch
- import torch.nn as nn
- from Dataset_Dataloader import *
- from TangShiModel import *
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
- def predict():
- model = TangShiRNN(word2idx_count)
- model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))
-
- model.eval()
-
- hidden = torch.zeros(1, 1, 128)
-
- start_word = input("输入第一个字:")
-
- flag = None
-
- tangshi_strs = []
-
- while True:
- if not flag:
- outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)
- tangshi_strs.append("S")
- flag = True
- else:
- tangshi_strs.append(start_word)
- outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)
- top_i = torch.argmax(outputs, dim=-1)
-
- if top_i.item() == word2idx["E"]:
- break
-
- print(top_i)
-
- start_word = idx2word[top_i.item()]
- print(tangshi_strs)
-
-
- if __name__ == '__main__':
- predict()
完整代码如下:
https://github.com/STZZ-1992/tangshi-generator.git
https://github.com/STZZ-1992/tangshi-generator.git