【论文笔记】An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Abstract
1 Introduction
2 Related Work
3 Method
3.1 Vision Transformer (ViT)

结合代码对ViT结构进行解释,代码详见models_vit.py第211行,具体VisionTransformer
类的代码如下:
1 | class VisionTransformer(nn.Module): |
原文:
An overview of the model is depicted in Figure 1. The standard Transformer receives as input a 1D sequence of token embeddings. To handle 2D images, we reshape the image \(\mathbf{x} \in \mathbb{R}^{H \times W \times C}\) into a sequence of flattened 2D patches \(\mathbf{x}_p \in \mathbb{R}^{N \times (P^2 \cdot C)}\), where \((H, W)\) is the resolution of the original image, \(C\) is the number of channels, \((P, P)\) is the resolution of each image patch, and \(N = \frac{H}{P} \times \frac{W}{P} = \frac{HW}{P^2}\) is the resulting number of patches, which also serves as the effective input sequence length for the Transformer. The Transformer uses constant latent vector size \(D\) through all of its layers, so we flatten the patches and map to \(D\) dimensions with a trainable linear projection (Eq. 1). We refer to the output of this projection as the patch embeddings.
根据原文中的说法,输入一张维度为\(\mathbf{x} \in \mathbb{R}^{H \times W \times C}\)的图像后,ViT会将输入的形状reshape为\(\mathbf{x}_p \in \mathbb{R}^{N \times (P^2 \cdot C)}\),其中\((H, W)\)为原图像的分辨率,\(C\)为通道数,\((P, P)\)为每个patch的分辨率,\(N = \frac{H}{P} \times \frac{W}{P} = \frac{HW}{P^2}\)为patch的数量,这一步没有涉及到网络操作,总的像素数\(H \times W \times C = N \times (P^2 \cdot C)\)不变,所以直接用reshape操作即可。
后续的代码中涉及到了batch_size,为了前后统一,这里为输入图像增加一个batch_size维度,即\(\mathbf{x} \in \mathbb{R}^{n \times H \times W \times C}\),其中\(n\)为batch_size。
在reshape之后,\(\mathbf{x}_p\)的维度是\(\mathbb{R}^{n \times N \times (P^2 \cdot C)}\),以Transformer的视角来看,\(n\)幅图像被表示为\(n\)个句子,每个句子中有\(N\)个单词,每个单词用长度为\(P^2 \cdot C\)的向量来表示。但是ViT没有直接将\(\mathbf{x}_p\)输入进Transformer中,而是使用一个线性投影\(\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D}\)将\(\mathbf{x}_p \in \mathbb{R}^{n \times N \times (P^2 \cdot C)}\)投影到\(\mathbb{R}^{n \times N \times D}\),也就是将每个单词用长度为\(D\)的向量表示。并且,由于该篇论文中ViT做的是有监督分类任务,所以会在每一个句子前面加上一个类别标签\(\mathbf{x}_\text{class}\)。在最后,Transformer的输入是\(\mathbf{z}_0 = [\mathbf{x}_\text{class}; \mathbf{x}_p^1\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos} \in \mathbb{R}^{n \times (N + 1) \times D}\),其中,\(\mathbf{x}_p^i\)的维度为\(\mathbb{R}^{n \times 1 \times D}\),\(\mathbf{E}_{pos} \in \mathbb{R}^{(N + 1) \times D}\)是位置编码。
但是实际上,ViT中所谓的线性投影是一个CNN,对应于上述代码的49-55行:
1 | class VisionTransformer(nn.Module): |
ViT的原仓库是用TensorFlow实现的,其中features
是卷积核数量,kernel_size
是卷积核大小,strides
是步长,padding
是填充方式,name
是层名。这里的self.hidden_size
就是上文提到的\(D\),self.patches.size
就是上文提到的\(P\)。该卷积操作的卷积核大小和步长大小相同,所以将输入图片经过这个卷积操作后,每一个patch都会被映射成一个\(D\)维的向量,且patch和patch之间没有重合部分,那么最后输出的维度是\(\mathbb{R}^{n \times \frac{H}{P} \times
\frac{W}{P} \times D}\)。
线性投影后,ViT会对序列进行reshape,并在序列前面加上一个类别标签\(\mathbf{x}_\text{class}\),对应于上述代码的59-70行:
1 | class VisionTransformer(nn.Module): |
在n, h, w, c = x.shape
中,n
是batch_size,h
为特征图的高度\(\frac{H}{P}\),w
为特征图的宽度\(\frac{W}{P}\),c
为特征图的通道数\(D\),经过reshape后,batch_size不变,特征图的高度\(\frac{H}{P}\)和宽度\(\frac{W}{P}\)相乘,得到\(N\),即patch的数量,至此,reshape操作和线性映射操作结束,得到序列的维度为\(\mathbb{R}^{n \times N \times D}\)。
接下来是在序列前面加上一个类别标签\(\mathbf{x}_\text{class}\),cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
创建了一个维度为\(\mathbb{R}^{1 \times 1 \times
D}\)的类别标签,cls = jnp.tile(cls, [n, 1, 1])
将类别标签在第0维复制\(n\)次,得到维度为\(\mathbb{R}^{n \times 1 \times
D}\)的类别标签,最后,x = jnp.concatenate([cls, x], axis=1)
将\(\mathbb{R}^{n \times 1 \times
D}\)的类别标签和\(\mathbb{R}^{n \times
N \times D}\)的序列在维度1拼接起来,得到维度为\(\mathbb{R}^{n \times (N + 1) \times
D}\)的输入序列。
但是此时还没有加入位置编码,在该篇论文ViT的实现中,位置编码是在调用encoder
时加入的。
x = self.encoder(name='Transformer', **self.transformer)(x, train=train)
中调用了encoder
,encoder
是Encoder
类的实例,Encoder
类的定义在models_vit.py第159行:
1 | class Encoder(nn.Module): |
其中第33行还有一个AddPositionEmbs
类,位置编码是在该步操作中加入的,AddPositionEmbs
类的定义在models_vit.py第37行中:
1 | class AddPositionEmbs(nn.Module): |
输入到Transformer的encoder之前,AddPositionEmbs
类为输入序列添加了位置编码,注意,ViT中的位置编码是可学习的位置编码,而不是像原版Transformer中那样使用正余弦函数生成位置编码。在添加位置编码后,维度为\(\mathbb{R}^{n \times (N + 1) \times
D}\)的数据被输入到Transformer的encoder中,得到输出序列。
最后,执行分类任务,对应于上述代码的72-93行:
1 | class VisionTransformer(nn.Module): |
由于Transformer
encoder处理前后数据的维度不变,所以输出维度仍然为\(\mathbb{R}^{n \times (N + 1) \times
D}\),x = x[:, 0]
将所有\(n\)个句子的类别标签取出,得到维度为\(\mathbb{R}^{n \times
D}\)的输出。最后输入到Dense
中得到分类向量。
3.2 Fine-tuning And Higher Resolution
4 Experiments
4.1 Setup
Model | Layers | Hidden size \(D\) | MLP size | Heads | Params |
---|---|---|---|---|---|
ViT-Base | 12 | 768 | 3072 | 12 | 86M |
ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
Ours-JFT (ViT-H/14) |
Ours-JFT (ViT-L/16) |
Ours-i21k (ViT-L/16) |
BiT-L (ResNet152x4) |
Noisy Student (EfficientNet-L2) |
|
---|---|---|---|---|---|
ImageNet | 88.55 ±0.04 | 87.76 ±0.03 | 85.30 ±0.02 | 87.54 ±0.02 | 88.4/88.5∗ |
ImageNet ReaL | 90.72 ±0.05 | 90.54 ±0.03 | 88.62 ±0.05 | 90.54 | 90.55 |
CIFAR-10 | 99.50 ±0.06 | 99.42 ±0.03 | 99.15 ±0.03 | 99.37 ±0.06 | - |
CIFAR-100 | 94.55 ±0.04 | 93.90 ±0.05 | 93.25 ±0.05 | 93.51 ±0.08 | - |
Oxford-IIIT Pets | 97.56 ±0.03 | 97.32 ±0.11 | 94.67 ±0.15 | 96.62 ±0.23 | - |
Oxford Flowers-102 | 99.68 ±0.02 | 99.74 ±0.00 | 99.61 ±0.02 | 99.63 ±0.03 | - |
VTAB (19 tasks) | 77.63 ±0.23 | 76.28 ±0.46 | 72.72 ±0.21 | 76.29 ±1.70 | - |
TPUv3-core-days | 2.5k | 0.68k | 0.23k | 9.9k | 12.3k |
4.2 Comparison to State of the Art

4.3 Pre-training Data Requirements


ViT-B/16 | ViT-B/32 | ViT-L/16 | ViT-L/32 | ViT-H/14 | ||
---|---|---|---|---|---|---|
ImageNet | CIFAR-10 | 98.13 | 97.77 | 97.86 | 97.94 | - |
CIFAR-100 | 87.13 | 86.31 | 86.35 | 87.07 | - | |
ImageNet | 77.91 | 73.38 | 76.53 | 71.16 | - | |
ImageNet ReaL | 83.57 | 79.56 | 82.19 | 77.83 | - | |
Oxford Flowers-102 | 89.49 | 85.43 | 89.66 | 86.36 | - | |
Oxford-IIIT-Pets | 93.81 | 92.04 | 93.64 | 91.35 | - | |
ImageNet-21k | CIFAR-10 | 98.95 | 98.79 | 99.16 | 99.13 | 99.27 |
CIFAR-100 | 91.67 | 91.97 | 93.44 | 93.04 | 93.82 | |
ImageNet | 83.97 | 81.28 | 85.15 | 80.99 | 85.13 | |
ImageNet ReaL | 88.35 | 86.63 | 88.40 | 85.65 | 88.70 | |
Oxford Flowers-102 | 99.38 | 99.11 | 99.61 | 99.19 | 99.51 | |
Oxford-IIIT-Pets | 94.43 | 93.02 | 94.73 | 93.09 | 94.82 | |
JFT-300M | CIFAR-10 | 99.00 | 98.61 | 99.38 | 99.19 | 99.50 |
CIFAR-100 | 91.87 | 90.49 | 94.04 | 92.52 | 94.55 | |
ImageNet | 84.15 | 80.73 | 87.12 | 84.37 | 88.04 | |
ImageNet ReaL | 88.85 | 86.27 | 89.99 | 88.28 | 90.33 | |
Oxford Flowers-102 | 99.56 | 99.27 | 99.56 | 99.45 | 99.68 | |
Oxford-IIIT-Pets | 95.80 | 93.40 | 97.11 | 95.83 | 97.56 |
4.4 Scaling Study

name | Epochs | ImageNet | ImageNet ReaL | CIFAR-10 | CIFAR-100 | Pets | Flowers | exaFLOPs |
---|---|---|---|---|---|---|---|---|
ViT-B/32 | 7 | 80.73 | 86.27 | 98.61 | 90.49 | 93.40 | 99.27 | 55 |
ViT-B/16 | 7 | 84.15 | 88.85 | 99.00 | 91.87 | 95.80 | 99.56 | 224 |
ViT-L/32 | 7 | 84.37 | 88.28 | 99.19 | 92.52 | 95.83 | 99.45 | 196 |
ViT-L/16 | 7 | 86.30 | 89.43 | 99.38 | 93.46 | 96.81 | 99.66 | 783 |
ViT-L/16 | 14 | 87.12 | 89.99 | 99.38 | 94.04 | 97.11 | 99.56 | 1567 |
ViT-H/14 | 14 | 88.08 | 90.36 | 99.50 | 94.71 | 97.11 | 99.71 | 4262 |
ResNet50x1 | 7 | 77.54 | 84.56 | 97.67 | 86.07 | 91.11 | 94.26 | 50 |
ResNet50x2 | 7 | 82.12 | 87.94 | 98.29 | 89.20 | 93.43 | 97.02 | 199 |
ResNet101x1 | 7 | 80.67 | 87.07 | 98.48 | 89.17 | 94.08 | 95.95 | 96 |
ResNet152x1 | 7 | 81.88 | 87.96 | 98.82 | 90.22 | 94.17 | 96.94 | 141 |
ResNet152x2 | 7 | 84.97 | 89.69 | 99.06 | 92.05 | 95.37 | 98.62 | 563 |
ResNet152x2 | 14 | 85.56 | 89.89 | 99.24 | 91.92 | 95.75 | 98.75 | 1126 |
ResNet200x3 | 14 | 87.22 | 90.15 | 99.34 | 93.53 | 96.32 | 99.04 | 3306 |
R50x1+ViT-B/32 | 7 | 84.90 | 89.15 | 99.01 | 92.24 | 95.75 | 99.46 | 106 |
R50x1+ViT-B/16 | 7 | 85.58 | 89.65 | 99.14 | 92.63 | 96.65 | 99.40 | 274 |
R50x1+ViT-L/32 | 7 | 85.68 | 89.04 | 99.24 | 92.93 | 96.97 | 99.43 | 246 |
R50x1+ViT-L/16 | 7 | 86.60 | 89.72 | 99.18 | 93.64 | 97.03 | 99.40 | 859 |
R50x1+ViT-L/16 | 14 | 87.12 | 89.76 | 99.31 | 93.89 | 97.36 | 99.11 | 1668 |
4.5 Inspecting Vision Transformer


4.6 Self-supervision
5 Conclusion
Appendix
A Multihead Self-attention
B Experiment Details
B.1 Training
Models | Dataset | Epochs | Base LR | LR decay | Weight decay | Dropout |
---|---|---|---|---|---|---|
ViT-B/{16,32} | JFT-300M | 7 | 8 · 10-4 | linear | 0.1 | 0.0 |
ViT-L/32 | JFT-300M | 7 | 6 · 10-4 | linear | 0.1 | 0.0 |
ViT-L/16 | JFT-300M | 7/14 | 4 · 10-4 | linear | 0.1 | 0.0 |
ViT-H/14 | JFT-300M | 14 | 3 · 10-4 | linear | 0.1 | 0.0 |
R50x{1,2} | JFT-300M | 7 | 10-3 | linear | 0.1 | 0.0 |
R101x1 | JFT-300M | 7 | 8 · 10-4 | linear | 0.1 | 0.0 |
R152x{1,2} | JFT-300M | 7 | 6 · 10-4 | linear | 0.1 | 0.0 |
R50+ViT-B/{16,32} | JFT-300M | 7 | 8 · 10-4 | linear | 0.1 | 0.0 |
R50+ViT-L/32 | JFT-300M | 7 | 2 · 10-4 | linear | 0.1 | 0.0 |
R50+ViT-L/16 | JFT-300M | 7/14 | 4 · 10-4 | linear | 0.1 | 0.0 |
ViT-B/{16,32} | ImageNet-21k | 90 | 10-3 | linear | 0.03 | 0.1 |
ViT-L/{16,32} | ImageNet-21k | 30/90 | 10-3 | linear | 0.03 | 0.1 |
ViT-∗ | ImageNet | 300 | 3 · 10-3 | cosine | 0.3 | 0.1 |
B.1.1 Fine-tuning
Dataset | Steps | Base LR |
---|---|---|
ImageNet | 20000 | {0.003, 0.01, 0.03, 0.06} |
CIFAR100 | 10000 | {0.001, 0.003, 0.01, 0.03} |
CIFAR10 | 10000 | {0.001, 0.003, 0.01, 0.03} |
Oxford-IIIT Pets | 500 | {0.001, 0.003, 0.01, 0.03} |
Oxford Flowers-102 | 500 | {0.001, 0.003, 0.01, 0.03} |
VTAB (19 tasks) | 2500 | 0.01 |
B.1.2 Self-supervision
C Additional Results
D Additional Analyses
D.1 Sgd Vs. Adam For Resnets
Dataset | ResNet50 | ResNet152x2 | ||
---|---|---|---|---|
Adam | SGD | Adam | SGD | |
ImageNet | 77.54 | 78.24 | 84.97 | 84.37 |
CIFAR10 | 97.67 | 97.46 | 99.06 | 99.07 |
CIFAR100 | 86.07 | 85.17 | 92.05 | 91.06 |
Oxford-IIIT Pets | 91.11 | 91.00 | 95.37 | 94.79 |
Oxford Flowers-102 | 94.26 | 92.06 | 98.62 | 99.32 |
Average | 89.33 | 88.79 | 94.01 | 93.72 |
D.2 Transformer Shape

D.3 Head Type And Class Token

D.4 Positional Embedding

Pos. Emb. | Default/Stem | Every Layer | Every Layer-Shared |
---|---|---|---|
No Pos. Emb. | 0.61382 | N/A | N/A |
1-D Pos. Emb. | 0.64206 | 0.63964 | 0.64292 |
2-D Pos. Emb. | 0.64001 | 0.64046 | 0.64022 |
Rel. Pos. Emb. | 0.64032 | N/A | N/A |
D.5 Empirical Computational Costs

D.6 Axial Attention

D.7 Attention Distance

D.8 Attention Maps
D.9 Objectnet Results
D.10 VTAB Breakdown

Caltech101 | CIFAR-100 | DTD | Flowers102 | Pets | Sun397 | SVHN | Camelyon | EuroSAT | Resisc45 | Retinopathy | Clevr-Count | Clevr-Dist | DMLab | dSprites-Loc | dSprites-Ori | KITTI-Dist | sNORB-Azim | sNORB-Elev | Mean | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ViT-H/14 (JFT) | 95.3 | 85.5 | 75.2 | 99.7 | 97.2 | 65.0 | 88.9 | 83.3 | 96.7 | 91.4 | 76.6 | 91.7 | 63.8 | 53.1 | 79.4 | 63.3 | 84.5 | 33.2 | 51.2 | 77.6 |
ViT-L/16 (JFT) | 95.4 | 81.9 | 74.3 | 99.7 | 96.7 | 63.5 | 87.4 | 83.6 | 96.5 | 89.7 | 77.1 | 86.4 | 63.1 | 49.7 | 74.5 | 60.5 | 82.2 | 36.2 | 51.1 | 76.3 |
ViT-L/16 (I21k) | 90.8 | 84.1 | 74.1 | 99.3 | 92.7 | 61.0 | 80.9 | 82.5 | 95.6 | 85.2 | 75.3 | 70.3 | 56.1 | 41.9 | 74.7 | 64.9 | 79.9 | 30.5 | 41.7 | 72.7 |