如何查看网站的建设者,海口网站建设的开发方案,网站监测浏览器类型,网页美工设计说明YOLOV8改进#xff1a;如何增加注意力模块#xff1f;#xff08;以CBAM模块为例#xff09;前言YOLOV8nn文件夹modules.pytask.pymodels文件夹总结前言
因为毕设用到了YOLO#xff0c;鉴于最近V8刚出#xff0c;因此考虑将注意力机制加入到v8中。
YOLOV8
代码地址如何增加注意力模块以CBAM模块为例前言YOLOV8nn文件夹modules.pytask.pymodels文件夹总结前言
因为毕设用到了YOLO鉴于最近V8刚出因此考虑将注意力机制加入到v8中。
YOLOV8
代码地址YOLOV8官方代码
使用pip安装或者clone到本地在此不多赘述了。下面以使用pip安装ultralytics包为例介绍。 进入ultralytics文件夹
nn文件夹
再进入nn文件夹。
-- modules.py在里面存放着各种常用的模块如ConvDWConvConvTransposeTransformerLayerBottleneck等
-- tasks.py: 在里面导入了modules中的基本模块组建model根据不同的下游任务组建不同的model。modules.py
在该文件中我们可以写入自己的注意力模块或者使用V8已经提供的CBAM模块见代码的CBAM类 通道注意力模型: 通道维度不变压缩空间维度。该模块关注输入图片中有意义的信息。
1假设输入的数据大小是(b,c,w,h)
2通过自适应平均池化使得输出的大小变为(b,c,1,1)
3通过2d卷积和sigmod激活函数后大小是(b,c,1,1)
4将上一步输出的结果和输入的数据相乘输出数据大小是(b,c,w,h)。class ChannelAttention(nn.Module):# Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdetdef __init__(self, channels: int) - None:super().__init__()self.pool nn.AdaptiveAvgPool2d(1)self.fc nn.Conv2d(channels, channels, 1, 1, 0, biasTrue)self.act nn.Sigmoid()def forward(self, x: torch.Tensor) - torch.Tensor:return x * self.act(self.fc(self.pool(x)))
空间注意力模块空间维度不变压缩通道维度。该模块关注的是目标的位置信息。
1 假设输入的数据x是(b,c,w,h)并进行两路处理。
2其中一路在通道维度上进行求平均值得到的大小是(b,1,w,h)另外一路也在通道维度上进行求最大值得到的大小是(b,1,w,h)。
3 然后对上述步骤的两路输出进行连接输出的大小是(b,2,w,h)
4经过一个二维卷积网络把输出通道变为1输出大小是(b,1,w,h)
4将上一步输出的结果和输入的数据x相乘最终输出数据大小是(b,c,w,h)。class SpatialAttention(nn.Module):# Spatial-attention moduledef __init__(self, kernel_size7):super().__init__()assert kernel_size in (3, 7), kernel size must be 3 or 7padding 3 if kernel_size 7 else 1self.cv1 nn.Conv2d(2, 1, kernel_size, paddingpadding, biasFalse)self.act nn.Sigmoid()def forward(self, x):return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdimTrue), torch.max(x, 1, keepdimTrue)[0]], 1)))class CBAM(nn.Module):# Convolutional Block Attention Moduledef __init__(self, c1, kernel_size7): # ch_in, kernelssuper().__init__()self.channel_attention ChannelAttention(c1)self.spatial_attention SpatialAttention(kernel_size)def forward(self, x):return self.spatial_attention(self.channel_attention(x))如果使用V8的CBAM模块则不需要更改modules.py的内容。如果使用自己的注意力模块只需要在该文件后面添加对应的代码即可。
task.py
在该文件中通过import modules.py文件中的模块来构建模型。 在文件开头导入需要的模块可以看到modules中的很多模块在v8中并没有用到。我们在最后添加对应的CBAM模块。
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,GhostBottleneck, GhostConv, Segment, CBAM)之后修改对应的parse_model方法对应428行 添加分支elif m is CBAM:具体代码如下
def parse_model(d, ch, verboseTrue): # model_dict, input_channels(3)# Parse a YOLO model.yaml dictionaryif verbose:LOGGER.info(f\n{:3}{from:20}{n:3}{params:10} {module:45}{arguments:30})nc, gd, gw, act d[nc], d[depth_multiple], d[width_multiple], d.get(activation)if act:Conv.default_act eval(act) # redefine default activation, i.e. Conv.default_act nn.SiLU()if verbose:LOGGER.info(f{colorstr(activation:)} {act}) # printch [ch]layers, save, c2 [], [], ch[-1] # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d[backbone] d[head]): # from, number, module, argsm eval(m) if isinstance(m, str) else m # eval stringsfor j, a in enumerate(args):# TODO: re-implement with eval() removal if possible# args[j] (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else awith contextlib.suppress(NameError):args[j] eval(a) if isinstance(a, str) else a # eval stringsn n_ max(round(n * gd), 1) if n 1 else n # depth gainif m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):c1, c2 ch[f], args[0]if c2 ! nc: # if c2 not equal to number of classes (i.e. for Classify() output)c2 make_divisible(c2 * gw, 8)args [c1, c2, *args[1:]]if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x):args.insert(2, n) # number of repeatsn 1elif m is nn.BatchNorm2d:args [ch[f]]elif m is Concat:c2 sum(ch[x] for x in f)elif m in (Detect, Segment):args.append([ch[x] for x in f])if m is Segment:args[2] make_divisible(args[2] * gw, 8)elif m is CBAM:ch[f]:上一层的args[0]:第0个参数c1:输入通道数c2:输出通道数c1, c2 ch[f], args[0]# print(ch[f]:,ch[f])# print(args[0]:,args[0])# print(args:,args)# print(c1:,c1)# print(c2:,c2)if c2 ! nc: # if c2 not equal to number of classes (i.e. for Classify() output)c2 make_divisible(c2 * gw, 8)args [c1,*args[1:]]else:c2 ch[f]m_ nn.Sequential(*(m(*args) for _ in range(n))) if n 1 else m(*args) # modulet str(m)[8:-2].replace(__main__., ) # module typem.np sum(x.numel() for x in m_.parameters()) # number paramsm_.i, m_.f, m_.type i, f, t # attach index, from index, typeif verbose:LOGGER.info(f{i:3}{str(f):20}{n_:3}{m.np:10.0f} {t:45}{str(args):30}) # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x ! -1) # append to savelistlayers.append(m_)if i 0:ch []ch.append(c2)return nn.Sequential(*layers), sorted(save)注意传入的参数为上一层输出要注意CBAM模块的参数和传入参数的对应。读者可以自行print比较。
models文件夹
返回上一级目录进入models文件夹。 可以看到该文件夹中还有v5、v3对应的模型配置文件所以也可以使用该包进行v5和v3的训练。 进入v8文件夹 打开对应的yolov8.yaml如下所示。该文件是V8对应的配置文件里面包括了类别数模型大小n,s,m,l,xbackbone和head。
# Ultralytics YOLO , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. modelyolov8n.yaml will call yolov8.yaml with scale n# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, nearest]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, nearest]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)我们复制一份以yolov8x为例并改名为myyolo.yaml
# Ultralytics YOLO , GPL-3.0 license# Parameters
nc: 80 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.25 # scales convolution channels# YOLOv8.0x backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 3, CBAM, [128,7]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [512, 3, 2]] # 7-P5/32- [-1, 3, C2f, [512, True]]- [-1, 1, SPPF, [512, 5]] # 9- [-1, 3, CBAM, [512,7]]# YOLOv8.0x head
head:- [-1, 1, nn.Upsample, [None, 2, nearest]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, nearest]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [512]] # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
我们在SPPF模块后添加一层CBAM模块参数为[512,7]7为SpatialAttention对应的卷积核大小值可为3或7其他会报错。 添加完后使用对应的yaml配置文件训练即可。
yolo taskdetect modetrain modelmyyolo.yaml datadatasets/data/MOT20Det/VOC2007/mot20.yaml batch32 epochs80 imgsz640 workers16 device\0,1,2,3\值得注意的是如果添加了多层CBAM模块可能会导致各个模块对应的层数改变因此需要同时修改head中各个layer from对应的层数。
初始YOLOV8X默认的层数如下
# 默认
# 0 -1 1 2320 ultralytics.nn.modules.Conv [3, 80, 3, 2]
# 1 -1 1 115520 ultralytics.nn.modules.Conv [80, 160, 3, 2]
# 2 -1 3 436800 ultralytics.nn.modules.C2f [160, 160, 3, True]
# 3 -1 1 461440 ultralytics.nn.modules.Conv [160, 320, 3, 2]
# 4 -1 6 3281920 ultralytics.nn.modules.C2f [320, 320, 6, True]
# 5 -1 1 1844480 ultralytics.nn.modules.Conv [320, 640, 3, 2]
# 6 -1 6 13117440 ultralytics.nn.modules.C2f [640, 640, 6, True]
# 7 -1 1 3687680 ultralytics.nn.modules.Conv [640, 640, 3, 2]
# 8 -1 3 6969600 ultralytics.nn.modules.C2f [640, 640, 3, True]
# 9 -1 1 1025920 ultralytics.nn.modules.SPPF [640, 640, 5]
# 10 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, nearest]
# 11 [-1, 6] 1 0 ultralytics.nn.modules.Concat [1]
# 12 -1 3 7379200 ultralytics.nn.modules.C2f [1280, 640, 3]
# 13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, nearest]
# 14 [-1, 4] 1 0 ultralytics.nn.modules.Concat [1]
# 15 -1 3 1948800 ultralytics.nn.modules.C2f [960, 320, 3]
# 16 -1 1 922240 ultralytics.nn.modules.Conv [320, 320, 3, 2]
# 17 [-1, 12] 1 0 ultralytics.nn.modules.Concat [1]
# 18 -1 3 7174400 ultralytics.nn.modules.C2f [960, 640, 3]
# 19 -1 1 3687680 ultralytics.nn.modules.Conv [640, 640, 3, 2]
# 20 [-1, 9] 1 0 ultralytics.nn.modules.Concat [1]
# 21 -1 3 7379200 ultralytics.nn.modules.C2f [1280, 640, 3]
# 22 [15, 18, 21] 1 8795008 ultralytics.nn.modules.Detect [80, [320, 640, 640]] 增加对应的模块后之后的层数的layer1因此需要适当更改不然会报concat维度不匹配的错误如下
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 32 for tensor number 1 in the list.总结
添加注意力模块只需要3步 1、在对应的modules.py中添加需要的模块 2、在task.py中引入modules.py中的模块并进行适当的参数匹配 3、修改对应的models文件夹中的yaml文件并注意层数问题。 之后就可以进行正常训练了 文章转载自: http://www.morning.rkjb.cn.gov.cn.rkjb.cn http://www.morning.gwxsk.cn.gov.cn.gwxsk.cn http://www.morning.rfyk.cn.gov.cn.rfyk.cn http://www.morning.pjbhk.cn.gov.cn.pjbhk.cn http://www.morning.gthc.cn.gov.cn.gthc.cn http://www.morning.smxyw.cn.gov.cn.smxyw.cn http://www.morning.gcjhh.cn.gov.cn.gcjhh.cn http://www.morning.bdqpl.cn.gov.cn.bdqpl.cn http://www.morning.wnxqf.cn.gov.cn.wnxqf.cn http://www.morning.zkqjz.cn.gov.cn.zkqjz.cn http://www.morning.cfrz.cn.gov.cn.cfrz.cn http://www.morning.qmbpy.cn.gov.cn.qmbpy.cn http://www.morning.jfxth.cn.gov.cn.jfxth.cn http://www.morning.mdtfh.cn.gov.cn.mdtfh.cn http://www.morning.nyqnk.cn.gov.cn.nyqnk.cn http://www.morning.c7617.cn.gov.cn.c7617.cn http://www.morning.thlr.cn.gov.cn.thlr.cn http://www.morning.thwcg.cn.gov.cn.thwcg.cn http://www.morning.ctsjq.cn.gov.cn.ctsjq.cn http://www.morning.ksgjn.cn.gov.cn.ksgjn.cn http://www.morning.tlfyb.cn.gov.cn.tlfyb.cn http://www.morning.jqhrk.cn.gov.cn.jqhrk.cn http://www.morning.trjp.cn.gov.cn.trjp.cn http://www.morning.ksjmt.cn.gov.cn.ksjmt.cn http://www.morning.wtlyr.cn.gov.cn.wtlyr.cn http://www.morning.qbrs.cn.gov.cn.qbrs.cn http://www.morning.kxwsn.cn.gov.cn.kxwsn.cn http://www.morning.zhffz.cn.gov.cn.zhffz.cn http://www.morning.qkwxp.cn.gov.cn.qkwxp.cn http://www.morning.pbtrx.cn.gov.cn.pbtrx.cn http://www.morning.haolipu.com.gov.cn.haolipu.com http://www.morning.crfyr.cn.gov.cn.crfyr.cn http://www.morning.kndst.cn.gov.cn.kndst.cn http://www.morning.qxlgt.cn.gov.cn.qxlgt.cn http://www.morning.sqqds.cn.gov.cn.sqqds.cn http://www.morning.sffkm.cn.gov.cn.sffkm.cn http://www.morning.jygsq.cn.gov.cn.jygsq.cn http://www.morning.wdply.cn.gov.cn.wdply.cn http://www.morning.qscsy.cn.gov.cn.qscsy.cn http://www.morning.blxlf.cn.gov.cn.blxlf.cn http://www.morning.sfgtp.cn.gov.cn.sfgtp.cn http://www.morning.nyqxy.cn.gov.cn.nyqxy.cn http://www.morning.ghssm.cn.gov.cn.ghssm.cn http://www.morning.ymfzd.cn.gov.cn.ymfzd.cn http://www.morning.wfmqc.cn.gov.cn.wfmqc.cn http://www.morning.lkfsk.cn.gov.cn.lkfsk.cn http://www.morning.kjfqf.cn.gov.cn.kjfqf.cn http://www.morning.pttrs.cn.gov.cn.pttrs.cn http://www.morning.ryfpx.cn.gov.cn.ryfpx.cn http://www.morning.jfnbh.cn.gov.cn.jfnbh.cn http://www.morning.dpmkn.cn.gov.cn.dpmkn.cn http://www.morning.bwrbm.cn.gov.cn.bwrbm.cn http://www.morning.qgmbx.cn.gov.cn.qgmbx.cn http://www.morning.nxwk.cn.gov.cn.nxwk.cn http://www.morning.qrzqd.cn.gov.cn.qrzqd.cn http://www.morning.ndtzy.cn.gov.cn.ndtzy.cn http://www.morning.cnwpb.cn.gov.cn.cnwpb.cn http://www.morning.skqfx.cn.gov.cn.skqfx.cn http://www.morning.mlckd.cn.gov.cn.mlckd.cn http://www.morning.qqnh.cn.gov.cn.qqnh.cn http://www.morning.thmlt.cn.gov.cn.thmlt.cn http://www.morning.kqblk.cn.gov.cn.kqblk.cn http://www.morning.wqmyh.cn.gov.cn.wqmyh.cn http://www.morning.tsrg.cn.gov.cn.tsrg.cn http://www.morning.gwhjy.cn.gov.cn.gwhjy.cn http://www.morning.dhmll.cn.gov.cn.dhmll.cn http://www.morning.pfggj.cn.gov.cn.pfggj.cn http://www.morning.zxqyd.cn.gov.cn.zxqyd.cn http://www.morning.stmkm.cn.gov.cn.stmkm.cn http://www.morning.yznsx.cn.gov.cn.yznsx.cn http://www.morning.tlrxt.cn.gov.cn.tlrxt.cn http://www.morning.smwlr.cn.gov.cn.smwlr.cn http://www.morning.nlkjq.cn.gov.cn.nlkjq.cn http://www.morning.psxwc.cn.gov.cn.psxwc.cn http://www.morning.hyhzt.cn.gov.cn.hyhzt.cn http://www.morning.ykwgl.cn.gov.cn.ykwgl.cn http://www.morning.qgbfx.cn.gov.cn.qgbfx.cn http://www.morning.rjmb.cn.gov.cn.rjmb.cn http://www.morning.rywr.cn.gov.cn.rywr.cn http://www.morning.pmghz.cn.gov.cn.pmghz.cn