dw如何用表格做网站,wordpress商城建站,杭州网站制作平台,慈溪专业做网站公司文章目录 1. 创建数据集1.1. 直接继承Dataset类1.2. 使用TensorDataset类 2. 数据集的划分3. 加载数据集4. 将数据转移到GPU 1. 创建数据集
主要是将数据集读入内存#xff0c;并用Dataset类封装。
1.1. 直接继承Dataset类
必须要重写__getitem__方法#xff0c;用于根据索… 文章目录 1. 创建数据集1.1. 直接继承Dataset类1.2. 使用TensorDataset类 2. 数据集的划分3. 加载数据集4. 将数据转移到GPU 1. 创建数据集
主要是将数据集读入内存并用Dataset类封装。
1.1. 直接继承Dataset类
必须要重写__getitem__方法用于根据索引获得相应样本数据。必要时还可以重写__len__方法用于返回数据集的大小。
from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):定义波士顿房价数据集def __init__(self):self.data np.load(../dataset/boston_housing/boston_housing.npz)def __getitem__(self, index):return self.data[x][index], self.data[y][index]def __len__(self):return self.data[x].shape[0]1.2. 使用TensorDataset类
将多个张量组合成一个数据集要保证所有张量的第一个维度相等保证每批样本数据格式相同。
import torch
from torch.utils.data import TensorDatasetdata np.load(../dataset/boston_housing/boston_housing.npz)
X torch.tensor(data[x])
y torch.tensor(data[y])
dataset TensorDataset(X, y)2. 数据集的划分
数据集可以划分为训练集、验证集和测试集。
训练集用于模型拟合的数据样本集合。验证集通常被用来调整模型的参数以找出效果最佳的模型。测试集用于训练好的模型性能评估的数据样本集合。
from torch.utils.data import random_splittrain_size int(0.8 * len(dataset))
test_size len(dataset) - train_size
train_dataset, test_dataset random_split(dataset, [train_size, test_size])3. 加载数据集
使用DataLoader类将Dataset封装的数据集分成批次并进行迭代以便于模型训练。DataLoader常用参数如下
dataset 要加载的数据集。batch_size 每个数据批次中包含的样本数。默认为1。shuffle 是否打乱数据集。默认为False。num_workers 使用几个进程来加载数据。默认为0即在主进程中加载数据。drop_last 当数据集样本数不能被batch_size整除时是否舍弃最后一个不完整的batch。默认为False。
from torch.utils.data import DataLoaderdataloader DataLoader(dataset, batch_size16, shuffleTrue)
4. 将数据转移到GPU
一般在要运算时才将数据转移到GPU有以下两种方法
var.to(device)var.cuda()
import torchdevice torch.device(cuda if torch.cuda.is_available() else cpu)
for X,y in dataloader:# 将数据转移到GPUX X.to(device)y y.to(device)# 也可以X X.cuda()y y.cuda()