• [文档] torch.distributions.Categorical


    这段时间看见分布式、并行之类的词语就害怕,结果这个是 distributions,分布,就是一些表征分布的函数们

    签名:

    torch.distributions.categorical.Categorical(probs=None, 
                                                logits=None, 
                                                validate_args=None)
    
    • 1
    • 2
    • 3

    Creates a categorical distribution parameterized by either probs or logits (but not both).
    创建一个离散的类别分布,参数由 probslogits, 二者其一指定

    If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.

    If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.


    如果 probs 是个一维的长度为K的张量,则每个元素是索引对应的类别的相对概率,将通过该概率进行采样
    probs 是N维张量,前N-1维只会被视为响应的 Batch, 这个不好翻译,直接看例子把

    第一个例子(官方Demo):

    >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
    >>> m.sample()  # 等概率的返回 0, 1, 2, 3
    
    • 1
    • 2

    第二个例子:

    probs = torch.FloatTensor([[0.05,0.1,0.85],[0.05,0.05,0.9]])
     
    dist = Categorical(probs)
    print(dist)
    # Categorical(probs: torch.Size([2, 3]))
     
    index = dist.sample()
    print(index.numpy())
    # 很大概率会是 [2 2]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    如果 probs 变量的 -1 维度上求和不为1,Categorical 内部也会帮你归一化,然后再 sample

    第二个例子:

    >>> probs = torch.FloatTensor([[[0.05,0.1,0.85],[0.05,0.05,0.9]]])
    >>> probs.shape
    torch.Size([1, 2, 3])
    
    >>> dist = Categorical(probs)
    >>> index = dist.sample()
    >>> index.shape
    torch.Size([1, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    上边这个例子,用来理解这句话:

    If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.

    probs.shape[X, Y, N, s]index.shape[X, Y, N]Categorical 只在最后一维上采样


    最后补一句,Categorical对象,有.entropy() 方法,用来计算熵

    >>> probs = torch.FloatTensor([[[0.05,0.1,0.85],[0.05,0.05,0.9]]])
    >>> dist = Categorical(probs)
    >>> dist.entropy()
    tensor([[0.5182, 0.3944]])
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    mysql创建schema和用户
    ubantu数据库安装以及使用——mysql+redis
    CentOS 7 双网卡bond 网卡mac 相同的处理
    javascript二叉树相关的知识
    在线图片转文字怎么转?这种方法大家可以学会
    CSS布局 | flex布局
    安装shap-e(openai开源的3D模型生成框架)踩过的一些坑
    Linux 反弹shell
    【问题解决】蓝牙显示已配对,无法连接,蓝牙设备显示在其他设备中。
    20240405,数据类型,运算符,程序流程结构
  • 原文地址:https://blog.csdn.net/HaoZiHuang/article/details/126356668