哪里有免费的网站推广软件,爱洛阳网,科技感网站设计,做视频网站视频短片使用torchvision集成的efficientnet-v2-s模型#xff0c;调用torchvision库中的Oxford IIIT Pet数据集#xff0c;对模型进行训练。 若有修改要求#xff0c;可以修改以下部分#xff1a;
train_dataset OxfordIIITPet(root./data, splittrainval, downloadTrue, transfo…使用torchvision集成的efficientnet-v2-s模型调用torchvision库中的Oxford IIIT Pet数据集对模型进行训练。 若有修改要求可以修改以下部分
train_dataset OxfordIIITPet(root./data, splittrainval, downloadTrue, transformtransform_train)
test_dataset OxfordIIITPet(root./data, splittest, downloadTrue, transformtransform_test)
#常见数据集可以直接加载若是自己的数据集就自己写个dataset/dataloadermodel.classifier[1] nn.Linear(model.classifier[1].in_features, 37)
#37为数据集类别数修改为自己对应的scheduler ReduceLROnPlateau(optimizer, max, patience3, factor0.1, verboseTrue)
#学习率处可以自己调整可玩性较高训练截图
其实十轮左右就稳定在90以上了跑了三十轮记得修改保存路径我这里是用kaggle跑的。 代码如下
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import OxfordIIITPet
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm# 数据预处理 数据增强
transform_train transforms.Compose([transforms.Resize((256, 256)), # 增大图片预处理尺寸transforms.RandomCrop((224, 224)), # 随机裁剪到模型输入尺寸transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), # 颜色抖动transforms.ToTensor(),transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])
])
transform_test transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])
])# 加载数据集
train_dataset OxfordIIITPet(root./data, splittrainval, downloadTrue, transformtransform_train)
test_dataset OxfordIIITPet(root./data, splittest, downloadTrue, transformtransform_test)train_loader DataLoader(train_dataset, batch_size32, shuffleTrue)
test_loader DataLoader(test_dataset, batch_size32, shuffleFalse)# 模型定义
model efficientnet_v2_s(pretrainedTrue)
model.classifier[1] nn.Linear(model.classifier[1].in_features, 37)# 设置设备
device torch.device(cuda if torch.cuda.is_available() else cpu)
model.to(device)# 定义损失函数和优化器
criterion nn.CrossEntropyLoss()
optimizer optim.SGD(model.parameters(), lr0.01, momentum0.9, weight_decay0.001)
scheduler ReduceLROnPlateau(optimizer, max, patience3, factor0.1, verboseTrue)# 训练模型
def train_model(num_epochs):model.train()best_accuracy 0for epoch in range(num_epochs):model.train()running_loss 0.0for inputs, labels in tqdm(train_loader, descfTraining Epoch {epoch 1}):inputs, labels inputs.to(device), labels.to(device)optimizer.zero_grad()outputs model(inputs)loss criterion(outputs, labels)loss.backward()optimizer.step()running_loss loss.item()# 每个 epoch 后测试accuracy test_model()scheduler.step(accuracy)# 如果当前模型表现更好保存模型if accuracy best_accuracy:best_accuracy accuracytorch.save(model.state_dict(), /kaggle/working/best_oxford_pets_efficientnetv2.pth)print(fNew best model saved with accuracy: {best_accuracy:.2f}%)def test_model():model.eval()correct 0total 0with torch.no_grad():for inputs, labels in tqdm(test_loader, descTesting):inputs, labels inputs.to(device), labels.to(device)outputs model(inputs)_, predicted torch.max(outputs.data, 1)total labels.size(0)correct (predicted labels).sum().item()# 调试输出if total 50: # 只打印前50个样本的信息print(fPredicted: {predicted[:10]}, Labels: {labels[:10]})accuracy 100 * correct / totalprint(fTesting Accuracy: {accuracy:.2f}%)return accuracytrain_model(30)