网站建设兼职劳务协议书,百度云平台建设网站,网站 网页数量,厦门手机网站建设公司排名主要思路#xff1a;
对于唐诗生成来说#xff0c;我们定义一个S 和 E作为开始和结束。 示例的唐诗大概有40000多首#xff0c;
首先数据预处理#xff0c;将唐诗加载到内存#xff0c;生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。…主要思路
对于唐诗生成来说我们定义一个S 和 E作为开始和结束。 示例的唐诗大概有40000多首
首先数据预处理将唐诗加载到内存生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。
Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderdef deal_tangshi():with open(poems.txt, r, encodingutf-8) as fr:lines fr.read().strip().split(\n)tangshis []for line in lines:splits line.split(:)if len(splits) ! 2:continuetangshis.append(S splits[1] E)word2idx {S: 0, E: 1}word2idx_count 2tangshi_ids []for tangshi in tangshis:for word in tangshi:if word not in word2idx:word2idx[word] word2idx_countword2idx_count 1idx2word {idx: w for w, idx in word2idx.items()}for tangshi in tangshis:tangshi_ids.extend([word2idx[w] for w in tangshi])return word2idx, idx2word, tangshis, word2idx_count, tangshi_idsword2idx, idx2word, tangshis, word2idx_count, tangshi_ids deal_tangshi()class TangShiDataset(Dataset):def __init__(self, tangshi_ids, num_chars):# 语料数据self.tangshi_ids tangshi_ids# 语料长度self.num_chars num_chars# 词的数量self.word_count len(self.tangshi_ids)# 句子数量self.number self.word_count // self.num_charsdef __len__(self):return self.numberdef __getitem__(self, idx):# 修正索引值到: [0, self.word_count - 1]start min(max(idx, 0), self.word_count - self.num_chars - 2)x self.tangshi_ids[start: start self.num_chars]y self.tangshi_ids[start 1: start 1 self.num_chars]return torch.tensor(x), torch.tensor(y)def __test_Dataset():dataset TangShiDataset(tangshi_ids, 8)x, y dataset[0]print(x, y)if __name__ __main__:# deal_tangshi()__test_Dataset()TangShiModel.py唐诗的模型import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as Fclass TangShiRNN(nn.Module):def __init__(self, vocab_size):super().__init__()# 初始化词嵌入层self.ebd nn.Embedding(vocab_size, 128)# 循环网络层self.rnn nn.RNN(128, 128, 1)# 输出层self.out nn.Linear(128, vocab_size)def forward(self, inputs, hidden):embed self.ebd(inputs)# 正则化层embed F.dropout(embed, p0.2)output, hidden self.rnn(embed.transpose(0, 1), hidden)# 正则化层embed F.dropout(output, p0.2)output self.out(output.squeeze())return output, hiddendef init_hidden(self):return torch.zeros(1, 64, 128) main.py:
import timeimport torchfrom Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdmdevice torch.device(cuda if torch.cuda.is_available() else cpu)def train():dataset TangShiDataset(tangshi_ids, 128)epochs 100model TangShiRNN(word2idx_count).to(device)criterion nn.CrossEntropyLoss()optimizer optim.Adam(model.parameters(), lr1e-3)for idx in range(epochs):dataloader DataLoader(dataset, batch_size64, shuffleTrue, drop_lastTrue)start_time time.time()total_loss 0total_num 0total_correct 0total_correct_num 0hidden model.init_hidden()for x, y in tqdm(dataloader):x x.to(device)y y.to(device)# 隐藏状态hidden model.init_hidden()hidden hidden.to(device)# 模型计算output, hidden model(x, hidden)# print(output.shape)# print(y.shape)# 计算损失loss criterion(output.permute(1, 2, 0), y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_loss loss.sum().item()total_num len(y)total_correct_num y.shape[0] * y.shape[1]# print(output.shape)total_correct (torch.argmax(output.permute(1, 0, 2), dim-1) y).sum().item()print(epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds %(idx 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))torch.save(model.state_dict(), f./modules/tangshi_module_{idx 1}.bin)if __name__ __main__:train()predict.py
import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *device torch.device(cuda if torch.cuda.is_available() else cpu)def predict():model TangShiRNN(word2idx_count)model.load_state_dict(torch.load(./modules/tangshi_module_100.bin, map_locationtorch.device(cpu)))model.eval()hidden torch.zeros(1, 1, 128)start_word input(输入第一个字:)flag Nonetangshi_strs []while True:if not flag:outputs, hidden model(torch.tensor([[word2idx[S]]], dtypetorch.long), hidden)tangshi_strs.append(S)flag Trueelse:tangshi_strs.append(start_word)outputs, hidden model(torch.tensor([[word2idx[start_word]]], dtypetorch.long), hidden)top_i torch.argmax(outputs, dim-1)if top_i.item() word2idx[E]:breakprint(top_i)start_word idx2word[top_i.item()]print(tangshi_strs)if __name__ __main__:predict()完整代码如下
https://github.com/STZZ-1992/tangshi-generator.githttps://github.com/STZZ-1992/tangshi-generator.git