• 将多个 TransformerEncoderLayer 层堆叠起来,形成一个完整的 Transformer 编码器


    该函数作用是将多个 TransformerEncoderLayer 层堆叠起来,形成一个完整的 Transformer 编码器。以下是这个类的主要部分的解释:

    • encoder_layer: 这是一个 TransformerEncoderLayer 类的实例,表示编码器层的构建模块。编码器由多个这样的层叠加而成。

    • num_layers: 这是编码器中的子编码器层数。也就是说,编码器由多少个 encoder_layer 堆叠而成。

    • norm: 这是可选的层归一化组件,用于在编码器的输出上应用层归一化。

    • forward 函数:这个函数执行编码器的前向传播过程。它接受输入序列 src,以及可选的掩码 mask 和序列键掩码 src_key_padding_mask。然后,它迭代遍历每个子编码器层(由 self.layers 组成),并将输入序列 src 传递给每一层。最后,如果指定了层归一化组件 norm,则应用层归一化并返回输出。

    这个类的主要作用是组装多个编码器层,使得它们可以一层一层地处理输入序列,并生成编码器的输出。这个输出通常用作后续任务的输入,例如序列到序列任务、文本分类等。

    1. class TransformerEncoder(Module):
    2. r"""TransformerEncoder is a stack of N encoder layers
    3. Args:
    4. encoder_layer: an instance of the TransformerEncoderLayer() class (required).
    5. num_layers: the number of sub-encoder-layers in the encoder (required).
    6. norm: the layer normalization component (optional).
    7. Examples::
    8. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
    9. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    10. >>> src = torch.rand(10, 32, 512)
    11. >>> out = transformer_encoder(src)
    12. """
    13. __constants__ = ['norm']
    14. def __init__(self, encoder_layer, num_layers, norm=None):
    15. super(TransformerEncoder, self).__init__()
    16. self.layers = _get_clones(encoder_layer, num_layers)
    17. self.num_layers = num_layers
    18. self.norm = norm
    19. def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
    20. r"""Pass the input through the encoder layers in turn.
    21. Args:
    22. src: the sequence to the encoder (required).
    23. mask: the mask for the src sequence (optional).
    24. src_key_padding_mask: the mask for the src keys per batch (optional).
    25. Shape:
    26. see the docs in Transformer class.
    27. """
    28. output = src
    29. for mod in self.layers:
    30. output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
    31. if self.norm is not None:
    32. output = self.norm(output)
    33. return output

  • 相关阅读:
    Android 13.0 第三方应用默认横屏显示
    迈动互联获“ISO20000信息技术服务管理体系认证证书”
    VM17虚拟机设置网络,本地使用工具连接虚拟机
    我用ChatGPT写了一个简单的Python自动化测试脚本
    Flare Network,跨越互操作性三难困境
    【SSM】SpringMVC系列——SpringMVC概述
    Qt | QListView、QListWidget、QTableView、QTableWidget的使用示例及区别
    uniapp echarts 适配H5与微信小程序
    涨姿势了,有意思的气泡 Loading 效果
    【数据结构与算法】二叉树OJ练习题
  • 原文地址:https://blog.csdn.net/vivi_cin/article/details/132900466