• 基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask


    基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask

    flyfish

    Word-level Language Modeling using RNN and Transformer

    word_language_model

    PyTorch 提供的 word_language_model 示例展示了如何使用循环神经网络RNN(GRU或LSTM)和 Transformer 模型进行词级语言建模 。默认情况下,训练使用Wikitext-2数据集,generate.py可以使用训练好的模型来生成新文本。

    源码地址
    https://github.com/pytorch/examples/tree/main/word_language_model

    文件:model.py

    import torch
    import matplotlib.pyplot as plt
    import numpy as np
    
    def _generate_square_subsequent_mask(sz):
        return torch.log(torch.tril(torch.ones(sz, sz)))
    
    # 设置矩阵大小
    sz = 5
    mask = _generate_square_subsequent_mask(sz)
    
    # 将 mask 转换为 numpy 数组,方便可视化
    mask_np = mask.numpy()
    
    # 可视化
    plt.imshow(mask_np, cmap='viridis')
    plt.colorbar()
    plt.title("Square Subsequent Mask")
    plt.show()
    

    可视化图示
    在可视化结果中,你会看到一个下三角矩阵,其值为 0 的部分为下三角部分,值为负无穷的部分为上三角部分。图像中通常负无穷会被显示为一种不同的颜色。

    这样,你可以直观地理解生成的掩码矩阵的结构和作用。这个掩码矩阵主要用于 Transformer 模型中,以确保模型在预测时只能看到当前时刻及之前的时刻信息,而不能看到未来的信息。
    在这里插入图片描述
    结果
    运行这段代码,你会看到一个 5x5 的矩阵,其中下三角部分是 0(因为 log(1) = 0),上三角部分是负无穷(由于 log(0) 是负无穷)。

    def _generate_square_subsequent_mask(sz):
        return torch.log(torch.tril(torch.ones(sz, sz)))
    
    # 设置矩阵大小
    sz = 5
    mask = _generate_square_subsequent_mask(sz)
    
    # 打印矩阵
    print(mask)
    

    输出

    tensor([[0., -inf, -inf, -inf, -inf],
            [0., 0., -inf, -inf, -inf],
            [0., 0., 0., -inf, -inf],
            [0., 0., 0., 0., -inf],
            [0., 0., 0., 0., 0.]])
    

    在数学上,定义对数函数时,log(0) 是未定义的,但在计算中,我们处理这种情况的方式是认为 log(0) 的极限值是负无穷。因此,计算机通常会返回负无穷来表示这种情况。

    在 PyTorch 中,torch.log(0) 的结果是 -inf(负无穷)。这是因为对数函数是单调递增的,并且在接近0时值会急剧下降到负无穷。

  • 相关阅读:
    Qt QImage和QPixmap区别
    数字图像滤波的本质
    docker-compose搭建私有Gitlab
    linux多处理器并发访问共享资源---自旋锁
    好多自恋性数
    白盒 SDK 加密 —— Go 语言中直调 C 动态库实现
    【CFD小工坊】浅水方程的离散及求解方法
    使用STM32怎么喂狗 (IWDG)
    【C++】类和对象(下)
    PHP:namespace 关键字和 __NAMESPACE__ 常量
  • 原文地址:https://blog.csdn.net/flyfish1986/article/details/139316412