Vision Transformer(1):ViT源码逐行阅读解析

2 篇文章 3 订阅
订阅专栏

 上图是Vision Transformer原文的模型结构展示,可以看到模型包含了几个核心模块:

 Vision Transformer:

        1. Embedding模块

        2.Transformer Encoder模块

                2.1 NormLayer ( × depth )

                        2.1.1 Multi-Head Attention层

                                 关于Attention机制的详细解析

                        2.1.2 MLP多层感知器

        3.MLP-Head 模块映射为类别

自底向上摸索是在未知中探索的不可缺少的方式,但通过摸索后,发现自顶向下能更好的阐述清楚整个逻辑。

一、ViT & Embedding

假设训练数据维度为(64, 3, 224, 224),意味着有64张三通道的224*224的图像。

设定参数dim=128意味着编码向量长度为128。

ViT中出现的PreNorm、Attention、FeedForward、Transformer后续解释

class ViT(nn.Module):
    '''
    :param
        *: input data
        image_size: 等边图像尺寸
        patch_size: patch的尺寸
        num_classes: 分类类别
        dim: 为每一个patch编码的长度
        depth: Encoder的深度,也就是连接encoder的数目
        heads: 多头注意力中头的数目
        mlp_dim: 多层感知器中隐含层的维度
        pool: 使用cls token还是使用均值池化
        channel: 图像的通道数
        dim_head: 注意力机制中一个头的输入维度
        dropout: NormLayer中dropout的参数比例
        emb_dropout: Embedding中的dropout比例
    :return 分类结果(64, 2)
    '''
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        # image_size就是每一张图像的长和宽,通过pair函数便捷明了的表现
        # patch_size就是图像的每一个patch的长和宽
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        # 保证图像可以整除为若干个patch
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算出每一张图片会被切割为多少个patch
        # 假设输入维度(64, 3, 224, 224), num_patches = 49
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 每一个patch数组大小, patch_dim = 3*32*32=3072
        patch_dim = channels * patch_height * patch_width
        # cls就是分类的Token, mean就是均值池化
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        # embeding操作:假设输入维度(64, 3, 224, 224),那么经过Rearange层后变成了(64, 7*7=49, 32*32*3=3072)
        self.to_patch_embedding = nn.Sequential(
            # 将图片分割为b*h*w个三通道patch,b表示输入图像数量
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            # 经过线性全连接后,维度变成(64, 49, 128)
            nn.Linear(patch_dim, dim),
        )
        # dim张图像,每张图像需要num_patches个向量进行编码
        # 位置编码(1, 50, 128) 本应该为49,但因为cls表示类别需要增加一个
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # CLS类别token,(1, 1, 128)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 设置dropout
        self.dropout = nn.Dropout(emb_dropout)
        # 初始化Transformer
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        # pool默认是cls进行分类
        self.pool = pool
        self.to_latent = nn.Identity()
        # 多层感知用于将最终特征映射为2个类别
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # 第一步,原始图像ebedding,进行了图像切割以及线性变换,变成x->(64, 49, 128)
        x = self.to_patch_embedding(img)
        # 得到原始图像数目和单图像的patches数量, b=64, n=49
        b, n, _ = x.shape
        # (1, 1, 128) -> (64, 1, 128) 为每一张图像设置一个cls的token
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将cls token加入到数据中 -> (64, 50, 128)
        x = torch.cat((cls_tokens, x), dim=1)
        # x(64, 50, 128)添加位置编码(1, 50, 128)
        x += self.pos_embedding[:, :(n + 1)]
        # 经过dropout层防止过拟合
        x = self.dropout(x)

        x = self.transformer(x)
        # 进行均值池化
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        # 最终进行分类映射
        return self.mlp_head(x)

二、Transformer

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        # 设定depth个encoder相连,并添加残差结构
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        # 每次取出包含Norm-attention和Norm-mlp这两个的ModuleList,实现残差结构
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

1.Norm层

class PreNorm(nn.Module):
    '''
    :param  dim 输入维度
            fn 前馈网络层,选择Multi-Head Attn和MLP二者之一
    '''
    def __init__(self, dim, fn):
        super().__init__()
        # LayerNorm: ( a - mean(last 2 dim) ) / sqrt( var(last 2 dim) )
        # 数据归一化的输入维度设定,以及保存前馈层
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    # 前向传播就是将数据归一化后传递给前馈层
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

2.MLP层

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

3.Attention层

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = heads * dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        # 表示1/(sqrt(dim_head))用于消除误差,保证方差为1,避免向量内积过大导致的softmax将许多输出置0的情况
        # 可以看原文《attention is all you need》中关于Scale Dot-Product Attention如何抑制内积过大
        self.scale = dim_head ** -0.5
        # dim =  > 0 时,表示mask第d维度,对相同的第d维度,进行softmax
        # dim =  < 0 时,表示mask倒数第d维度,对相同的倒数第d维度,进行softmax
        self.attend = nn.Softmax(dim = -1)
        # 生成qkv矩阵,三个矩阵被放在一起,后续会被分开
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 如果是多头注意力机制则需要进行全连接和防止过拟合,否则输出不做更改
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 分割成q、k、v三个矩阵
        # qkv为 inner_dim * 3,其中inner_dim = heads * dim_head
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # qkv的维度是(3, inner_dim = heads * dim_head)
        # 'b n (h d) -> b h n d' 重新按思路分离出8个头,一共8组q,k,v矩阵
        # rearrange后维度变成 (3, heads, dim, dim_head)
        # 经过map后,q、k、v维度变成(1, heads, dim, dim_head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # query * key 得到对value的注意力预测,并通过向量内积缩放防止softmax无效化部分参数
        # heads * dim * dim
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # 对最后一个维度进行softmax后得到预测的概率值
        attn = self.attend(dots)
        # 乘积得到预测结果
        # out -> heads * dim * dim_head
        out = torch.matmul(attn, v)
        # 重组张量,将heads维度重新还原
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

4.其他部分

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

三、MLP-Head模块

self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

【计算机视觉】ViTVision Transformer 讲解
LJR的博客
03-18 850
ViT vision transformer BERT NLP CV 图像分类 CLS encoder 全局平均池化 Global Average Pooling GAP patch 注意力 attention 归纳偏置 ResNet BiT CNNs 局部相关性(locality)和平移不变性(translation equivariance) Hybrid 预训练 微调 pretrain fine-tune
Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions论文以及代码解析
Hide on bush
09-27 3236
Pyramid Vision Transformer1. Abstract2. Introduction3. Method3.1. Overall Architecture3.2. Transformer Encoder3.3. More Details4. PVT-V1代码解析4.1. main脚本4.2. pvt脚本4.2.1. PyramidVisionTransformer类的init4.2.2. PyramidVisionTransformer类的forward 论文地址:PVT-V1版本论文
ViT源码讲解
carambola_的博客
08-24 2238
ViT模型源码
vision_transformer
03-16
视觉变压器 作者:Alexey Dosovitskiy *&dagger;,Lucas Beyer *,Alexander Kolesnikov *,Dirk Weissenborn *,翟小华*,Thomas Unterthiner,Mostafa Dehghani,Matthias Minderer,Georg Heigold,Sylvain Gelly,Jakob Uszkoreit和Neil Houlsby *&dagger;。 (*)等于技术贡献,(&dagger;)等于建议。 由Andreas Steiner准备的开源发行版。 注意:此存储库是从分支和修改的。 介绍 在此存储库中,我们从论文 释放模型,这些模型已在 ( imagenet21k )数据集上进行了预训练。 我们提供用于微调 / 已发布模型的代码。 模型概述:我们将图像分割成固定大小的块,线性地嵌入每个块,添加位置嵌入,然后将所得的矢量序列馈送到标
ViT中的Postion Embedding(位置编码)详解:数据从一维到二维的变化
最新发布
介绍AI领域相关知识
08-23 1941
Transformer有效的解决了长距离依赖问题,并且有良好的可扩展性,适用于处理序列化的数据,NLP中的语句刚好就是序列化的数据,但是在计算机视觉中,图像属于二维数据,那么如何在二维数据中应用到transformer呢,针对这个问题,ViT的作者提出一种位置编码策略,将一张图片切分成相同大小的块,然后给每个块进行位置编码成为一个序列,然后再使用transformer进行训练。本篇内容带大家详细了解一下ViT中的位置编码。
VIT源码解读
qq_52093995的博客
08-09 615
patch_size是选择多大的区域进行分块提取特征,n_patches一共有多少块(图像宽/patch_size宽)x(图像高/patch_size高)patch_embeddings卷积stride为patch_size,提取特征时的卷积不重叠提取特征。图像的第一步:图特征向量提取。cls_token为(1,1,768)进行expand将维度复制到B,这里的B是16,每个数据都要有对应的cls_token。输入(16,3,224,224)训练batch大小,3通道RGB,图像大小224x224。
VIT 源码详解
qq_52053775的博客
08-10 3580
参数说明:数据集: --name cifar10-100_500 --dataset cifar10哪个版本的模型: --model_type ViT-B_16预训练权重: --pretrained_dir checkpoint/ViT-B_16.npz 对于图像编码,以VIT - B/16为例,首先用卷积核大小为16*16、步长为16的卷积,对图像进行变换,此时图像维度变成16 * 768 * 14 * 14,再变换维度为[16, 19
vit源码阅读
Tomatoccc的博客
05-10 678
vit源码阅读-pytorch 原论文链接:https://arxiv.org/abs/2010.11929 源码来源:timm包中的VisionTransformer模型。 刚开始博客,写的不好请多多包含。如有错误,请指出,感谢。 模型结构 词嵌入 词嵌入接口调用if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_cha
Vision Transformer源码详解
m0_53374472的博客
11-10 1567
本篇文章主要分享视觉Transformer的Pytorch实现和代码细节问题,着重在于Vision Transfomer的Pytorch实现,
Vision Transformer 源码解读
02-03
在文本任务中大量使用了Transformer 架构,因为文本数据是一个序列非常好的契合Transformer 架构。 可是如何将一张图像展开成一个序列呢? 将一个文本数据使用Transformer 进行特征提取需要把文本embbeding成一个...
2020:ViTVision Transformer)【将Transformer应用在图像分类的模型】【当拥有足够多的数据进行预训练的时候,ViT的表现就会超过ResNets,反之不如】
u013250861的博客
02-19 2248
ViT作者团队出品,在CNN和Transformer大火的背景下,舍弃了卷积和注意力机制,提出了MLP-Mixer,一个完全基于MLPs的结构,其MLPs有两种类型,分别是和,前者独立作用于image patches(融合通道信息),后者跨image patches作用(融合空间信息)。实验结果表明该结构和SOTA方法同样出色,证明了convolution和attention不是必要操作,如果将其替换为简单的MLP,模型依然可以完美work。
VisionTransformer-Pytorch
05-14
视觉变压器火炬 此项目是从 / 和 / ,为您提供了现成的API,可让您像EfficientNet一样容易地使用VisionTransformer。 快速开始 使用pip install vision_transformer_pytorch安装,并使用以下命令加载经过预训练的VisionTransformer: from vision_transformer_pytorch import VisionTransformer model = VisionTransformer.from_pretrained('ViT-B_16') 关于视觉变压器PyTorch Vision Transformer Pytorch是Vision Transformer的PyTorch重新实现,它基于常用的深度学习库的最佳实践之一 ,以及的优雅实现工具 。 在这个项目中,我们旨在使我们的PyTorch实施尽
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
weixin_56836871的博客
02-06 2507
动机: 为啥挑这篇文章,因为效果炸裂,各种改款把各种数据集霸榜了:语义分割/分类/目标检测,前10都有它 Swin Transformer, that capably serves as a general-purpose backbone for computer vision. 【CC】接着VIT那篇论文挖的坑,transfomer能否做为CV领域的backbone,VIT里面只做了分类的尝试,留了检测/语义分割的坑,这篇文章直接回答swin transfomer可以 Transformer from
Vision Transformer 全面代码解析
人工智能曾小健
08-20 908
注意力机制允许模型在处理输入序列时,关注到最重要的部分,而多头自注意力则通过多个独立的注意力头来同时关注不同的特征子空间,提高了模型的表达能力。虽然我们已经完成了VisionTransformer的所有代码分析和搭建过程,但为了让模型更加易于使用和调用,我们还需要对其进行进一步的封装。在完成了所有必要模块的创建之后,我们现在要做的就是将它们组合起来,构建我们的VisionTransformer模型。在这些阶段,模型的输出是基于所有路径的贡献,而不是被随机“丢弃”了一些路径的情况。,以提高模型的泛化能力。
Vision Transformer代码
分享计算机视觉,C++,网络摄像头研发,音视频开发,嵌入式等知识。
09-07 228
【代码】Vision Transformer代码。
轻松理解ViT(Vision Transformer)原理及源码
qq_17027283的博客
06-27 2462
轻松理解ViT(Vision Transformer)原理及源码
Transformer实战-系列教程5:Vision Transformer 源码解读3
机器学习与软件工程
02-05 1089
Transformer实战-系列教程5:Vision Transformer 源码解读3
Transformer实战-系列教程4:Vision Transformer 源码解读2
机器学习与软件工程
02-03 1008
Transformer实战-系列教程3:Vision Transformer 源码解读1
可变形的Tranformer算法详解与源码——DAT:Vision Transformer with Deformable Attention
qq_52053775的博客
08-26 2189
和分别表示变形的键嵌入和值嵌入。具体来说,我们将采样函数(·;·)设置为一个双线性插值,使其可微:其中和索引了上的所有位置。由于g只在最接近的4个积分点上不为零,因此它简化了等式(8)到4个地点的加权平均值。与现有的方法类似,我们对q、k、v进行多头注意,并采用相对位置偏移r。注意头的输出表述为:其中对应于之前的工作[26]之后的位置嵌入,同时有一些适应。细节将在本节后面解释。每个头的特征被连接在一起,并通过Wo进行投影,得到最终的输出z为等式(3)....
写文章

热门文章

  • 朴素贝叶斯分类算法(matlab实现) 11134
  • Yolov5(1):Detect源码逐行解析 6933
  • Vision Transformer(1):ViT源码逐行阅读解析 4512
  • 服务器Ubuntu系统----网络安全ufw规则配置 4378
  • Netlogo笔记06:狼羊追逐 3521

分类专栏

  • 爬虫
  • C++ 6篇
  • Mall
  • spring源码 1篇
  • transformer 2篇
  • 回忆 3篇
  • Node.js 4篇
  • 深度学习与机器学习 6篇
  • netlogo与元胞自动机 7篇
  • yolo 2篇
  • 笔记 1篇
  • MySQL 2篇
  • Opencv 1篇
  • Linux 3篇
  • 动态规划 4篇
  • Python 1篇
  • 图论 3篇
  • Matlab 5篇

最新评论

  • Yolov5(1):Detect源码逐行解析

    。七十二。: 你好,为什么文中有:下方的图中,用四个黄颜色圈出来的部分,分别表示为4个Head部分。 他的输出不是只有三个head吗?这里是什么意思?

  • Netlogo笔记07:蚁群算法实现TSP问题可视化

    AboutHutdeminbai: 想回你吧你都问两年了,现在回答你的问题没意义;不回你吧你都问两年了没人给你回答你伤心

  • 朴素贝叶斯分类算法(matlab实现)

    精神焕发830: 后面那个矩阵数据也是错的 包括他的概率算的也不对吧

  • Netlogo笔记07:蚁群算法实现TSP问题可视化

    伶俜594: 为什么报错house-num没有定义啊

  • 朴素贝叶斯分类算法(matlab实现)

    勿笑葱: 应该是1才对吧

最新文章

  • Typescript 实现字典树检索数据
  • Spring源码学习:setter循环依赖
  • Sails.js自动化Api实践与测试
2023年2篇
2022年8篇
2021年8篇
2020年18篇
2019年15篇

目录

目录

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43元 前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

尼卡尼卡尼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或 充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值

玻璃钢生产厂家三门峡商场美陈绿植墙重庆公园景观玻璃钢美陈雕塑菏泽卡通玻璃钢雕塑北京商场创意商业美陈价格陕西人物玻璃钢雕塑设计甘南玻璃钢雕塑定制石家庄直销玻璃钢雕塑畅销全国玻璃钢电影动漫雕塑厂家直销达州玻璃钢雕塑摆件施工方法正宗玻璃钢雕塑厂家服务介绍宁波深圳玻璃钢雕塑广西玻璃钢花盆价格苏州玻璃钢人物雕塑价格范围半边地球仪玻璃钢雕塑哪家有玻璃钢花盆儿童画玻璃钢立体人物雕塑玻璃钢香蕉雕塑供应商湘潭玻璃钢造型雕塑长沙人物玻璃钢雕塑价格银川大型玻璃钢雕塑定制三门峡玻璃钢雕塑报价艺术玻璃钢雕塑哪家好浮雕玻璃钢人物雕塑北京市商场美陈公司报价济源玻璃钢人物园林不锈钢雕塑太原玻璃钢海豚雕塑价格曲阳玻璃钢雕塑企业管庄商场美陈效果图扬州玻璃钢人物雕塑设计泸州玻璃钢雕塑摆件工程香港通过《维护国家安全条例》两大学生合买彩票中奖一人不认账让美丽中国“从细节出发”19岁小伙救下5人后溺亡 多方发声单亲妈妈陷入热恋 14岁儿子报警汪小菲曝离婚始末遭遇山火的松茸之乡雅江山火三名扑火人员牺牲系谣言何赛飞追着代拍打萧美琴窜访捷克 外交部回应卫健委通报少年有偿捐血浆16次猝死手机成瘾是影响睡眠质量重要因素高校汽车撞人致3死16伤 司机系学生315晚会后胖东来又人满为患了小米汽车超级工厂正式揭幕中国拥有亿元资产的家庭达13.3万户周杰伦一审败诉网易男孩8年未见母亲被告知被遗忘许家印被限制高消费饲养员用铁锨驱打大熊猫被辞退男子被猫抓伤后确诊“猫抓病”特朗普无法缴纳4.54亿美元罚金倪萍分享减重40斤方法联合利华开始重组张家界的山上“长”满了韩国人?张立群任西安交通大学校长杨倩无缘巴黎奥运“重生之我在北大当嫡校长”黑马情侣提车了专访95后高颜值猪保姆考生莫言也上北大硕士复试名单了网友洛杉矶偶遇贾玲专家建议不必谈骨泥色变沉迷短剧的人就像掉进了杀猪盘奥巴马现身唐宁街 黑色着装引猜测七年后宇文玥被薅头发捞上岸事业单位女子向同事水杯投不明物质凯特王妃现身!外出购物视频曝光河南驻马店通报西平中学跳楼事件王树国卸任西安交大校长 师生送别恒大被罚41.75亿到底怎么缴男子被流浪猫绊倒 投喂者赔24万房客欠租失踪 房东直发愁西双版纳热带植物园回应蜉蝣大爆发钱人豪晒法院裁定实锤抄袭外国人感慨凌晨的中国很安全胖东来员工每周单休无小长假白宫:哈马斯三号人物被杀测试车高速逃费 小米:已补缴老人退休金被冒领16年 金额超20万

玻璃钢生产厂家 XML地图 TXT地图 虚拟主机 SEO 网站制作 网站优化