租车网站建设,德州网站怎样建设,汽配网站开发,汉服网站设计目的1、数据集的保存形式#xff1a;一行一行的。
比如说预测两个值的加法#xff1a;abc#xff0c;那么传进Dataset的形式应该是 a1,b1,c1 a2,b2,c2 ... an,bn,cn 2、代码
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, D…1、数据集的保存形式一行一行的。
比如说预测两个值的加法abc那么传进Dataset的形式应该是 a1,b1,c1 a2,b2,c2 ... an,bn,cn 2、代码
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset# 创建数据
np.random.seed(2024) # 保证后续使用random函数时产生固定的随机数
data_rand np.random.rand(10, 2)
datas np.insert(data_rand, 2, data_rand.sum(axis1), axis1)
print(\ndatas.shape, datas.shape)
print(datas\n, datas)train_data datas[:int(len(datas) * 0.9)]
test_data datas[int(len(datas) * 0.9):]debug_flag False # False,Trueclass PreDataSet(Dataset):def __init__(self, _data):self.x_data torch.Tensor(_data[:, :-1])self.y_data torch.Tensor(_data[:, -1])if debug_flag:print(self.x_data.shape, self.x_data.shape)print(self.y_data.shape, self.y_data.shape)self.n_getitem 0 # 记录进入__getitem__的次数self.n_len 0 # 记录进入__len__的次数def __getitem__(self, index):self.n_getitem self.n_getitem 1if debug_flag:print(index, index, n_getitem, self.n_getitem)print(x_data[index].shape, self.x_data[index].shape)print(y_data[index].shape, self.y_data[index].shape)return self.x_data[index], self.y_data[index]def __len__(self):self.n_len self.n_len 1if debug_flag:print(len(self.x_data), len(self.x_data), n_len, self.n_len)return len(self.x_data)train_dataset PreDataSet(train_data)
train_dataloader DataLoader(train_dataset, batch_size1, shuffleFalse)# 2、输出看结果
for x, y in train_dataloader:print(\nx, x)print(y, y)if debug_flag:print(x.shape, x.shape)print(y.shape, y.shape)
3、运行结果
D:\SoftProgram\JetBrains\anaconda3_202303\python.exe E:\program\python\DKASCProject\10DistributedPV\tst_dataloader_end.py datas.shape (10, 3)
datas[[0.58801452 0.69910875 1.28712327][0.18815196 0.04380856 0.23196052][0.20501895 0.10606287 0.31108183][0.72724014 0.67940052 1.40664067][0.4738457 0.44829582 0.92214153][0.01910695 0.75259834 0.77170529][0.60244854 0.96177758 1.56422611][0.66436865 0.60662962 1.27099827][0.44915131 0.22535416 0.67450548][0.6701743 0.73576659 1.40594089]]x tensor([[0.5880, 0.6991]])
y tensor([1.2871])x tensor([[0.1882, 0.0438]])
y tensor([0.2320])x tensor([[0.2050, 0.1061]])
y tensor([0.3111])x tensor([[0.7272, 0.6794]])
y tensor([1.4066])x tensor([[0.4738, 0.4483]])
y tensor([0.9221])x tensor([[0.0191, 0.7526]])
y tensor([0.7717])x tensor([[0.6024, 0.9618]])
y tensor([1.5642])x tensor([[0.6644, 0.6066]])
y tensor([1.2710])x tensor([[0.4492, 0.2254]])
y tensor([0.6745])进程已结束退出代码为 0参考B站视频
【2、数据集加载Dataset和DataLoader】