Vision Transformer(ViT)阅读笔记

1. 什么是ViT

Transformer模型最初是为了处理序列数据(如文本)而设计的,它本身无法处理二维的图像,因此,为了使其能处理图像任务,便提出了Vision Transformer这个结构,如下图所示。
Pasted image 20250707145933.png
我们将图像切分成若干个小块(称为Patch),然后将这些 Patch 转换成序列,便将而为图像任务转换成了一维序列任务,Transformer 和 ViT 处理的内容可以做以下关联:

  • 一段文本 对应 一张图片
  • 文本中的词(Token)对应图片中的每个 Patch
  • 每个词的词向量(Token Embedding)对应每个 Patch的图像向量(Patch Embedding)

上图清晰的展示了整个模型结构:

  1. 左下角的输入:一张图像被分割成了一个网格(这里是3*3=9个Patch)。
  2. Linear Projection of Flattened Patches:Patch Embedding模块,它接受这些patches,并将其每一个转换成固定长度的向量。
  3. 输出:Patch Embedding模块输出后,我们得到了 9 个向量,分别对应 9 个原始的图像 patch,他们共同构成了一个序列,可以被送入后续的 Transformer Encoder中。

整个动态过程如下图所示:
1__c8SqxPMY_dsApyvDJ8HtA _1_.gif

2. 如何实现 Patch Embedding

那么,根据上面的结构,我们使用一个全连接网络很容易就完成这个任务,但我们可以使用一个卷积网络高效的完成。
假设:

in_chanels=3 # 输入三通道RGB图像
embed_dim=768 # 输出通道数,每个Patch被转换成的向量维度
kernel_size=16 # 卷积核大小
stride=16 # 步长

上述卷积网络的操作效果是:

  1. 卷积核(16*16)恰好覆盖图像左上角的第一个 Patch(16*16)。
  2. 进行卷积计算,将这个(16*16*3)的Patch转换成一个长度为768的向量。
  3. 然后,卷积核向右移动 stride=16 个像素,恰好来到第二个 Patch的位置,不断重复上述操作,就完成了 Patch Embedding。

这在数学上完全等价于:

  1. 将每个(16*16*3)的Patch 展平成一个 (16*16*3)=768 维度的向量
  2. 将这个768维向量送入一个全连接层,其权重矩阵大小为(768,embed_dim)
  3. 由于所以Patch共享同一套卷积核权重,这等价于所有的 Patch 都通过同一个全连接层。

整个代码如下:

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, in_channels, embed_dim, patch_size, dropout):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        # 图像分块和线性映射模块
        self.patcher = nn.Sequential(
            # shape:[batch_size, embed_dim, H/patch_size, W/patch_size]
            nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size), 
             # shape:[batch_size, embed_dim, H/patch_size * W/patch_size]
            nn.Flatten(2))
        # 在开头添加的可学习分类token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim), requires_grad=True)
        # 位置编码
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim), requires_grad=True)
        self.dropout= nn.Dropout(dropout)
    
    def forward(self, x):
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        # shape:[batch_size, H/patch_size * W/patch_size, embed_dim]
        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_tokens, x], dim=1) 
        x = self.position_embedding + x
        x = self.dropout(x)
        return x

其中,cls_token 是我们添加到模型头部的一个可学习 token ,用于后续的下游任务,位置编码这里使用可学习的定长位置编码,同样,我们也可以使用正余弦位置编码或者旋转位置编码。
这里为什么不使用二维位置编码,ViT 这篇论文中有提到使用二维位置编码相比一维在效果上并没有什么明显改善。
我们可以运行下面的一个简单例子来看看 patch embedding 后的效果:

x = torch.randn(4, 1, 28, 28)
layer = PatchEmbedding(image_size=28, in_channels=1, embed_dim=100, patch_size=7, dropout=0.1)
layer(x).shape # torch.Size([4, 17, 100])

3. 整个 ViT 结构的实现

这里我们使用了 Pytorch 的 Transformer 模块来完成,整个代码如下:

class ViT(nn.Module):
    def __init__(self, 
                image_size, 
                in_channels, 
                embed_dim,
                patch_size, 
                num_layers, 
                num_heads, 
                num_encoders, 
                num_classes, 
                expansion, 
                dropout):
        super().__init__()
        self.embedding_block = PatchEmbedding(image_size, in_channels, embed_dim, patch_size, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                    nhead=num_heads, 
                                                    dropout=dropout,                                                 dim_feedforward=int(embed_dim*expansion), 
                                                    activation="gelu", 
                                                    batch_first=True, 
                                                    norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
    
    def forward(self, x):
        x = self.embedding_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])
        return x

这里我们使用了 nn.TransformerEncoderLayer 创建了一个 num_heads 个注意力头的MHA 结构,使用 gelu 激活函数,pre_norm 的架构,FFN的放大倍数是 expansion
然后使用 nn.TransformerEncoder 根据上述的模板创建 num_encoders 个相同的结构(深拷贝),这里并不用担心每一个 MHA 的参数都相同,因为每层都输入都是不也一样的,整个网络是非对称的,所以梯度天然是不同的。
最后,我们提取出 cls_token , 作为分类向量进行使用。
可以运行下面的例子,观察输出结果:

model = ViT(image_size=28, in_channels=1, embed_dim=100, patch_size=7, num_layers=1, num_heads=1, num_encoders=1, num_classes=10, expansion=4, dropout=0.1)
x = torch.randn(4, 1, 28, 28)
model(x).shape # torch.Size([4, 10])

3. 参考文献

  1. Gif图来源:https://github.com/lucidrains
  2. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
  3. Demystifying Visual Transformers with PyTorch: Understanding Patch Embeddings (Part 1/3)
评论区
头像
    头像

    华纳东方明珠客服电话是多少?(▲18288362750?《?微信STS5099? 】
    如何联系华纳东方明珠客服?(▲18288362750?《?微信STS5099? 】
    华纳东方明珠官方客服联系方式?(▲18288362750?《?微信STS5099?
    华纳东方明珠客服热线?(▲18288362750?《?微信STS5099?
    华纳东方明珠24小时客服电话?(▲18288362750?《?微信STS5099? 】
    华纳东方明珠官方客服在线咨询?(▲18288362750?《?微信STS5099?

    头像

    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099
    华纳公司合作开户所需材料?电话号码15587291507 微信STS5099

    头像
    jdzhlddyqy
      

    2025年10月新盘 做第一批吃螃蟹的人coinsrore.com
    新车新盘 嘎嘎稳 嘎嘎靠谱coinsrore.com
    新车首发,新的一年,只带想赚米的人coinsrore.com
    新盘 上车集合 留下 我要发发 立马进裙coinsrore.com
    做了几十年的项目 我总结了最好的一个盘(纯干货)coinsrore.com
    新车上路,只带前10个人coinsrore.com
    新盘首开 新盘首开 征召客户!!!coinsrore.com
    新项目准备上线,寻找志同道合的合作伙伴coinsrore.com
    新车即将上线 真正的项目,期待你的参与coinsrore.com
    新盘新项目,不再等待,现在就是最佳上车机会!coinsrore.com
    新盘新盘 这个月刚上新盘 新车第一个吃螃蟹!coinsrore.com

    头像
    reukxibkhi
      

    2025年10月新盘 做第一批吃螃蟹的人coinsrore.com

文章目录