visionTransformer代码¶
约 30 个字 197 行代码 预计阅读时间 3 分钟
1 patch的构建¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
# step1 convert image to embedding vector sequence
def image2emb_navie(image,patch_size,weight):
# image shape:bs × channel × height × width
patch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)
patch_embedding = patch @ weight
return patch_embedding
def image2emb_conv(image,kernel,stride):
conv_output = F.conv2d(image,kernel,stride=stride) # bs*oc*oh*ow
bs,oc,oh,ow = conv_output.shape
patch_embedding = conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)
return patch_embedding
# test code for image2emb
bs,ic,image_h,image_w = 1,3,8,8
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * ic
image = torch.randn(bs,ic,image_h,image_w)
weight = torch.randn(patch_depth,model_dim) # model_dim是输出通道数目,patch depth是卷积核的面积乘以输入通道数
# 分块方法得到embedding
patch_embedding_naive = image2emb_navie(image,patch_size,weight)
kernel = weight.transpose(0,1).reshape((-1,ic,patch_size,patch_size))
# 二维卷积方法得到embedding
patch_embedding_conv = image2emb_conv(image,kernel,patch_size)
print(patch_embedding_naive)
print(patch_embedding_conv)
注释:
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
# step1 convert image to embedding vector sequence
def image2emb_navie(image,patch_size,weight):
# image shape:bs × channel × height × width = 1,3,8,8
patch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)
# patchshape = torch.Size([1, 48, 4]) patch_size = 4
# 1:batchsize
# 48 = 3*4*4(卷积覆盖的input region)
# 4:1,3,8,8的输入图片用 1344的卷积核卷积,得到4个input region
# transpose(-1,-2) → 1,4,48
patch_embedding = patch @ weight
# 1,4,48 @ 48,8 = 1× 4 × 8
return patch_embedding
def image2emb_conv(image,kernel,stride):
# image = bs,ic,image_h,image_w = 1,3,8,8
# kernel = 8 × 3 × 4 × 4
# stride = patch_size = 4
conv_output = F.conv2d(image,kernel,stride=stride) # bs*oc*oh*ow
# (h-k+2p+s)/s = (8-4+4)/4 = 2
# conv_output = 8 × 1 × 2 × 2
bs,oc,oh,ow = conv_output.shape
patch_embedding = conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)
# conv_output = 1 × 8 × 2 × 2
# reshape : 1 × 8 × 4
# transpose(-1,-2) 1 × 4 × 8
#(输入图片 划分成 4个patch,每个patch由原来的 48个像素表示,降维成8维表示)
return patch_embedding
# test code for image2emb
bs,ic,image_h,image_w = 1,3,8,8
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * ic
image = torch.randn(bs,ic,image_h,image_w)
weight = torch.randn(patch_depth,model_dim) # model_dim是输出通道数目,patch depth是卷积核的面积乘以输入通道数
# 分块方法得到embedding
patch_embedding_naive = image2emb_navie(image,patch_size,weight)
# conv版本:
kernel = weight.transpose(0,1).reshape((-1,ic,patch_size,patch_size))
# weight = 48 × 8
# transpose(0,1) : 8 × 48
# reshape :8 × 3 × 4 × 4
# 二维卷积方法得到embedding
patch_embedding_conv = image2emb_conv(image,kernel,patch_size)
# image = bs,ic,image_h,image_w = 1,3,8,8
# kernel = 8 × 3 × 4 × 4
# patch_size = 4
print(patch_embedding_naive)
print(patch_embedding_conv)
2 CLS token embedding¶
Python
# step2 prepend CLS token embedding
# patch_embedding_conv = 1 × 4 × 8
# cls_token_embedding = 1 × 1 × 8
cls_token_embedding = torch.randn(bs,1,model_dim,requires_grad=True)
# token_embedding
# 第一个位置 是 cls token,cls token的嵌入维度是 8
# 所以 dim = 1
token_embedding = torch.cat([cls_token_embedding,patch_embedding_conv],dim=1)
3 Position embedding¶
Python
# step3 add position embedding
positon_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
seq_len = token_embedding.shape[1]
positon_embedding = torch.tile(positon_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
token_embedding += positon_embedding
注释:
Python
# step3 add position embedding
# max_num_token = 16
# model_dim = 8
# positon_embedding_table = 16,8
positon_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
# token_embedding = 1,5,8 (bs,5个位置(1个cls token、4个单词),model_dim = 8)
# seq_len = 5
seq_len = token_embedding.shape[1]
positon_embedding = torch.tile(positon_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
# positon_embedding_table[:seq_len] = positon_embedding_table[:5] 取前5个8维
# [:5] 表示 对 第一维 索引
# positon_embedding_table[:seq_len] = 5,8
# [token_embedding.shape[0],1,1] = [1,1,1]
# positon_embedding = 1,5,8
token_embedding += positon_embedding
# token_embedding = 1,5,8
4 Transformer Encoder¶
Python
# step4 Pass embedding to Transformer Encoder
# d_model = model_dim = 8
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=6)
# token_embedding = 1,5,8(可以理解为 5个词,每个词 嵌入 8个维度)
encoder_output = transformer_encoder(token_embedding)
5 classification head¶
Python
# step5 do classification
cls_token_output = encoder_output[:,0,:]
linear_layer = nn.Linear(model_dim,num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits,label)
print(loss)
6 全部代码¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
def image2emb_navie(image,patch_size,weight):
# image shape:bs × channel × height × width
patch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)
patch_embedding = patch @ weight
return patch_embedding
def image2emb_conv(image,kernel,stride):
conv_output = F.conv2d(image,kernel,stride=stride) # bs*oc*oh*ow
bs,oc,oh,ow = conv_output.shape
patch_embedding = conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)
return patch_embedding
# test code for image2emb
bs,ic,image_h,image_w = 1,3,8,8
patch_size = 4
model_dim = 8
max_num_token = 16
num_classes = 10
label = torch.randint(10,(bs,))
patch_depth = patch_size * patch_size * ic
image = torch.randn(bs,ic,image_h,image_w)
weight = torch.randn(patch_depth,model_dim) # model_dim是输出通道数目,patch depth是卷积核的面积乘以输入通道数
patch_embedding_naive = image2emb_navie(image,patch_size,weight) # 分块方法得到embedding
kernel = weight.transpose(0,1).reshape((-1,ic,patch_size,patch_size)) # oc*ic*kh*kw
patch_embedding_conv = image2emb_conv(image,kernel,patch_size) # 二维卷积方法得到embedding
# print(patch_embedding_naive)
# print(patch_embedding_conv)
# step2 prepend CLS token embedding
cls_token_embedding = torch.randn(bs,1,model_dim,requires_grad=True)
token_embedding = torch.cat([cls_token_embedding,patch_embedding_conv],dim=1)
# step3 add position embedding
positon_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
seq_len = token_embedding.shape[1]
positon_embedding = torch.tile(positon_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
token_embedding += positon_embedding
# step4 Pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=6)
encoder_output = transformer_encoder(token_embedding)
# step5 do classification
cls_token_output = encoder_output[:,0,:]
linear_layer = nn.Linear(model_dim,num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits,label)
print(loss)
2024-11-25 21:21:30 2024-12-04 22:03:23