邯郸做网站xy0310,aws个人免费版,东莞人才市场官网,成都网站建设报价表第一步#xff1a;准备数据
5种中草药数据#xff1a;self.class_indict [百合, 党参, 山魈, 枸杞, 槐花, 金银花]
#xff0c;总共有900张图片#xff0c;每个文件夹单独放一种数据 第二步准备数据
5种中草药数据self.class_indict [百合, 党参, 山魈, 枸杞, 槐花, 金银花]
总共有900张图片每个文件夹单独放一种数据 第二步搭建模型
本文选择一个EfficientNetV2网络其原理介绍如下 该网络主要使用训练感知神经结构搜索和缩放的组合在EfficientNetV1的基础上引入了Fused-MBConv到搜索空间中引入渐进式学习策略、自适应正则强度调整机制使得训练更快进一步关注模型的推理速度与训练速度 与EfficientV1相比主要有以下不同
V2中除了使用MBConv模块外还使用了Fused-MBConv模块V2中会使用较小的expansion ratio在V1中基本都是6。这样的好处是能够减少内存访问开销V2中更偏向使用更小的kernel_size(3 x 3)在V1中很多5 x 5。优于3 x 3的感受野是比5 x 5小的所以需要堆叠更多的层结构以增加感受野移除了V1中最优一个步距为1的stage 第三步训练代码
1损失函数为交叉熵损失函数
2训练代码
import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device torch.device(args.device if torch.cuda.is_available() else cpu)print(args)print(Start Tensorboard with tensorboard --logdirruns, view at http://localhost:6006/)tb_writer SummaryWriter()if os.path.exists(./weights) is False:os.makedirs(./weights)train_images_path, train_images_label, val_images_path, val_images_label read_split_data(args.data_path)img_size {s: [300, 384], # train_size, val_sizem: [384, 480],l: [384, 480]}num_model sdata_transform {train: transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),val: transforms.Compose([transforms.Resize(img_size[num_model][1]),transforms.CenterCrop(img_size[num_model][1]),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset MyDataSet(images_pathtrain_images_path,images_classtrain_images_label,transformdata_transform[train])# 实例化验证数据集val_dataset MyDataSet(images_pathval_images_path,images_classval_images_label,transformdata_transform[val])batch_size args.batch_sizenw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8]) # number of workersprint(Using {} dataloader workers every process.format(nw))train_loader torch.utils.data.DataLoader(train_dataset,batch_sizebatch_size,shuffleTrue,pin_memoryTrue,num_workersnw,collate_fntrain_dataset.collate_fn)val_loader torch.utils.data.DataLoader(val_dataset,batch_sizebatch_size,shuffleFalse,pin_memoryTrue,num_workersnw,collate_fnval_dataset.collate_fn)# 如果存在预训练权重则载入model create_model(num_classesargs.num_classes).to(device)if args.weights ! :if os.path.exists(args.weights):weights_dict torch.load(args.weights, map_locationdevice)load_weights_dict {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() v.numel()}print(model.load_state_dict(load_weights_dict, strictFalse))else:raise FileNotFoundError(not found weights file: {}.format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除head外其他权重全部冻结if head not in name:para.requires_grad_(False)else:print(training {}.format(name))pg [p for p in model.parameters() if p.requires_grad]optimizer optim.SGD(pg, lrargs.lr, momentum0.9, weight_decay1E-4)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf lambda x: ((1 math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) args.lrf # cosinescheduler lr_scheduler.LambdaLR(optimizer, lr_lambdalf)for epoch in range(args.epochs):# traintrain_loss, train_acc train_one_epoch(modelmodel,optimizeroptimizer,data_loadertrain_loader,devicedevice,epochepoch)scheduler.step()# validateval_loss, val_acc evaluate(modelmodel,data_loaderval_loader,devicedevice,epochepoch)tags [train_loss, train_acc, val_loss, val_acc, learning_rate]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0][lr], epoch)torch.save(model.state_dict(), ./weights/model-{}.pth.format(epoch))if __name__ __main__:parser argparse.ArgumentParser()parser.add_argument(--num_classes, typeint, default5)parser.add_argument(--epochs, typeint, default100)parser.add_argument(--batch-size, typeint, default4)parser.add_argument(--lr, typefloat, default0.01)parser.add_argument(--lrf, typefloat, default0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument(--data-path, typestr,defaultrG:\demo\data\ChineseMedicine)# download model weights# 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ 密码: 5gu1parser.add_argument(--weights, typestr, default./pre_efficientnetv2-s.pth,helpinitial weights path)parser.add_argument(--freeze-layers, typebool, defaultTrue)parser.add_argument(--device, defaultcuda:0, helpdevice id (i.e. 0 or 0,1 or cpu))opt parser.parse_args()main(opt)第四步统计正确率 第五步搭建GUI界面 第六步整个工程的内容
有训练代码和训练好的模型以及训练过程提供数据提供GUI界面代码 代码的下载路径新窗口打开链接基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码 有问题可以私信或者留言有问必答