Transform结构详解 + 手撕版本
Transform
下图是Transform 的整体架构,由decoder和encoder构成,下面分部分对Transform的结构进行分析 (下图来自于论文Attention is all you need)。
(论文链接https://arxiv.org/pdf/1706.03762.pdf)
1.Encoder
Encoder 主要是用来对句子的输入进行编码,下面用 ”我爱学习“ 这句话为例子解析编码过程。
首先是以词向量的形式进行输入,并且这里的词向量加入了positional encoding,也就是位置信息,来标定 ’我‘ ’爱‘ ’学‘ ’习‘ 这四个词向量的位置。
下一步就是将融合了位置信息的词向量输入到self-Attention 中进行编码
1.1 Self-Attention
- Self-Attention 的过程如下图所示
- 以第一个向量为例,对它进行变换生成Q,K,V。首先Q输入到第一个Attention中,然后分别和所有的K生成权重,然后根据权重对所有的V加权求和得到第一个Attention输出,其他的Attention类似,最终可以得到四个红色的向量。
- 注:图中只画出了第一个Attention的输入关系,其他的类似。
- 在输出红色的向量之后,红色的向量还要再加上原来的输入向量,然后再进行Norm操作,得到黄色的向量。
1.2 Multi-Head Attention
- 所谓多头注意力机制,也就是多次进行自注意力机制编码,如下图所示。
1.3 Feed Forward
- FF层比较简单,即为将Muti-head Attention 输出的向量通过全连接层,并与自己相加。
2 Decoder
2.1 Masked-Multi-head Attention
-
Masked 的意思即为在产生第n个编码的时候只能考虑第n个和第n个之前的信息,不能考虑之后的信息。
-
注:只画出了V的输入关系。
2.2 Cross Attention
- 右边绿色的向量为解码器的输入,在机器翻译任务中就是,要翻译成的语言比如说’I love learning’ ,0表示开始产生。Cross Attention 就是将decoder的编码作为Q和将Encoder的输出作为K,V进行Multi-Head Self-Attention。
3 Decoder 训练过程
- Decoder在训练过程中是并行进行的,也就是说‘I love learning’ 是同步输入解码器的,输入的是标准答案‘I love learning’
- 在inference过程中,不是并行需要一个一个的输出,先产生 I, 在根据I产生love ,当然可能产生的不是love而是like。
4 手撕代码
import torch
import torchvision
class Layernorm_m(torch.nn.Module):
def __init__(self):
super(Layernorm_m,self).__init__()
pass
def forward(self,x):
mean = torch.mean(x, dim = 2)
std = torch.std(x, dim = 2)
return (x - mean[:, :, None]) / std[:,:,None]
class Attention(torch.nn.Module):
def __init__(self):
super(Attention,self).__init__()
self.Wq = torch.nn.Linear(512,512,bias= False)
self.Wk = torch.nn.Linear(512, 512,bias= False)
self.Wv = torch.nn.Linear(512, 512,bias= False)
self.fc = torch.nn.Linear(512, 512,bias= False)
self.layernorm = Layernorm_m()
def forward(self,x):
res = x
q = self.Wq(x)
k = self.Wk(x)
v = self.Wv(x)
#q* k.T * v
A = q.bmm(k.permute(0,2,1)) / torch.sqrt(torch.tensor(512,dtype = torch.float32))
A = torch.softmax(A, dim = -1)
x = A.bmm(v)
x = self.fc(x)
return self.layernorm(x + res)
class PoswiseFeedForwardNet(torch.nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet,self).__init__()
self.fc = torch.nn.Linear(512,512)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(512,512)
self.layerNorm = Layernorm_m()
def forward(self,x):
res = x
x = self.fc(x)
x = self.relu(x)
x = self.fc2(x)
return self.layerNorm(x + res)
class Encoder(torch.nn.Module):
def __init__(self):
super(Encoder,self).__init__()
self.attention = Attention()
self.ffn = PoswiseFeedForwardNet()
def forward(self,x):
x = self.attention(x)
x = self.ffn(x)
return x
x = torch.randn((4,16,512))
encoder = Encoder()
x = encoder(x)
print(x)
pass
Coraline-: 啊啊啊,多写点
m0_70624260: 复现出来效果怎么样啊
DBY9909: 需要自己写一下处理数据集的代码
m0_70624260: 想问一下博主这篇论文的代码跑通了吗
秋风——落叶: 你的能运行?为什么我运行后springboot-Demo在类APIController那段代码里面会显示找不到类RegisterRequest