paper:https://arxiv.org/abs/2103.17239
Hi guy!我们又见面了,这次来解析一篇来自 FaceBook AI 的一篇视觉 Transformer 的相关工作 CaiT
Transformer 在视觉领域可谓风生水起,各大视觉相关榜单都被刷爆了,自从 Google 的 ViT: An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale 开始,很多研究者展开了基于 ViT 的改进工作,比较著名的是 DeiT、Swin Transformer、PVT 等视觉 Transformer 改进,以及一些新架构 MLP-Mixer、ResMLP、gMLP 等。毫无疑问,今年是视觉 Transformer 大爆炸的一年,这在 CVPR、ICCV 等视觉相关顶会可见一斑。
ViT 尽管取得了很高的精度,但这离不开大规模数据训练(JFT-300M),而且这样的数据集不开源,很多工作致力于解决 ViT 的数据问题和计算开销问题。数据问题是指 ViT 这样的视觉 Transformer 缺少一定的归纳偏置,需要更多的数据(相比 ConvNet 网络)来训练,否则很容易在小数据集上过拟合,参数量越大的模型越明显。计算开销问题是指 ViT 中的 MHSA 计算量与 Token 数平方相关(O(n2)),尽管提升 Token 数可以获得更好的表征从而得到更高的精度(这也是最简单直接的办法),但是其计算复杂度随着 Token 数增加呈二次方发展,这将会给模型带来庞大的计算量。
针对上述问题后续很多工作给出了改进,比如 DeiT 在蒸馏中引入 distillation token,LV-ViT 通过 token labeling 技巧辅助训练, 以及最近的 BEiT、MAE、MaskFeat 等自监督训练,这些都很好解决了视觉 Transformer 的数据问题,它们旨在让 ViT 仅在 ImageNet-1K 下就能获得具有竞争力的性能。而在计算量问题上,有诸如 Swin Transformer、CvT 等通过改进 MHSA(Multi-Head Self Attention)来降低计算量,也有 PVT/PVT v2 这样通过将 ViT 的直筒式结构改成金字塔结构以此来降低计算量。
相比上述工作,CaiT 则是思考如何加深网络
根据以前的经验,增加模型的深度可以使得网络学习更复杂的表征,比如 ResNet 从18 层到 152 层,随着层数的增加其精度逐渐提高
但是在 Transformer 中,当我们扩展架构时,模型变得越来越难训练,其中深度是不稳定的主要来源之一,例如 DeiT-S 在不调整超参数情况下不能正确收敛到 18 层以上,尽管结合一些调参技巧如线性调整 drop rate,DeiT-S 依然在36层达到饱和(实验均在 ImageNet 1K 下进行)
为了解决深度问题,CaiT 提出了两个改进,一个是 LayerScale,一个是 Class-Attention
LayerScale 在每个残差块的输出上添加一个可学习的对角矩阵,该矩阵被初始化为接近0。在每个残差块之后添加这个简单的层可以提高训练的动态性,使我们能够训练更深层次的大容量 Transformer,如下所示 $$ x_{l}^{'}=x_l+diag(lambda_1,…,lambda _d)times SA(Norm(x_l)) $$
$$ x_{l+1}=x_{l}^{'}+diag(lambda_{1}^{'},…,lambda_{d}^{'})times FFN(Norm(x_{l}^{'})) $$
Class-Attention 是一个类似于 Encode/Decode 的结构,和 Self-Attention 不同,Class-Attention 更注重从处理过的 patches token 中提取信息,相比 SA 主要是 Q(query)的自变量 z 变成 xclass,而 K(keys)、V(value)则保持不变,如下所示 $$ Q=W_q x_{class}+b_q $$
$$ K=W_kz+b_k $$
$$ V=W_vz+b_v $$
其中 $$ z=[x_{class},x_{patches}] $$
我们先看一下论文给出的结构图,如下所示
左边是 ViT 网络结构,CLS(classes token)与 Patch Embedding 一起被送进网络,最后输出 CLS 做分类
右边则是 CaiT 网络结构,相对于左边的 ViT 结构而言,最直观的变化是 CLA 被放入网络更深的层。
下面我们详细过一遍 CaiT 的流程,CaiT 流程可以分为三个部分,如下图所示
首先是 patch embedding 操作,将输入划分为不同的 patches。Patch Embedding 是视觉 Transformer 常见的操作,这里不做过多的解释,相关代码如下所示
class PatchEmbed(nn.Layer):
""" 2D Image to Patch Embedding
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2D(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC
x = self.norm(x)
return x
<< · Back Index ·>>