• 深度学习之生成唐诗案例(Pytorch版)


    主要思路:

    对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

     示例的唐诗大概有40000多首,

    首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

    Dataset_Dataloader.py
    1. import torch
    2. import torch.nn as nn
    3. from torch.utils.data import Dataset, DataLoader
    4. def deal_tangshi():
    5. with open("poems.txt", "r", encoding="utf-8") as fr:
    6. lines = fr.read().strip().split("\n")
    7. tangshis = []
    8. for line in lines:
    9. splits = line.split(":")
    10. if len(splits) != 2:
    11. continue
    12. tangshis.append("S" + splits[1] + "E")
    13. word2idx = {"S": 0, "E": 1}
    14. word2idx_count = 2
    15. tangshi_ids = []
    16. for tangshi in tangshis:
    17. for word in tangshi:
    18. if word not in word2idx:
    19. word2idx[word] = word2idx_count
    20. word2idx_count += 1
    21. idx2word = {idx: w for w, idx in word2idx.items()}
    22. for tangshi in tangshis:
    23. tangshi_ids.extend([word2idx[w] for w in tangshi])
    24. return word2idx, idx2word, tangshis, word2idx_count, tangshi_ids
    25. word2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()
    26. class TangShiDataset(Dataset):
    27. def __init__(self, tangshi_ids, num_chars):
    28. # 语料数据
    29. self.tangshi_ids = tangshi_ids
    30. # 语料长度
    31. self.num_chars = num_chars
    32. # 词的数量
    33. self.word_count = len(self.tangshi_ids)
    34. # 句子数量
    35. self.number = self.word_count // self.num_chars
    36. def __len__(self):
    37. return self.number
    38. def __getitem__(self, idx):
    39. # 修正索引值到: [0, self.word_count - 1]
    40. start = min(max(idx, 0), self.word_count - self.num_chars - 2)
    41. x = self.tangshi_ids[start: start + self.num_chars]
    42. y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]
    43. return torch.tensor(x), torch.tensor(y)
    44. def __test_Dataset():
    45. dataset = TangShiDataset(tangshi_ids, 8)
    46. x, y = dataset[0]
    47. print(x, y)
    48. if __name__ == '__main__':
    49. # deal_tangshi()
    50. __test_Dataset()
    TangShiModel.py:唐诗的模型
    
    1. import torch
    2. import torch.nn as nn
    3. from Dataset_Dataloader import *
    4. import torch.nn.functional as F
    5. class TangShiRNN(nn.Module):
    6. def __init__(self, vocab_size):
    7. super().__init__()
    8. # 初始化词嵌入层
    9. self.ebd = nn.Embedding(vocab_size, 128)
    10. # 循环网络层
    11. self.rnn = nn.RNN(128, 128, 1)
    12. # 输出层
    13. self.out = nn.Linear(128, vocab_size)
    14. def forward(self, inputs, hidden):
    15. embed = self.ebd(inputs)
    16. # 正则化层
    17. embed = F.dropout(embed, p=0.2)
    18. output, hidden = self.rnn(embed.transpose(0, 1), hidden)
    19. # 正则化层
    20. embed = F.dropout(output, p=0.2)
    21. output = self.out(output.squeeze())
    22. return output, hidden
    23. def init_hidden(self):
    24. return torch.zeros(1, 64, 128)

     main.py:

    1. import time
    2. import torch
    3. from Dataset_Dataloader import *
    4. from TangShiModel import *
    5. import torch.optim as optim
    6. from tqdm import tqdm
    7. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    8. def train():
    9. dataset = TangShiDataset(tangshi_ids, 128)
    10. epochs = 100
    11. model = TangShiRNN(word2idx_count).to(device)
    12. criterion = nn.CrossEntropyLoss()
    13. optimizer = optim.Adam(model.parameters(), lr=1e-3)
    14. for idx in range(epochs):
    15. dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
    16. start_time = time.time()
    17. total_loss = 0
    18. total_num = 0
    19. total_correct = 0
    20. total_correct_num = 0
    21. hidden = model.init_hidden()
    22. for x, y in tqdm(dataloader):
    23. x = x.to(device)
    24. y = y.to(device)
    25. # 隐藏状态
    26. hidden = model.init_hidden()
    27. hidden = hidden.to(device)
    28. # 模型计算
    29. output, hidden = model(x, hidden)
    30. # print(output.shape)
    31. # print(y.shape)
    32. # 计算损失
    33. loss = criterion(output.permute(1, 2, 0), y)
    34. # 梯度清零
    35. optimizer.zero_grad()
    36. # 反向传播
    37. loss.backward()
    38. # 参数更新
    39. optimizer.step()
    40. total_loss += loss.sum().item()
    41. total_num += len(y)
    42. total_correct_num += y.shape[0] * y.shape[1]
    43. # print(output.shape)
    44. total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()
    45. print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %
    46. (idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))
    47. torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")
    48. if __name__ == '__main__':
    49. train()

    predict.py:

    1. import torch
    2. import torch.nn as nn
    3. from Dataset_Dataloader import *
    4. from TangShiModel import *
    5. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    6. def predict():
    7. model = TangShiRNN(word2idx_count)
    8. model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))
    9. model.eval()
    10. hidden = torch.zeros(1, 1, 128)
    11. start_word = input("输入第一个字:")
    12. flag = None
    13. tangshi_strs = []
    14. while True:
    15. if not flag:
    16. outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)
    17. tangshi_strs.append("S")
    18. flag = True
    19. else:
    20. tangshi_strs.append(start_word)
    21. outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)
    22. top_i = torch.argmax(outputs, dim=-1)
    23. if top_i.item() == word2idx["E"]:
    24. break
    25. print(top_i)
    26. start_word = idx2word[top_i.item()]
    27. print(tangshi_strs)
    28. if __name__ == '__main__':
    29. predict()

    完整代码如下:

    https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

  • 相关阅读:
    正则表达式符号含义
    Web前端——立体相册的制作
    一些测试知识
    javascript高级篇之原型和原型链
    知识增广的预训练语言模型K-BERT:将知识图谱作为训练语料
    matlab绘制hsv色轮图
    SpringBoot @TransactionalEventListener&JavaMailSender使用
    velocity 调用hutool 实现首字母大写
    【ajax】withCredentials
    【YOLO系列】YOLOv4
  • 原文地址:https://blog.csdn.net/wtl1992/article/details/134527010