2023年6月21日发(作者:)

SwinTransformer实战:timm使⽤、Mixup、Cutout和评分⼀⽹打尽,。。。⽂章⽬录摘要本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演⽰如何使⽤timm版本的Swin Transformer图像分类模型实现分类任务已经对验证集得分的统计,本⽂实现了多个GPU并⾏训练。通过本⽂你和学到:1、如何从timm调⽤模型、loss和Mixup?2、如何制作ImageNet数据集?3、如何使⽤Cutout数据增强?4、如何使⽤Mixup数据增强。5、如何实现多个GPU训练和验证。6、如何使⽤余弦退⽕调整学习率?7、如何使⽤classification_report实现对模型的评价。8、预测的两种写法。Swin Transformer简介⽬标检测刷到58.7 AP!实例分割刷到51.1 Mask AP!语义分割在ADE20K上刷到53.5 mIoU!今年,微软亚洲研究院的Swin Transformer⼜开启了吊打CNN的模式,在速度和精度上都有很⼤的提⾼。这篇⽂章带你实现SwinTransformer图像分类。资料汇总⼀些⼤佬的B站视频:2、ClimbingVision社区:关于Swin Transformer的资料有很多,在这⾥就不⼀⼀列举了,我觉得理解这个模型的最好⽅式:源码+论⽂。数据增强Cutout和Mixup为了提⾼成绩我在代码中加⼊Cutout和Mixup这两种增强⽅式。实现这两种增强需要安装torchtoolbox。安装命令:pip install torchtoolboxCutout实现,在transforms中。from orm import Cutout#

数据预处理transform = e([ ((224, 224)), Cutout(), or(), ize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])需要导⼊包:from import Mixup,定义Mixup,和SoftTargetCrossEntropy mixup_fn = Mixup( mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, prob=0.1, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=12)

criterion_train = SoftTargetCrossEntropy()项⽬结构Swin_demo├─data│ ├─Black-grass│ ├─Charlock│ ├─Cleavers│ ├─Common Chickweed│ ├─Common wheat│ ├─Fat Hen│ ├─Loose Silky-bent│ ├─Maize│ ├─Scentless Mayweed│ ├─Shepherds Purse│ ├─Small-flowered Cranesbill│ └─Sugar beet├─mean_├─├─├─└─_:计算mean和std的值。:⽣成数据集。计算mean和std为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_,插⼊代码:from ts import ImageFolderimport torchfrom torchvision import transformsdef get_mean_and_std(train_data): train_loader = ader( train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) mean = (3) std = (3) for X, _ in train_loader: for d in range(3): mean[d] += X[:, d, :, :].mean() std[d] += X[:, d, :, :].std() _(len(train_data)) _(len(train_data)) return list(()), list(())if __name__ == '__main__': train_dataset = ImageFolder(root=r'data1', transform=or()) print(get_mean_and_std(train_dataset))数据集结构:运⾏结果:([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])把这个结果记录下来,后⾯要⽤!⽣成数据集我们整理还的图像分类的数据集结构是这样的data├─Black-grass├─Charlock├─Cleavers├─Common Chickweed├─Common wheat├─Fat Hen├─Loose Silky-bent├─Maize├─Scentless Mayweed├─Shepherds Purse├─Small-flowered Cranesbill└─Sugar beetpytorch和keras默认加载⽅式是ImageNet数据集格式,格式是├─data│ ├─val│ │ ├─Black-grass│ │ ├─Charlock│ │ ├─Cleavers│ │ ├─Common Chickweed│ │ ├─Common wheat│ │ ├─Fat Hen│ │ ├─Loose Silky-bent│ │ ├─Maize│ │ ├─Scentless Mayweed│ │ ├─Shepherds Purse│ │ ├─Small-flowered Cranesbill│ │ └─Sugar beet│ └─train│ ├─Black-grass│ ├─Charlock│ ├─Cleavers│ ├─Common Chickweed│ ├─Common wheat│ ├─Fat Hen│ ├─Loose Silky-bent│ ├─Maize│ ├─Scentless Mayweed│ ├─Shepherds Purse│ ├─Small-flowered Cranesbill│ └─Sugar beet新增格式转化脚本,插⼊代码:import globimport osimport shutilimage_list=('data1/*/*.png')print(image_list)file_dir='data'if (file_dir): print('true') #(file_dir) (file_dir)#删除再建⽴ rs(file_dir)else: rs(file_dir)from _selection import train_test_splittrainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)train_dir='train'val_dir='val'train_root=(file_dir,train_dir)val_root=(file_dir,val_dir)for file in trainval_files: file_class=e("","/").split('/')[-2] file_name=e("","/").split('/')[-1] file_class=(train_root,file_class) if not (file_class): rs(file_class) (file, file_class + '/' + file_name)for file in val_files: file_class=e("","/").split('/')[-2] file_name=e("","/").split('/')[-1] file_class=(val_root,file_class) if not (file_class): rs(file_class) (file, file_class + '/' + file_name)训练完成上⾯的步骤后,就开始train脚本的编写,新建.导⼊项⽬使⽤的库import torchimport as nnimport elimport as optimimport port butedimport ts as datasetsimport orms as transformsfrom s import classification_reportfrom import Mixupfrom import SoftTargetCrossEntropyfrom import swin_small_patch4_window7_224from orm import Cutout设置全局参数设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使⽤CPU。建议使⽤GPU,CPU太慢了。#

设置全局参数model_lr = 1e-4BATCH_SIZE = 4EPOCHS = 1000DEVICE = ('cuda:0' if _available() else 'cpu')图像预处理与增强数据处理⽐较简单,加⼊了Cutout、做了Resize和归⼀化,定义Mixup函数。#

数据预处理7transform = e([ ((224, 224)), Cutout(), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])transform_test = e([ ((224, 224)), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])mixup_fn = Mixup( mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, prob=0.1, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=12)读取数据使⽤pytorch默认读取数据的⽅式,然后将dataset__to_idx打印出来,预测的时候要⽤到。#

读取数据dataset_train = older('data/train', transform=transform)dataset_test = older("data/val", transform=transform_test)print(dataset__to_idx)#

导⼊数据train_loader = ader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = ader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)class_to_idx的结果:{‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3, ‘Common wheat’: 4, ‘Fat Hen’: 5,‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8, ‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10,‘Sugar beet’: 11}设置模型设置loss函数,train的loss为:SoftTargetCrossEntropy,val的loss:ntropyLoss()。设置模型为swin_small_patch4_window7_224,预训练设置为true,num_classes设置为12。检测可⽤显卡的数量,如果⼤于1,则要⽤rallel加载模型,开启多卡训练。优化器设置为adam。学习率调整策略选择为余弦退⽕。#

实例化模型并且移动到GPUcriterion_train = SoftTargetCrossEntropy()criterion_val = ntropyLoss()model_ft = swin_small_patch4_window7_224(pretrained=True)print(model_ft)num_ftrs = model__featuresmodel_ = (num_ftrs, 12)model_(DEVICE)print(model_ft)if _count() > 1: print("Let's use", _count(), "GPUs!") model_ft = rallel(model_ft)print(model_ft)#

选择简单暴⼒的Adam优化器,学习率调低optimizer = (model_ters(), lr=model_lr)cosine_schedule = _AnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)定义训练和验证函数定义训练函数和验证函数,在⼀个epoch完成后,使⽤classification_report计算详细的得分情况。#

定义训练过程def train(model, device, train_loader, optimizer, epoch): () sum_loss = 0 total_num = len(train_t) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = (device, non_blocking=True), (device, non_blocking=True) samples, targets = mixup_fn(data, target) _grad() output = model(data) loss = criterion_train(output, targets) rd() () lr = _dict()['param_groups'][0]['lr'] print_loss = () sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}tLR:{:.9f}'.format( epoch, (batch_idx + 1) * len(data), len(train_t), 100. * (batch_idx + 1) / len(train_loader), (), lr)) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss))ACC = 0#

验证过程def val(model, device, test_loader): global ACC () test_loss = 0 correct = 0 total_num = len(test_t) print(total_num, len(test_loader)) val_list = [] pred_list = [] with _grad(): for data, target in test_loader: for t in target: val_(()) data, target = (device), (device) output = model(data) loss = criterion_val(output, target) _, pred = (, 1) for p in pred: pred_(()) correct += (pred == target) print_loss = () test_loss += print_loss correct = () acc = correct / total_num avgloss = test_loss / len(test_loader) print('nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)n'.format( avgloss, correct, len(test_t), 100 * acc)) if acc > ACC: (model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth') ACC = acc return val_list, pred_list#

训练for epoch in range(1, EPOCHS + 1): train(model_ft, DEVICE, train_loader, optimizer, epoch) cosine_() val_list, pred_list = val(model_ft, DEVICE, test_loader) print(classification_report(val_list, pred_list, target_names=dataset__to_idx))运⾏结果:测试我介绍两种常⽤的测试⽅式,第⼀种是通⽤的,通过⾃⼰⼿动加载数据集然后做预测,具体操作如下:测试集存放的⽬录如下图:第⼀步 定义类别,这个类别的顺序和训练时的类别顺序对应,⼀定不要改变顺序第⼆步 定义transforms,transforms和验证集的transforms⼀样即可,别做数据增强。第三步 加载model,并将模型放在DEVICE⾥,第四步 读取图⽚并预测图⽚的类别,在这⾥注意,读取图⽚⽤PIL库的Image。不要⽤cv2,transforms不⽀持。import butedimport orms as transformsfrom PIL import Imagefrom ad import Variableimport osclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat','Fat Hen', 'Loose Silky-bent', 'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')transform_test = e([ ((224, 224)), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])

DEVICE = ("cuda:0" if _available() else "cpu")model = ("")()(DEVICE)

path='data/test/'testList=r(path)for file in testList: img=(path+file) img=transform_test(img) eze_(0) img = Variable(img).to(DEVICE) out=model(img) # Predict _, pred = (, 1) print('Image Name:{},predict:{}'.format(file,classes[()]))运⾏结果:第⼆种 使⽤⾃定义的Dataset读取图⽚import butedimport orms as transformsfrom t import SeedlingDatafrom ad import Variable

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat','Fat Hen', 'Loose Silky-bent', 'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')transform_test = e([ ((224, 224)), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])

DEVICE = ("cuda:0" if _available() else "cpu")model = ("")()(DEVICE)

dataset_test =SeedlingData('data/test/', transform_test,test=True)print(len(dataset_test))#

对应⽂件夹的label

for index in range(len(dataset_test)): item = dataset_test[index] img, label = item eze_(0) data = Variable(img).to(DEVICE) output = model(data) _, pred = (, 1) print('Image Name:{},predict:{}'.format(dataset_[index], classes[()])) index += 1

运⾏结果:

2023年6月21日发(作者:)

SwinTransformer实战:timm使⽤、Mixup、Cutout和评分⼀⽹打尽,。。。⽂章⽬录摘要本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演⽰如何使⽤timm版本的Swin Transformer图像分类模型实现分类任务已经对验证集得分的统计,本⽂实现了多个GPU并⾏训练。通过本⽂你和学到:1、如何从timm调⽤模型、loss和Mixup?2、如何制作ImageNet数据集?3、如何使⽤Cutout数据增强?4、如何使⽤Mixup数据增强。5、如何实现多个GPU训练和验证。6、如何使⽤余弦退⽕调整学习率?7、如何使⽤classification_report实现对模型的评价。8、预测的两种写法。Swin Transformer简介⽬标检测刷到58.7 AP!实例分割刷到51.1 Mask AP!语义分割在ADE20K上刷到53.5 mIoU!今年,微软亚洲研究院的Swin Transformer⼜开启了吊打CNN的模式,在速度和精度上都有很⼤的提⾼。这篇⽂章带你实现SwinTransformer图像分类。资料汇总⼀些⼤佬的B站视频:2、ClimbingVision社区:关于Swin Transformer的资料有很多,在这⾥就不⼀⼀列举了,我觉得理解这个模型的最好⽅式:源码+论⽂。数据增强Cutout和Mixup为了提⾼成绩我在代码中加⼊Cutout和Mixup这两种增强⽅式。实现这两种增强需要安装torchtoolbox。安装命令:pip install torchtoolboxCutout实现,在transforms中。from orm import Cutout#

数据预处理transform = e([ ((224, 224)), Cutout(), or(), ize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])需要导⼊包:from import Mixup,定义Mixup,和SoftTargetCrossEntropy mixup_fn = Mixup( mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, prob=0.1, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=12)

criterion_train = SoftTargetCrossEntropy()项⽬结构Swin_demo├─data│ ├─Black-grass│ ├─Charlock│ ├─Cleavers│ ├─Common Chickweed│ ├─Common wheat│ ├─Fat Hen│ ├─Loose Silky-bent│ ├─Maize│ ├─Scentless Mayweed│ ├─Shepherds Purse│ ├─Small-flowered Cranesbill│ └─Sugar beet├─mean_├─├─├─└─_:计算mean和std的值。:⽣成数据集。计算mean和std为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_,插⼊代码:from ts import ImageFolderimport torchfrom torchvision import transformsdef get_mean_and_std(train_data): train_loader = ader( train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) mean = (3) std = (3) for X, _ in train_loader: for d in range(3): mean[d] += X[:, d, :, :].mean() std[d] += X[:, d, :, :].std() _(len(train_data)) _(len(train_data)) return list(()), list(())if __name__ == '__main__': train_dataset = ImageFolder(root=r'data1', transform=or()) print(get_mean_and_std(train_dataset))数据集结构:运⾏结果:([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])把这个结果记录下来,后⾯要⽤!⽣成数据集我们整理还的图像分类的数据集结构是这样的data├─Black-grass├─Charlock├─Cleavers├─Common Chickweed├─Common wheat├─Fat Hen├─Loose Silky-bent├─Maize├─Scentless Mayweed├─Shepherds Purse├─Small-flowered Cranesbill└─Sugar beetpytorch和keras默认加载⽅式是ImageNet数据集格式,格式是├─data│ ├─val│ │ ├─Black-grass│ │ ├─Charlock│ │ ├─Cleavers│ │ ├─Common Chickweed│ │ ├─Common wheat│ │ ├─Fat Hen│ │ ├─Loose Silky-bent│ │ ├─Maize│ │ ├─Scentless Mayweed│ │ ├─Shepherds Purse│ │ ├─Small-flowered Cranesbill│ │ └─Sugar beet│ └─train│ ├─Black-grass│ ├─Charlock│ ├─Cleavers│ ├─Common Chickweed│ ├─Common wheat│ ├─Fat Hen│ ├─Loose Silky-bent│ ├─Maize│ ├─Scentless Mayweed│ ├─Shepherds Purse│ ├─Small-flowered Cranesbill│ └─Sugar beet新增格式转化脚本,插⼊代码:import globimport osimport shutilimage_list=('data1/*/*.png')print(image_list)file_dir='data'if (file_dir): print('true') #(file_dir) (file_dir)#删除再建⽴ rs(file_dir)else: rs(file_dir)from _selection import train_test_splittrainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)train_dir='train'val_dir='val'train_root=(file_dir,train_dir)val_root=(file_dir,val_dir)for file in trainval_files: file_class=e("","/").split('/')[-2] file_name=e("","/").split('/')[-1] file_class=(train_root,file_class) if not (file_class): rs(file_class) (file, file_class + '/' + file_name)for file in val_files: file_class=e("","/").split('/')[-2] file_name=e("","/").split('/')[-1] file_class=(val_root,file_class) if not (file_class): rs(file_class) (file, file_class + '/' + file_name)训练完成上⾯的步骤后,就开始train脚本的编写,新建.导⼊项⽬使⽤的库import torchimport as nnimport elimport as optimimport port butedimport ts as datasetsimport orms as transformsfrom s import classification_reportfrom import Mixupfrom import SoftTargetCrossEntropyfrom import swin_small_patch4_window7_224from orm import Cutout设置全局参数设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使⽤CPU。建议使⽤GPU,CPU太慢了。#

设置全局参数model_lr = 1e-4BATCH_SIZE = 4EPOCHS = 1000DEVICE = ('cuda:0' if _available() else 'cpu')图像预处理与增强数据处理⽐较简单,加⼊了Cutout、做了Resize和归⼀化,定义Mixup函数。#

数据预处理7transform = e([ ((224, 224)), Cutout(), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])transform_test = e([ ((224, 224)), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])mixup_fn = Mixup( mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, prob=0.1, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=12)读取数据使⽤pytorch默认读取数据的⽅式,然后将dataset__to_idx打印出来,预测的时候要⽤到。#

读取数据dataset_train = older('data/train', transform=transform)dataset_test = older("data/val", transform=transform_test)print(dataset__to_idx)#

导⼊数据train_loader = ader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = ader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)class_to_idx的结果:{‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3, ‘Common wheat’: 4, ‘Fat Hen’: 5,‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8, ‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10,‘Sugar beet’: 11}设置模型设置loss函数,train的loss为:SoftTargetCrossEntropy,val的loss:ntropyLoss()。设置模型为swin_small_patch4_window7_224,预训练设置为true,num_classes设置为12。检测可⽤显卡的数量,如果⼤于1,则要⽤rallel加载模型,开启多卡训练。优化器设置为adam。学习率调整策略选择为余弦退⽕。#

实例化模型并且移动到GPUcriterion_train = SoftTargetCrossEntropy()criterion_val = ntropyLoss()model_ft = swin_small_patch4_window7_224(pretrained=True)print(model_ft)num_ftrs = model__featuresmodel_ = (num_ftrs, 12)model_(DEVICE)print(model_ft)if _count() > 1: print("Let's use", _count(), "GPUs!") model_ft = rallel(model_ft)print(model_ft)#

选择简单暴⼒的Adam优化器,学习率调低optimizer = (model_ters(), lr=model_lr)cosine_schedule = _AnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)定义训练和验证函数定义训练函数和验证函数,在⼀个epoch完成后,使⽤classification_report计算详细的得分情况。#

定义训练过程def train(model, device, train_loader, optimizer, epoch): () sum_loss = 0 total_num = len(train_t) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = (device, non_blocking=True), (device, non_blocking=True) samples, targets = mixup_fn(data, target) _grad() output = model(data) loss = criterion_train(output, targets) rd() () lr = _dict()['param_groups'][0]['lr'] print_loss = () sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}tLR:{:.9f}'.format( epoch, (batch_idx + 1) * len(data), len(train_t), 100. * (batch_idx + 1) / len(train_loader), (), lr)) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss))ACC = 0#

验证过程def val(model, device, test_loader): global ACC () test_loss = 0 correct = 0 total_num = len(test_t) print(total_num, len(test_loader)) val_list = [] pred_list = [] with _grad(): for data, target in test_loader: for t in target: val_(()) data, target = (device), (device) output = model(data) loss = criterion_val(output, target) _, pred = (, 1) for p in pred: pred_(()) correct += (pred == target) print_loss = () test_loss += print_loss correct = () acc = correct / total_num avgloss = test_loss / len(test_loader) print('nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)n'.format( avgloss, correct, len(test_t), 100 * acc)) if acc > ACC: (model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth') ACC = acc return val_list, pred_list#

训练for epoch in range(1, EPOCHS + 1): train(model_ft, DEVICE, train_loader, optimizer, epoch) cosine_() val_list, pred_list = val(model_ft, DEVICE, test_loader) print(classification_report(val_list, pred_list, target_names=dataset__to_idx))运⾏结果:测试我介绍两种常⽤的测试⽅式,第⼀种是通⽤的,通过⾃⼰⼿动加载数据集然后做预测,具体操作如下:测试集存放的⽬录如下图:第⼀步 定义类别,这个类别的顺序和训练时的类别顺序对应,⼀定不要改变顺序第⼆步 定义transforms,transforms和验证集的transforms⼀样即可,别做数据增强。第三步 加载model,并将模型放在DEVICE⾥,第四步 读取图⽚并预测图⽚的类别,在这⾥注意,读取图⽚⽤PIL库的Image。不要⽤cv2,transforms不⽀持。import butedimport orms as transformsfrom PIL import Imagefrom ad import Variableimport osclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat','Fat Hen', 'Loose Silky-bent', 'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')transform_test = e([ ((224, 224)), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])

DEVICE = ("cuda:0" if _available() else "cpu")model = ("")()(DEVICE)

path='data/test/'testList=r(path)for file in testList: img=(path+file) img=transform_test(img) eze_(0) img = Variable(img).to(DEVICE) out=model(img) # Predict _, pred = (, 1) print('Image Name:{},predict:{}'.format(file,classes[()]))运⾏结果:第⼆种 使⽤⾃定义的Dataset读取图⽚import butedimport orms as transformsfrom t import SeedlingDatafrom ad import Variable

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat','Fat Hen', 'Loose Silky-bent', 'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')transform_test = e([ ((224, 224)), or(), ize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])

DEVICE = ("cuda:0" if _available() else "cpu")model = ("")()(DEVICE)

dataset_test =SeedlingData('data/test/', transform_test,test=True)print(len(dataset_test))#

对应⽂件夹的label

for index in range(len(dataset_test)): item = dataset_test[index] img, label = item eze_(0) data = Variable(img).to(DEVICE) output = model(data) _, pred = (, 1) print('Image Name:{},predict:{}'.format(dataset_[index], classes[()])) index += 1

运⾏结果: