常州上位机培训

常州机器视觉培训

常州机器人培训

江苏和讯自动化设备有限公司欢迎您!
  • 和讯PLC,电工培训中心优势,名师团队一对一教学.
热门课程
联系方式
  • 常州和讯自动化培训中心
  • 常州市新北区太湖东路府琛大厦2号楼307-1室,307-2室(常州万达广场对面)
  • 电话:0519-85602926
  • 手机:15861139266 13401342299
当前位置:网站首页 > 技术资料 技术资料
CoTLayer与CoTNet-50的Pytorch实现,常州上位机培训,常州机器视觉培训,常州工业机器人培训
日期:2023-5-17 10:18:58人气:  标签:常州上位机培训 常州机器视觉培训 常州工业机器人培训

CoTLayer与CoTNet-50的Pytorch实现

CoTLayer部分

从上面的结构图中我们可以看到,CoT模块包含三个部分,我们需要先构建这三个基本模块,比较简单,用最基本的卷积操作可以搞定。


keys_embedding

Values_embedding

Attention_embedding

self.key_embed = nn.Sequential(

    # 通过K*K的卷积提取邻近上下文信息,视作输入X的静态上下文表达

    nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=1, stride=1, bias=False),

    nn.BatchNorm2d(dim),

    nn.ReLU()

)

self.value_embed = nn.Sequential(

    nn.Conv2d(dim, dim, kernel_size = 1, stride=1, bias=False),  # 1*1的卷积进行Value的编码

    nn.BatchNorm2d(dim)

)


factor = 4

self.attention_embed = nn.Sequential(  # 通过连续两个1*1的卷积计算注意力矩阵

    nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),  # 输入concat后的特征矩阵 Channel = 2*C

    nn.BatchNorm2d(2 * dim // factor),

    nn.ReLU(),

    nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1, stride=1)  # out: H * W * (K*K*C)

)

之后就是重写forward方法,这才是CoTLayer的关键。


首先得到Key和Value的编码


bs, c, h, w = x.shape

k1 = self.key_embed(x)  # shape:bs,c,h,w  提取静态上下文信息得到key

v = self.value_embed(x).view(bs, c, -1)  # shape:bs,c,h*w  得到value编码

使用torch.cat操作将key与输入x在channel纬度进行拼接,并得到注意力矩阵


y = torch.cat([k1, x], dim=1)  # shape:bs,2c,h,w  Key与Query在channel维度上进行拼接进行拼接

att = self.attention_embed(y)  # shape:bs,c*k*k,h,w  计算注意力矩阵

为了进行之后的静动态上下文信息的融合,需要把注意力矩阵进行reshape


att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)

att = att.mean(2, keepdim=False).view(bs, c, -1)  # shape:bs,c,h*w  求平均降低维度

k2 = F.softmax(att, dim=-1) * v  # 对每一个Channel进行softmax后

k2 = k2.view(bs, c, h, w)

最后return k1 + k2所谓此结构的输出


CoTNet替换ResNet-50

根据原文中的信息,我们只需要将Bottleneck中3*3的卷积替换为CoTLayer即可,但是具体实现起来还是会有一些问题,主要是涉及到图片的大小调整。



Figure4: CoTNet-50结构图

我们运行上面实现的CoTLayer会发现输入特征和输出特征的大小是不会改变的,但是ResNet-50会逐渐从224减小到1。为了解决此问题,我们需要加入额外的downsample操作。


if stride > 1:

    self.avd = nn.AvgPool2d(3, 2, padding=1)

else:

    self.avd = None

其他部分和ResNet-50结构相同,至此我们就实现了对CoTNet-50的复现,我将复现的代码放到了下面的链接中,可以自取之后替换自己项目中的ResNet-50查看效果。


本文网址:

相关信息:
版权所有 CopyRight 2006-2017 江苏和讯自动化设备有限公司 电话:0519-85602926 地址:常州市新北区太湖东路府琛大厦2号楼307-1室,307-2室
ICP备14016686号-2 技术支持:常州鹤翔网络
本站关键词:常州电工培训 常州电工证 常州变频器培训 常州触摸屏培训 网站地图 网站标签
在线与我们取得联系