🔭 ViT 家族神经网络数据流对比(9模型并列)

ViT-Base/16
ViT-Large/16
DeiT
Swin-T
MAE
BEiT
DINO
CaiT
MobileViT
图像输入
x∈(B,3,224,224)
Patch提取
划分16×16 patch reshape→(B,196,768) 每patch展平为768维
Patch线性投影
x=patches·W_E+b_E W_E∈ℝ^(768×768) out: (B,196,768)
[CLS]拼接
cls∈(1,768)→广播(B,1,768) concat([cls,x_patch],dim=1) out: (B,197,768)
Position Embedding广播
E_pos∈ℝ^(197,768) 广播→(B,197,768) x←x+E_pos_broadcasted out: (B,197,768)
Dropout
x←Dropout(x,p=0.1) (B,197,768)
Encoder × 12层 Pre-LN(H=768,N=197,4H=3072)
① LayerNorm
x_n=(x-μ)/√(σ²+ε)·γ+β (B,197,768) 不变
② 多头自注意力(7步)
步骤1-整体QKV投影: Q=X·W_Q,W_Q∈ℝ^(768×768) K=X·W_K,V=X·W_V Q/K/V∈(B,197,768) 步骤2-来源:自注意力 Q/K/V均来自X∈(B,197,768) 步骤3-多头拆分: d_k=768/12=64 Q_i=Q[:,:,i*64:(i+1)*64] Q_i∈(B,197,64);K_i/V_i同 等价reshape: Q→(B,12,197,64) 步骤4-△Q·Kᵀ产生分数 ∈(B,12,197,197) V_i只被加权聚合不参与分数 步骤5-单头计算: head_i=Softmax(Q_i·K_iᵀ/√64) ·V_i∈(B,197,64) 步骤6-多头拼接: Concat(h1..h12)→(B,197,768) 步骤7-输出投影: ·W_O,W_O∈ℝ^(768×768) out: (B,197,768)
③ 残差加法①
x←x+attn_out (B,197,768)
④ LayerNorm②
x_n=(x-μ)/√(σ²+ε)·γ+β (B,197,768)
⑤ Linear升维
out=x·W+b W∈ℝ^(768×3072) out: (B,197,3072)
⑥ GELU
GELU(x)=x·Φ(1.702x) (B,197,3072)
⑦ Linear降维
W∈ℝ^(3072×768) out: (B,197,768)
⑧ 残差加法②
x←x+ffn_out (B,197,768)
取[CLS] token
x_cls=x[:,0,:] (B,197,768)→(B,768)
LayerNorm输出前
x=(x-μ)/√(σ²+ε)·γ+β (B,768)
分类头
logits=x_cls·W_cls+b W_cls∈ℝ^(768×C) out: (B,C) 图像分类头,C=1000(ImageNet)
图像输入
x∈(B,3,224,224)
Patch提取
划分16×16 patch →(B,196,768)
Patch线性投影
x=patches·W_E+b_E W_E∈ℝ^(768×1024) out: (B,196,1024)
[CLS]拼接
cls∈(1,1024)→(B,1,1024) concat→(B,197,1024)
Position Embedding广播
E_pos∈ℝ^(197,1024) 广播→(B,197,1024) x←x+E_pos_broadcasted
Dropout
x←Dropout(x,p=0.1)
Encoder × 24层 Pre-LN(H=1024,N=197,4H=4096)
① LayerNorm①
x_n=(x-μ)/√(σ²+ε)·γ+β
② 多头自注意力(7步)
步骤1-整体投影: Q=X·W_Q,W_Q∈ℝ^(1024×1024) Q/K/V∈(B,197,1024) 步骤2-来源:自注意力 Q/K/V均来自X∈(B,197,1024) 步骤3-拆分:d_k=1024/16=64 Q_i∈(B,197,64);K_i/V_i同 reshape→(B,16,197,64) 步骤4-△Q·Kᵀ∈(B,16,197,197) V_i仅被加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√64)·V_i∈(B,197,64) 步骤6-Concat(h1..h16) →(B,197,1024) 步骤7-·W_O∈ℝ^(1024×1024) out: (B,197,1024)
③ 残差加法①
x←x+attn_out
④ LayerNorm②
⑤ Linear升维
W∈ℝ^(1024×4096) out: (B,197,4096)
⑥ GELU
GELU(x)=x·Φ(1.702x)
⑦ Linear降维
W∈ℝ^(4096×1024) out: (B,197,1024)
⑧ 残差加法②
x←x+ffn_out
取[CLS]
x_cls=x[:,0,:] (B,197,1024)→(B,1024)
LayerNorm
分类头
W_cls∈ℝ^(1024×C) out: (B,C) 图像分类头,C=1000
图像输入
x∈(B,3,224,224)
Patch提取+线性投影
W_E∈ℝ^(768×768) out: (B,196,768)
[CLS]+[DIST]拼接
cls,dist各∈(1,768) →广播(B,2,768) concat→(B,198,768) △额外蒸馏token[DIST]
Position Embedding广播
E_pos∈ℝ^(198,768) 广播→(B,198,768) x←x+E_pos_broadcasted
Dropout
x←Dropout(x,p=0.1)
Encoder × 12层 Pre-LN(H=768,N=198)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(768×768) Q/K/V∈(B,198,768) 步骤2-来源:自注意力 含[CLS]和[DIST]的198 token 步骤3:d_k=768/12=64 Q_i∈(B,198,64) reshape→(B,12,198,64) 步骤4-△Q·Kᵀ∈(B,12,198,198) [DIST]参与全局注意力 V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√64)·V_i 步骤6-Concat→(B,198,768) 步骤7-·W_O∈ℝ^(768×768)
③ 残差加法①
④ LayerNorm②
⑤ Linear升维
W∈ℝ^(768×3072)
⑥ GELU
⑦ Linear降维
W∈ℝ^(3072×768)
⑧ 残差加法②
取[CLS]和[DIST]
cls=x[:,0,:],dist=x[:,1,:] 各→(B,768)
双分类头
logits_cls=cls·W_c+b logits_dist=dist·W_d+b 各out: (B,C) cls头从真实标签学 dist头从teacher模型学 推理时平均softmax
图像输入
x∈(B,3,224,224)
4×4 Patch切分
→(B,3136,48)
Patch线性投影
W_E∈ℝ^(48×96) out: (B,3136,96)
Stage1 ×2层(W-MSA+SW-MSA交替)
W-MSA Pre-LN LN①
窗口多头自注意力(7步)
步骤1-重排+投影: (B,3136,96)→(B·win,49,96) Q=X_w·W_Q∈(B·win,49,96) 步骤2-来源:自注意力 窗口内49个token(7×7) 步骤3:d_k=96/3=32 Q_i∈(B·win,49,32) reshape→(B·win,3,49,32) 步骤4-△Q·Kᵀ∈(B·win,3,49,49) 窗口内计算,复杂度O(N·w²) V_i仅加权聚合 步骤5-head_i=Softmax( (Q_i·K_iᵀ+B_rel)/√32)·V_i B_rel∈ℝ^(49×49)相对位置偏置 步骤6-Concat→(B·win,49,96) 重排回(B,3136,96) 步骤7-·W_O∈ℝ^(96×96)
残差①+LN②+FFN+残差②
FFN: 96→384→GELU→96
SW-MSA(移位窗口)
循环移位(shift=3,3) 注意力分数加掩码M A←A+M,M∈{0,-∞} 屏蔽跨窗口注意力 其余结构同W-MSA
Patch Merging ×3次
concat 2×2邻域+Linear 96→192→384→768 Stage2-4各×2/6/2层 头数3→6→12→24,d_k=32
全局平均池化
x=mean(x,dim=1) (B,49,768)→(B,768)
LayerNorm
分类头
W_cls∈ℝ^(768×C) out: (B,C) 图像分类头,C=1000 也用于检测/分割下游
图像输入
x∈(B,3,224,224)
Patch提取+线性投影
→(B,196,768)
随机掩码75%
丢弃被掩码patch →(B,~49,768)
固定sin/cos位置编码
E_pos∈ℝ^(196,768) 取可见patch位置对应行 out: (B,~49,768)
Encoder × 12层 Pre-LN(H=768,N≈49)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(768×768) Q/K/V∈(B,~49,768) 步骤2-来源:自注意力 Q/K/V来自可见patch X 步骤3:d_k=768/12=64 Q_i∈(B,~49,64) reshape→(B,12,~49,64) 步骤4-△Q·Kᵀ∈(B,12,~49,~49) 仅在可见patch间计算 V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√64)·V_i 步骤6-Concat→(B,~49,768) 步骤7-·W_O∈ℝ^(768×768)
③④ 残差+LN②
⑤⑥⑦ FFN+GELU+降维
768→3072→GELU→768
⑧ 残差加法②
LayerNorm(Encoder输出)
Linear投影+插入mask_token
W_proj∈ℝ^(768×512) 插入mask_token到缺失位置 →(B,196,512)
补全sin/cos位置编码
全序列196个token
Decoder × 8层(仅训练阶段,推理时丢弃)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(512×512) Q/K/V∈(B,196,512) 步骤2:含补入mask_token 步骤3:d_k=512/16=32 Q_i∈(B,196,32) reshape→(B,16,196,32) 步骤4-△Q·Kᵀ∈(B,16,196,196) V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√32)·V_i 步骤6-Concat→(B,196,512) 步骤7-·W_O∈ℝ^(512×512)
FFN
512→2048→GELU→512
像素重建头(仅训练)
W_rec∈ℝ^(512×768) 对masked patch算MSE out: (B,196,768) 预测masked patch的原始RGB 推理时丢弃此头
图像输入
x∈(B,3,224,224)
dVAE离散化(训练时)
encoder_dVAE(patch)→整数ID 词表大小8192 →(B,196)整数序列
Patch提取+线性投影
→(B,196,768)
掩码占位(~40%)
mask_token替换被掩patch 不丢弃,保留196个位置 out: (B,196,768)
相对位置偏置
B_rel∈ℝ^(196,196) 无绝对Position Embedding
Encoder × 12层 Pre-LN(H=768,N=196)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(768×768) Q/K/V∈(B,196,768) 步骤2-来源:自注意力 含mask_token占位的196 token 步骤3:d_k=768/12=64 Q_i∈(B,196,64) reshape→(B,12,196,64) 步骤4-△Q·Kᵀ∈(B,12,196,196) V_i仅加权聚合 步骤5-head_i=Softmax( (Q_i·K_iᵀ+B_rel)/√64)·V_i B_rel为相对位置偏置 步骤6-Concat→(B,196,768) 步骤7-·W_O∈ℝ^(768×768)
③④ 残差+LN②
⑤⑥⑦ Linear+GELU+Linear
768→3072→GELU→768
⑧ 残差加法②
预测头(训练)
W_pred∈ℝ^(768×8192) 只对masked位置算CE out: (B,196,8192) 预测dVAE离散token ID
微调分类头
GAP→Linear(768,C) out: (B,C)
图像双路增强
学生:local crop 教师:global crop
Patch提取+线性投影
→(B,196,768)
[CLS]拼接
cls∈(1,768)→(B,1,768) concat→(B,197,768)
Position Embedding广播
E_pos∈ℝ^(197,768) 广播→(B,197,768) x←x+E_pos_broadcasted
Dropout(学生有,教师无)
Encoder × 12层 Pre-LN(学生网络)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(768×768) Q/K/V∈(B,197,768) 步骤2-来源:自注意力 步骤3:d_k=768/12=64 Q_i∈(B,197,64) reshape→(B,12,197,64) 步骤4-△Q·Kᵀ∈(B,12,197,197) V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√64)·V_i 步骤6-Concat→(B,197,768) 步骤7-·W_O∈ℝ^(768×768)
③④⑤⑥⑦⑧ 残差/LN/FFN
LayerNorm
取[CLS]
(B,197,768)→(B,768)
Projection Head(3层MLP)
Linear①:W₁∈ℝ^(768×2048) GELU Linear②:W₂∈ℝ^(2048×2048) GELU Linear③:W₃∈ℝ^(2048×256) out=out/‖out‖ (L2归一化) out: (B,256) 256维L2归一化特征 用于自监督对比损失
EMA教师+损失
θ_t←m·θ_t+(1-m)·θ_s m≈0.996 L=-Σp_t·log(p_s) centering:c←m_c·c+(1-m_c) ·mean(z_t)
图像输入
x∈(B,3,224,224)
Patch提取+线性投影
W_E∈ℝ^(768×768) out: (B,196,768)
Position Embedding广播
E_pos∈ℝ^(196,768) x←x+E_pos_broadcasted 此阶段无[CLS]
SA-Block × 24层(仅196 patch,Pre-LN)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(768×768) Q/K/V∈(B,196,768) 步骤2-来源:自注意力 Patches X∈(B,196,768) 步骤3:d_k=768/16=48 Q_i∈(B,196,48) reshape→(B,16,196,48) 步骤4-△Q·Kᵀ∈(B,16,196,196) V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√48)·V_i 步骤6-Concat→(B,196,768) 步骤7-·W_O∈ℝ^(768×768)
LayerScale残差①
x←x+α·attn_out,α=1e-4
LayerNorm②+FFN+LayerScale残差②
FFN:768→3072→GELU→768 x←x+α·ffn_out
[CLS]拼接(24层SA后)
cls∈(1,768)→(B,1,768) concat→(B,197,768) △24层SA后才引入[CLS]
CA-Block × 2层(交叉注意力)
LayerNorm([CLS]和patch分别)
交叉注意力(Cross-Attention 7步)
步骤1:Q=[CLS]·W_Q W_Q∈ℝ^(768×768),Q∈(B,1,768) K=Patches·W_K∈(B,196,768) V=Patches·W_V∈(B,196,768) 步骤2-△交叉注意力: Q来自[CLS]∈(B,1,768)(1个token) K/V来自Patches∈(B,196,768) 步骤3:d_k=768/16=48 Q_i∈(B,1,48) K_i∈(B,196,48);V_i∈(B,196,48) △Q只有1个token,K/V有196个 步骤4-△Q·Kᵀ∈(B,16,1,196) 1个[CLS]查询196个patch位置 V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√48)·V_i∈(B,1,48) 步骤6-Concat→(B,1,768) 只更新[CLS] 步骤7-·W_O∈ℝ^(768×768)
残差(仅[CLS])+FFN+残差
LayerNorm;取[CLS]
(B,197,768)→(B,768)
分类头
W_cls∈ℝ^(768×C) out: (B,C) 图像分类头,C=1000
图像输入
x∈(B,3,256,256)
3×3 Conv stem
BN+SiLU,stride=2 out: (B,16,128,128)
MV2 Block × 多层
DWConv+BN+SiLU +PWConv+BN,残差
特征图折叠
(B,C,H,W)→(B,P,N,C) P个局部区域,每区域N token
MobileViT Block × L层(局部区域内,Pre-LN,无位置编码)
① LayerNorm①
② 多头自注意力(7步)
步骤1:Q=X·W_Q W_Q∈ℝ^(C×C) Q/K/V∈(B·P,N,C) 步骤2-来源:自注意力 同一局部区域N个token 步骤3:d_k=C/4 Q_i∈(B·P,N,C/4) reshape→(B·P,4,N,C/4) 步骤4-△Q·Kᵀ∈(B·P,4,N,N) 仅局部区域内计算 V_i仅加权聚合 步骤5-head_i=Softmax( Q_i·K_iᵀ/√(C/4))·V_i 步骤6-Concat→(B·P,N,C) 步骤7-·W_O∈ℝ^(C×C)
③ 残差加法①
④ LayerNorm②
⑤ Linear升维
W∈ℝ^(C×2C)
⑥ GELU
⑦ Linear降维
W∈ℝ^(2C×C)
⑧ 残差加法②
折叠回+PWConv融合
(B,P,N,C)→(B,C,H,W) PWConv融合局部全局特征
MV2+全局平均池化
x=mean(x,dim=[2,3]) →(B,C_last)
分类头
W_cls∈ℝ^(C_last×C) out: (B,C) 图像分类头,轻量化移动端 <6M参数
模型LH每头d_kFFN位置编码注意力范围Decoder参数量预训练擅长任务
ViT-Base/161276864768→3072→768可学习绝对全局(197×197)86MImageNet-21K监督图像分类微调下游
ViT-Large/16241024641024→4096→1024可学习绝对全局(197×197)307MJFT-300M监督高精度分类
DeiT1276864768→3072→768可学习绝对全局(198×198)86M蒸馏+ImageNet数据高效分类
Swin-T2+2+6+296→768324×H相对位置偏置局部窗口7×7/移位28MImageNet监督检测分割分类
MAE12+876864768→3072→768sin/cos固定全局(~49)有(仅训练,推理丢弃)86M(Enc)MIM像素重建自监督预训练
BEiT1276864768→3072→768相对位置偏置全局(196×196)86MMIM+dVAE token自监督预训练
DINO1276864768→3072→768可学习绝对全局(197×197)86M自监督EMA对比无标签特征学习
CaiT24+276848768→3072→768可学习绝对SA全局/CA交叉有(CA,始终使用)86MImageNet监督深层ViT分类
MobileViT混合变化C/42C局部区域内<6M监督移动端轻量化