Tiny Transformer:从零开始构建简化版Transformer模型

引言

        自然语言处理(NLP)与计算机视觉(CV)有显著差异,各自任务的独特性决定了它们适用的模型架构。在CV中,卷积神经网络(CNN)长期占据主导地位,而在NLP领域,循环神经网络(RNN)和长短期记忆网络(LSTM)曾是主流。然而,这些传统模型在处理长序列时效率较低,难以捕捉长期依赖关系。

        针对这些问题,Vaswani等人在2017年提出了一种全新的、完全基于注意力机制的模型——Transformer。该模型解决了RNN串行计算的效率问题,并通过自注意力机制有效处理了长序列的长期依赖问题。本文将带领大家一步步构建一个简化版的Transformer模型,称之为Tiny Transformer,帮助大家深入理解其工作原理。

1. 注意力机制

        Transformer的核心是注意力机制,它通过计算Query、Key和Value之间的相关性,动态地为不同位置分配注意力权重。我们将通过多头注意力机制(Multi-Head Attention)来扩展这种计算,以便模型能同时关注多个不同的相关性。

1.1 什么是Attention?

        Attention机制通过计算Query(查询向量)与Key(键向量)之间的相似度来为Value(值向量)加权求和。它的本质是根据当前输入的每个词与其他词的相关性动态调整注意力分布。

        例如,给定一个句子,我们可以通过Attention机制来计算每个词对其他词的关注程度。Attention公式如下:

1.2 Multi-Head Attention

        多头注意力机制扩展了单头注意力的概念,通过并行化多个注意力头来捕获序列中不同层次的相关性。每个注意力头对输入进行独立的Attention计算,然后将所有头的输出拼接起来,形成最终的输出。

import torch.nn as nn
import torch
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"

        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.fc_out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        
        attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, T, C)
        return self.fc_out(attn_output)
2. 编码器和解码器

        Transformer的结构包括编码器(Encoder)和解码器(Decoder),二者均由多层的注意力机制和前馈神经网络(Feed-Forward Neural Network, FFN)组成。

2.1 编码器

        编码器的主要任务是对输入序列进行编码,并生成上下文表示供解码器使用。每个编码器层包括一个自注意力层和一个前馈网络。

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, ff_hidden_dim, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attn_output = self.mha(x)
        x = self.layernorm1(x + self.dropout(attn_output))
        
        ffn_output = self.ffn(x)
        return self.layernorm2(x + self.dropout(ffn_output))
2.2 解码器

        解码器的结构与编码器类似,但它包含了一个额外的“交叉注意力”层,用于将编码器的输出作为上下文信息输入,结合解码器自身的输入进行生成。

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, ff_hidden_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out):
        attn_output1 = self.mha1(x)
        x = self.layernorm1(x + self.dropout(attn_output1))
        
        attn_output2 = self.mha2(x, enc_out, enc_out)
        x = self.layernorm2(x + self.dropout(attn_output2))
        
        ffn_output = self.ffn(x)
        return self.layernorm3(x + self.dropout(ffn_output))
3. 位置编码

        Transformer由于完全摒弃了递归结构,不能自然捕捉输入序列中的位置信息。因此,位置编码(Positional Encoding)被引入,用于为每个词添加位置信息。位置编码通过正弦和余弦函数为不同位置生成独特的表示。

import math
import torch

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
4. 完整的Transformer模型

        有了上面各个模块后,我们可以将它们组合成一个完整的Transformer模型。该模型包括一个嵌入层、多个编码器层、解码器层以及一个线性层用于生成输出。

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, ff_hidden_dim, dropout):
        super(Transformer, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, ff_hidden_dim, dropout) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, ff_hidden_dim, dropout) for _ in range(num_decoder_layers)])
        
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        src = self.positional_encoding(self.src_embedding(src))
        tgt = self.positional_encoding(self.tgt_embedding(tgt))
        
        for layer in self.encoder_layers:
            src = layer(src)
        
        for layer in self.decoder_layers:
            tgt = layer(tgt, src)
        
        return self.fc_out(tgt)
结语

        本文通过逐步实现简化版的Transformer,展示了Transformer模型的核心组成部分——多头注意力、编码器-解码器架构、位置编码等。通过这些模块,Transformer能够高效处理序列数据,实现并行计算,广泛应用于自然语言处理、机器翻译等任务。

        Transformer的灵活性和强大的性能使其成为现代深度学习的基石。在掌握了这些基本模块后,大家可以进一步研究更复杂的模型,如BERT、GPT等预训练模型,以更好地理解和应用Transformer在实际任务中的强大能力。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.niftyadmin.cn/n/5688518.html

相关文章

物流行业中的AI平台架构与智能化应用

随着物流行业的迅速发展,尤其是电商、仓储、运输的需求日益增多,AI技术逐渐成为推动物流企业高效运营、提升服务水平的关键力量。AI平台架构为物流行业的各个环节提供了智能化解决方案,助力物流企业在仓储管理、运输调度、客户服务等方面实现…

重生之我们在ES顶端相遇第 18 章 - Script 使用(进阶)

文章目录 0. 前言1. 基本使用2. 读请求中访问文档字段2.1 遍历 List2.2 判断对象存不存在2.3 判断值是否为空2.4 总结 3. 写请求中访问文档字段3.1 数字相加3.2 字符串相加3.3 将字符串转为数组 0. 前言 在前面部分,我们介绍了 ES 的基本使用和要掌握的基础性读写原…

微信小程序数据操作指南:从绑定到更新

微信小程序数据操作指南:从绑定到更新 在微信小程序开发中,数据操作是核心环节之一。微信小程序提供了一系列简洁而强大的数据操作方法,帮助开发者轻松实现数据的绑定、更新和渲染。本文将详细介绍微信小程序中常用的数据操作方法&#xff0…

MySQL 启动失败 (code=exited, status=1/FAILURE) 异常解决方案

目录 前言1. 问题描述2. 查看错误日志文件2.1 确认日志文件路径2.2 查看日志文件内容 3. 定位问题3.1 问题分析 4. 解决问题4.1 注释掉错误配置4.2 重启 MySQL 服务 5. 总结结语 前言 在日常运维和开发过程中,MySQL数据库的稳定运行至关重要。然而,MySQ…

视频加字幕免费软件哪个好用?详细介绍6款字幕编辑软件的优缺点!码住!

视频加字幕免费软件哪个好用?在视频制作和编辑的过程中,字幕的添加是不可或缺的一环。它不仅能帮助观众更好地理解视频内容,还能提升视频的专业度和观赏性。然而,面对市场上琳琅满目的视频加字幕软件,如何选择一款既免…

推荐 uniapp 相对好用的海报生成插件

插件地址:自定义canvas样式海报 - DCloud 插件市场 兼容性也是不错的:

C++11 异步操作 std::future类

阅读导航 引言一、异步的概念二、应用场景1. 异步任务处理2. 并发控制3. 结果获取 三、使用示例1. 使用std::async关联异步任务💻示例代码说明 2. 使用std::packaged_task和std::future配合(1)定义std::packaged_task(2&#xff0…

【MySQL】子查询、合并查询、表的连接

目录 一、子查询 1、单行子查询 显示SMITH同一部门的员工信息 2、多行子查询 in关键字 查询和10号部门的工作岗位相同的雇员的名字、岗位、工资、部门号,但是筛选出的雇员的部门不能有10号部门 all关键字 查询工资比30号部门中所有雇员工资高的雇员的姓名、…