网站建设服务,发广告平台有哪些免费,设计工作一般多少工资,做30个精品网站之——全纪录
目录
之——全纪录
杂谈
正文
1.下载处理数据
2.数据集概览
3.构建自定义dataset
4.初始化网络
5.训练 杂谈 综合方法试一下。 leaves
1.下载处理数据 从官网下载数据集#xff1a;Classify Leaves | Kaggle 解压后有一个图片集#xff0c;一个提交示…之——全纪录
目录
之——全纪录
杂谈
正文
1.下载处理数据
2.数据集概览
3.构建自定义dataset
4.初始化网络
5.训练 杂谈 综合方法试一下。 leaves
1.下载处理数据 从官网下载数据集Classify Leaves | Kaggle 解压后有一个图片集一个提交示例一个测试集一个训练集。 images27153个树叶图片 test.csv8800个 train.csv18353个 2.数据集概览 训练集、测试集、类别
#导包
import random
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Imagetrain_datapd.read_csv(rD:\apycharmblackhorse\leaves\train.csv)
test_datapd.read_csv(rD:\apycharmblackhorse/leaves/test.csv)train_imagestrain_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
print(训练集数量,len(train_images))
n_trainlen(train_images)
test_imagestest_data.iloc[:,0].values
print(测试集数量,len(test_images))
n_testlen(test_images)train_labels pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
#独热编码后找到每行最大的索引记下来就是类别号而顺序与独热编码colums也就是与下方排序一致
# print(len(train_labels),train_labels)#记录并排序所有的类别名
train_labels_header pd.get_dummies(train_data.iloc[:, 1]).columns.values
print(总类别:,len(train_labels_header))
classeslen(train_labels_header) 3.构建自定义dataset 继承 torch.utils.Dataset 类自定义树叶分类数据集
#继承 torch.utils.Dataset 类自定义树叶分类数据集
class leaves_dataset(torch.utils.data.Dataset):#root数据目录, images图片路径, labels图片标签, transform数据增强def __init__(self, root, images, labels, transform):super(leaves_dataset, self).__init__()self.root rootself.images imagesif labels is None:self.labels Noneelse:self.labels labelsself.transform transform#获得指定样本def __getitem__(self, index):image_path self.root self.images[index]image Image.open(image_path)#预处理image self.transform(image)if self.labels is None:return imagelabel torch.tensor(self.labels[index])return image, label#获得数据集长度def __len__(self):return self.images.shape[0]构建读取数据与预处理
def load_data(images, labels, batch_size, train):aug []normalize torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])if (train):aug [torchvision.transforms.CenterCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ColorJitter(brightness0.5, contrast0.5, saturation0.5, hue0.5),transforms.ToTensor(),normalize]else:aug [torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),transforms.ToTensor(),normalize]transform transforms.Compose(aug)dataset leaves_dataset(rD:\apycharmblackhorse\leaves\\, images, labels, transformtransform)if trainTrue:type训练else:type测试print(载入,dataset.__len__(),type)return torch.utils.data.DataLoader(datasetdataset, batch_sizebatch_size, num_workers0, shuffletrain)train_iter load_data(train_images, train_labels, 512, trainTrue) 4.初始化网络 使用官方预训练模型初始化网络并修改输出类别数
#初始化网络
net torchvision.models.resnet18(pretrainedTrue)net.fc nn.Linear(net.fc.in_features, classes)
nn.init.xavier_uniform_(net.fc.weight)
net.fc 5.训练 定义迭代器、优化器以及其他超参数进行训练
# 如果param_groupTrue输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size64, num_epochs20,param_groupTrue):train_slices random.sample(list(range(n_train)), 15000)test_slices list(set(range(n_train)) - set(train_slices))train_iter load_data(train_images[train_slices], train_labels[train_slices], batch_size, trainTrue)test_iter load_data(train_images[test_slices], train_labels[test_slices], batch_size, trainFalse)devices d2l.try_all_gpus()loss nn.CrossEntropyLoss(reductionnone)if param_group:params_1x [param for name, param in net.named_parameters()if name not in [fc.weight, fc.bias]]#别的层不变最后一层10倍学习率trainer torch.optim.Adam([{params: params_1x},{params: net.fc.parameters(),lr: learning_rate * 10}],lrlearning_rate, weight_decay0.001)else:trainer torch.optim.Adam(net.parameters(), lrlearning_rate,weight_decay0.001)print(111)try:d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)except Exception as e:print(e)#%%#较小的学习率通过微调预训练获得的模型参数
train_fine_tuning(net, 1e-3)小破脑跑得慢之前不用预训练5个epoch后acc大概只能到0.3 使用预训练后到了0.6但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择资源不够这里就不做尝试了大概尝试情况 CIFAR-10
1.数据集 2.未完待续