西安企业建站机构那里有,甘肃省建设厅执业注册中心网站,北京 网站建设咨询顾问公司,关键词排名点击软件Cswin提出了上图中使用交叉形状局部attention#xff0c;为了解决VIT模型中局部自注意力感受野进一步增长受限的问题#xff0c;同时提出了局部增强位置编码模块#xff0c;超越了Swin等模型#xff0c;在多个任务上效果SOTA#xff08;当时的SOTA#xff0c;已经被SG Fo… Cswin提出了上图中使用交叉形状局部attention为了解决VIT模型中局部自注意力感受野进一步增长受限的问题同时提出了局部增强位置编码模块超越了Swin等模型在多个任务上效果SOTA当时的SOTA已经被SG Former超越感兴趣的可以看看SG Former。 论文地址https://arxiv.org/abs/2107.00652 代码地址https://github.com/microsoft/CSWin-Transformer 模型整体结构如上所示由token embeeding layer和4个stageblock所堆叠而成每个stage block后面都会接入一个conv层用来对featuremap进行下采样。和典型的R50设计类似每次下采样后会增加dim的数量一是为了提升感受野二是为了增加特征性。
研究动机
基于global attention的transformer效果虽然好但是计算复杂度与特征图大小平方(HW的情况)成正比。基于local attention的transformer的会限制每个token的感受野的交互减缓感受野的增长需要堆叠大量的block来实现全局自注意力。
解决办法
提出了Cross-Shaped Window self-attention机制对注意力头进行分组并行计算水平和竖直方向的self-attention可以在更小的计算量条件下获得更好的效果。提出了Locally-enhanced Positional Encoding(LePE), 可以更好的处理局部位置信息并且支持任意形状的输入。
1.1 Convolutional Token Embedding 用convolution来做embedding为了减少计算量本文直接采用了7x7的卷积核stride为4的卷积来直接对输入进行embedding之后再对最后一维进行layernorm。
self.stage1_conv_embed nn.Sequential(nn.Conv2d(in_chans, embed_dim, 7, 4, 2),Rearrange(b c h w - b (h w) c, himg_size // 4, wimg_size // 4),nn.LayerNorm(embed_dim)
)
1.2 Cross-Shaped Window Self-Attention 具体来讲假设原始的Feature Map为为了计算它在横向上的自注意力它首先被拆分成个横条的数据(实际代码先进行竖列处理)其中是横条的宽度。在这4个不同的Stage中取不同的值实验结果表明[1,2,7,7]这组值在速度和精度上取得了比较好的均衡。 对于每个条状特征使用Transformer可以得到它的特征最后将这个特征拼接到一起便得到了这个head的输入。假设它属于第个head那么横向自注意力的计算方式为 纵向自注意力V-Attention 和H-Attention的计算方式类似不同的是它是取的宽度为的竖条。
最终这个block的输出表示为 CSWin self-attention计算复杂度分析 对于高分辨率输入HW早期大于C后期小于C因此早期sw小后期大。即调整sw可以有效地扩大后期每个token的attention区域。为了使224×224输入的中间特征图大小可被sw整除默认将4个阶段的sw设置为1、2、7、7。
def img2windows(img, H_sp, W_sp):img: B C H WB, C, H, W img.shapeimg_reshape img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]img_perm img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) # [N*56*1 56 32] [N*56*1 56 32] / [N*14*1 56 64] [N*14*1 56 64] / [N*2*1 98 128] [N*2*1 98 128] / [N*1*1 49 512]return img_permdef windows2img(img_splits_hw, H_sp, W_sp, H, W):img_splits_hw: B H W CB int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))img img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) # [N*56*1 56 32]-[N 1 56 56 1 32] [N*56*1 56 32]-[N 56 1 1 56 32] / [N*14*1 56 64]-[N 1 14 28 2 64] [N*14*1 56 64]-[N 14 1 2 28 64] / [N*2*1 98 128]-[N 1 2 14 7 128] [N*2*1 98 128]-[N 2 1 7 14 128] / [N*1*1 49 512]-[N 1 1 7 7 512]img img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # [N 56 56 32] [N 28 28 64] [N 14 14 128] [N 7 7 512]return imgclass LePEAttention(nn.Module):def __init__(self, dim, resolution, idx, split_size7, dim_outNone, num_heads8, attn_drop0., proj_drop0.,qk_scaleNone):super().__init__()self.dim dimself.dim_out dim_out or dimself.resolution resolutionself.split_size split_sizeself.num_heads num_headshead_dim dim // num_heads# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weightsself.scale qk_scale or head_dim ** -0.5if idx -1:H_sp, W_sp self.resolution, self.resolutionelif idx 0:H_sp, W_sp self.resolution, self.split_sizeelif idx 1:W_sp, H_sp self.resolution, self.split_sizeelse:print(ERROR MODE, idx)exit(0)self.H_sp H_spself.W_sp W_spstride 1self.get_v nn.Conv2d(dim, dim, kernel_size3, stride1, padding1, groupsdim)self.attn_drop nn.Dropout(attn_drop)def im2cswin(self, x):B, N, C x.shapeH W int(np.sqrt(N))x x.transpose(-2, -1).contiguous().view(B, C, H, W) # [B, N, C] - [B, C, N] - [B, C, H, W]x img2windows(x, self.H_sp, self.W_sp) # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]x x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1,3).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]return xdef get_lepe(self, x, func):B, N, C x.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]H W int(np.sqrt(N))x x.transpose(-2, -1).contiguous().view(B, C, H, W) # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]H_sp, W_sp self.H_sp, self.W_spx x.view(B, C, H // H_sp, H_sp, W // W_sp,W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]x x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,W_sp) ### B, C, H, W # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]lepe func(x) ### B, C, H, W # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]lepe lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]x x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]return x, lepedef forward(self, qkv):x: B L Cq, k, v qkv[0], qkv[1], qkv[2] # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]### Img2WindowH W self.resolution # 56 28 14 7B, L, C q.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]assert L H * W, flatten img_tokens has wrong sizeq self.im2cswin(q) # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]k self.im2cswin(k) # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]v, lepe self.get_lepe(v, self.get_v)q q * self.scaleattn (q k.transpose(-2, -1)) # B head N C B head C N -- B head N Nattn nn.functional.softmax(attn, dim-1, dtypeattn.dtype)attn self.attn_drop(attn)x (attn v) lepex x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp,C) # B head N N B head N C # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]### Window2Imgx windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) # B H W Creturn x # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
代码部分其实和Swin类似如果理解了swin的分窗机制再加上head分组基本上就能很快理解论文中思想。
1.3 Locally-Enhanced Positional Encoding(LePE) 因为Transformer是输入顺序无关的因此需要向其中加入位置编码。上图左边为ViT模型的PE使用的绝对位置编码或者是条件位置编码只在embedding的时候与token一起进入transformer中间的是SwinCrossFormer等模型的PE使用相对位置编码偏差通过引入token图的权重来和attention一起计算灵活度更好相对APE效果更好。 本文所提出的LePE相比于RPE更加直接将位置信息施加到线性投影中同时注意到RPE以head方式引入偏差而LepE是per-channel bias这可能显示出更强大的潜力来充当位置嵌入。也就是直接将位置编码添加加到了Value向量上假设位置编码为它的添加方式是通过将位置编码和相乘完成的。然后通过一个short-cut将添加了位置编码的和通过自注意力加权的单位加到一起公式如下: 这里作者基于一个假设对于一个输入元素他附近的元素提供最重要的位置信息。所以对V做一个深度卷积加到softmax之后的结果上。公式为: 这样LePE可以友好地应用于将任意输入分辨率作为输入的下游任务。 def get_lepe(self, x, func):# func - self.get_v nn.Conv2d(dim, dim, kernel_size3, stride1, padding1,groupsdim)B, N, C x.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]H W int(np.sqrt(N))x x.transpose(-2, -1).contiguous().view(B, C, H, W) # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]H_sp, W_sp self.H_sp, self.W_spx x.view(B, C, H // H_sp, H_sp, W // W_sp,W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]x x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,W_sp) ### B, C, H, W # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]lepe func(x) ### B, C, H, W # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]lepe lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]x x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]return x, lepe
1.4 CSWin Transformer Block CSWin Transformer Block的结构如图所示它最显著的特点是添加了两个shortcut并使用LN对特征做归一化.
网络结构配置 其中为第 个Transformer block的输出或各stage的卷积层。 CSwin的block有两个部分一个是做LayerNorm和Cross-shaped window self-attention并接一个shortcut另一个则是做LayerNorm和MLP相比于Swin和Twins来说block的计算量大大的降低了(swin,twins则是有两个attention两个MLP堆叠一个block)。
class CSWinBlock(nn.Module):def __init__(self, dim, reso, num_heads,split_size7, mlp_ratio4., qkv_biasFalse, qk_scaleNone,drop0., attn_drop0., drop_path0.,act_layernn.GELU, norm_layernn.LayerNorm,last_stageFalse):super().__init__()self.dim dimself.num_heads num_headsself.patches_resolution resoself.split_size split_sizeself.mlp_ratio mlp_ratioself.qkv nn.Linear(dim, dim * 3, biasqkv_bias)self.norm1 norm_layer(dim)if self.patches_resolution split_size:last_stage Trueif last_stage:self.branch_num 1else:self.branch_num 2self.proj nn.Linear(dim, dim)self.proj_drop nn.Dropout(drop)if last_stage:self.attns nn.ModuleList([LePEAttention(dim, resolutionself.patches_resolution, idx -1,split_sizesplit_size, num_headsnum_heads, dim_outdim,qk_scaleqk_scale, attn_dropattn_drop, proj_dropdrop)for i in range(self.branch_num)])else:self.attns nn.ModuleList([LePEAttention(dim//2, resolutionself.patches_resolution, idx i,split_sizesplit_size, num_headsnum_heads//2, dim_outdim//2,qk_scaleqk_scale, attn_dropattn_drop, proj_dropdrop)for i in range(self.branch_num)])mlp_hidden_dim int(dim * mlp_ratio)self.drop_path DropPath(drop_path) if drop_path 0. else nn.Identity()self.mlp Mlp(in_featuresdim, hidden_featuresmlp_hidden_dim, out_featuresdim, act_layeract_layer, dropdrop)self.norm2 norm_layer(dim)def forward(self, x):x: B, H*W, CH W self.patches_resolution # 56B, L, C x.shape # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]assert L H * W, flatten img_tokens has wrong sizeimg self.norm1(x)qkv self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # [3 N 3136 64] [3 N 784 128] [3 N 196 256] [3 N 49 512]if self.branch_num 2:x1 self.attns[0](qkv[:,:,:,:C//2]) # qkv[3 N 3136 32]-x1[N 3136 32] qkv[3 N 784 128]-x1[N 784 64] qkv[3 N 196 256]-x1[N 196 128]x2 self.attns[1](qkv[:,:,:,C//2:]) # qkv[3 N 3136 32]-x2[N 3136 32] qkv[3 N 784 128]-x1[N 784 64] qkv[3 N 196 256]-x1[N 196 128]attened_x torch.cat([x1,x2], dim2)else:attened_x self.attns[0](qkv) # [3 N 49 512]-[N 49 512]attened_x self.proj(attened_x)x x self.drop_path(attened_x)x x self.drop_path(self.mlp(self.norm2(x)))return x # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
在相似网络参数和计算量的模型中cswin在分类任务和各类下游任务中都做到了SOTA 检测 分割