【视觉 Transformer】超详细解读 CaiT 模型

CaiT

paper:https://arxiv.org/abs/2103.17239

浅谈 CaiT

Hi guy!我们又见面了,这次来解析一篇来自 FaceBook AI 的一篇视觉 Transformer 的相关工作 CaiT

Transformer 在视觉领域可谓风生水起,各大视觉相关榜单都被刷爆了,自从 Google 的 ViT: An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale 开始,很多研究者展开了基于 ViT 的改进工作,比较著名的是 DeiT、Swin Transformer、PVT 等视觉 Transformer 改进,以及一些新架构 MLP-Mixer、ResMLP、gMLP 等。毫无疑问,今年是视觉 Transformer 大爆炸的一年,这在 CVPR、ICCV 等视觉相关顶会可见一斑。

ViT 尽管取得了很高的精度,但这离不开大规模数据训练(JFT-300M),而且这样的数据集不开源,很多工作致力于解决 ViT 的数据问题和计算开销问题。数据问题是指 ViT 这样的视觉 Transformer 缺少一定的归纳偏置,需要更多的数据(相比 ConvNet 网络)来训练,否则很容易在小数据集上过拟合,参数量越大的模型越明显。计算开销问题是指 ViT 中的 MHSA 计算量与 Token 数平方相关(O(n2)),尽管提升 Token 数可以获得更好的表征从而得到更高的精度(这也是最简单直接的办法),但是其计算复杂度随着 Token 数增加呈二次方发展,这将会给模型带来庞大的计算量。

针对上述问题后续很多工作给出了改进,比如 DeiT 在蒸馏中引入 distillation token,LV-ViT 通过 token labeling 技巧辅助训练, 以及最近的 BEiT、MAE、MaskFeat 等自监督训练,这些都很好解决了视觉 Transformer 的数据问题,它们旨在让 ViT 仅在 ImageNet-1K 下就能获得具有竞争力的性能。而在计算量问题上,有诸如 Swin Transformer、CvT 等通过改进 MHSA(Multi-Head Self Attention)来降低计算量,也有 PVT/PVT v2 这样通过将 ViT 的直筒式结构改成金字塔结构以此来降低计算量。

相比上述工作,CaiT 则是思考如何加深网络

根据以前的经验,增加模型的深度可以使得网络学习更复杂的表征,比如 ResNet 从18 层到 152 层,随着层数的增加其精度逐渐提高

但是在 Transformer 中,当我们扩展架构时,模型变得越来越难训练,其中深度是不稳定的主要来源之一,例如 DeiT-S 在不调整超参数情况下不能正确收敛到 18 层以上,尽管结合一些调参技巧如线性调整 drop rate,DeiT-S 依然在36层达到饱和(实验均在 ImageNet 1K 下进行)

为了解决深度问题,CaiT 提出了两个改进,一个是 LayerScale,一个是 Class-Attention

LayerScale 在每个残差块的输出上添加一个可学习的对角矩阵,该矩阵被初始化为接近0。在每个残差块之后添加这个简单的层可以提高训练的动态性,使我们能够训练更深层次的大容量 Transformer,如下所示 $$ x_{l}^{'}=x_l+diag(lambda_1,…,lambda _d)times SA(Norm(x_l)) $$

$$ x_{l+1}=x_{l}^{'}+diag(lambda_{1}^{'},…,lambda_{d}^{'})times FFN(Norm(x_{l}^{'})) $$

Class-Attention 是一个类似于 Encode/Decode 的结构,和 Self-Attention 不同,Class-Attention 更注重从处理过的 patches token 中提取信息,相比 SA 主要是 Q(query)的自变量 z 变成 xclass,而 K(keys)、V(value)则保持不变,如下所示 $$ Q=W_q x_{class}+b_q $$

$$ K=W_kz+b_k $$

$$ V=W_vz+b_v $$

其中 $$ z=[x_{class},x_{patches}] $$

流程解析

我们先看一下论文给出的结构图,如下所示

左边是 ViT 网络结构,CLS(classes token)与 Patch Embedding 一起被送进网络,最后输出 CLS 做分类

右边则是 CaiT 网络结构,相对于左边的 ViT 结构而言,最直观的变化是 CLA 被放入网络更深的层。

下面我们详细过一遍 CaiT 的流程,CaiT 流程可以分为三个部分,如下图所示

首先是 patch embedding 操作,将输入划分为不同的 patches。Patch Embedding 是视觉 Transformer 常见的操作,这里不做过多的解释,相关代码如下所示

class PatchEmbed(nn.Layer):
""" 2D Image to Patch Embedding
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2D(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else Identity()

def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC
x = self.norm(x)
return x

<< · Back Index ·>>

发表回复

相关推荐

【30名】2023年溧陽市教育系統面向社會公開招聘幼兒園備案制教師公告

為更好地選拔優秀人才,充實教師隊伍,優化人員結構,根據《中共溧陽市委溧陽市人民政府關於印發的通知》(溧委

· 5分钟前

【北交就业】计算机学院毕业就业单位及去向

北京交通大学研究生毕业能去哪?就业待遇如何?签约比例怎样?

· 15分钟前

當日事必須當日畢

人性本身是放縱、散漫的,表現就是對目標的堅持、時間的控制等做得不到位,事情不能按時完成。如果拖延已開始影響工作的質量...

· 18分钟前

《晃過上帝》:不忘初心,夢想必將照亮現實

《晃過上帝之重返街頭》是《晃過上帝》系列的完結篇,一次真正理想照亮現實的圓滿收官。這部系列電影說明瞭這樣一個道理,夢...

· 25分钟前

10位清华毕业的85后县委书记!

清华大学堪称中国的顶尖学府,能有幸走进清华园读书无疑都是学霸级人物,因此他们的毕业去向一直备受社会关注。值得注意的是 ...

· 31分钟前