• Pytorch:张量的索引操作


    参考神君

    一、索引

    1.使用整数索引访问单个元素

    import torch
    x=torch.tensor([[1,2,3],[4,5,6]])
    lst = [[1,2,3],[4,5,6]]
    print(x[0][1])#等价于x[0,1]
    print(lst[0][1])
    
    • 1
    • 2
    • 3
    • 4
    • 5

    输出:

    tensor(2)
    2
    
    • 1
    • 2

    2.使用多个整数索引访问多个元素*

    list没有这种索引方式。这种索引方式,通过在索引中,给定多个列表或多个张量,索引按这个列表或张量的次序访问元素。

    a.示例详解

    考虑以下的 3x3 张量:

    1 2 3
    4 5 6
    7 8 9
    
    • 1
    • 2
    • 3

    假设我们想从中选择元素 26。这两个元素分别位于:

    • 2 在第0行第1列
    • 6 在第1行第2列

    要使用 fancy indexing 来选择这些元素,可以做如下操作:

    import torch
    
    # 创建一个 3x3 的张量
    tensor = torch.tensor([[1, 2, 3],
                           [4, 5, 6],
                           [7, 8, 9]])
    
    # 定义行索引和列索引
    rows = torch.tensor([0, 1])
    columns = torch.tensor([1, 2])
    
    # 使用行索引和列索引进行选择
    selected_elements = tensor[rows, columns]
    
    print(selected_elements)  # 输出 tensor([2, 6])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    在上面的例子中,rowscolumns 是两个张量,分别指定了要访问的行和列的索引。当你传递 rowscolumns 到原始张量 tensor 时,PyTorch 会解释这样的索引方式:

    • 对于每对 (row, column) 索引,选择对应的元素。
    • rows[0]columns[0] 组合指定了元素 tensor[0, 1](即 2)。
    • rows[1]columns[1] 组合指定了元素 tensor[1, 2](即 6)。

    这种索引非常灵活,允许你从张量中快速选择一个不规则的元素集合。这对于机器学习和数据预处理任务尤其有用,因为经常需要从数据集中提取特定的样本或特征。

    b.示例

    import torch
    x=torch.tensor([[1,2,3],[4,5,6]])
    print(x[[0,1],[1,1]])
    
    • 1
    • 2
    • 3

    输出:

    tensor([2, 5])#x[0][1]和x[1][1]
    
    • 1

    当我们只给定一个维度的时候,也是一样的。

    y=torch.tensor([1,2,3,4,5,6])
    print(y[[0,1]])
    
    • 1
    • 2

    输出:

    tensor([1, 2]) # y[0]和y[1]
    
    • 1

    c.一维索引示例

    import torch
    x=torch.tensor([[1,2,3],[4,5,6],[4,5,6],[4,5,6],[4,5,6],[4,5,6],[4,5,6]])
    print(x[[0,1,5,4,2,6,3]])
    
    • 1
    • 2
    • 3

    3.使用负数索引从张量的末尾开始计数

    如果某个维度的索引是-1,则在该维取末尾元素。

    import torch
    x=torch.tensor([[1,2,3],[4,5,6]])
    print(x[[0,1],[1,-1]])
    
    • 1
    • 2
    • 3

    输出:

    tensor([2, 6])
    
    • 1

    4.使用布尔索引访问满足条件的元素*

    list没有这种索引方式。

      高级索引是一种在 PyTorch 和 NumPy 中常用的索引方法,它允许你从数组或张量中选择复杂的、非连续的数据子集。高级索引可以通过传递整数数组、张量或列表来实现,而这些索引方式相比基本的切片提供了更大的灵活性。在 PyTorch 中,使用高级索引时,索引操作的结果通常会形成一个新的张量,不与原始数据共享内存。

    高级索引的几种常见形式:

    1. 整数数组索引
      使用整数数组进行索引时,你可以指定要访问的每个维度上的索引位置。这种方式可以从张量中选择任意位置的数据,而这些数据可以是非连续和非规则的。

    2. 布尔(掩码)索引
      布尔索引允许你使用布尔数组(通常是逻辑条件的结果)来选择张量的元素。这种方法非常适用于基于条件的筛选。

    a.张量的元素级布尔操作

    在 PyTorch(以及其他类似的库,如 NumPy)中,使用张量进行布尔表达式的操作本质上是一种称为“元素级”或“元素对元素”的操作。当你在布尔表达式中使用张量时,PyTorch 会自动应用广播和矢量化操作,使得表达式能够逐元素地计算结果。这种处理方式使得代码不仅可读性好,而且效率高,非常适合科学计算和机器学习任务。

    元素级布尔操作:

    在 PyTorch 中,布尔操作(例如比较操作 <, >, <=, >=, ==, !=)都是元素级的。这意味着每个操作都是在输入张量的对应元素间独立进行的,并生成一个布尔类型的张量,其中的每个元素都是单个比较的结果。

    原理解释:

    1. 矢量化
      矢量化是指使用优化的库例程一次处理整个数组(或张量),而不是在Python层面上使用循环处理数组的每个元素。这减少了循环的开销,提高了执行速度,尤其是在底层使用如C/C++等编译语言实现的情况下。

    2. 广播
      广播是一种灵活处理不同形状张量的方法。当在两个不同大小的张量上进行操作时,较小的张量会自动“扩展”其维度以匹配较大张量的形状。例如,如果你有一个形状为(3,1)的张量和一个形状为(1,4)的张量,那么在操作中,每个张量都会广播到(3,4)以进行元素级操作。

    示例:

    考虑两个张量 AB,大小分别为 (3,3)(3,)

    import torch
    
    A = torch.tensor([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])
    B = torch.tensor([2, 2, 2])
    
    # 元素级比较
    C = A > B
    print(C)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    输出:

    tensor([[False, False,  True],
            [ True,  True,  True],
            [ True,  True,  True]])
    
    • 1
    • 2
    • 3

    在这里,B 被广播到与 A 相同的形状 (3,3),每个元素的比较都是独立进行的。

    b.布尔索引示例

    布尔索引可以用来选择满足特定条件的元素。例如,选择张量中所有大于5的元素:

    # 创建一个 3x4 的张量
    tensor = torch.tensor([[1, 2, 3, 4],
                           [5, 6, 7, 8],
                           [9, 10, 11, 12]])
    
    # 布尔索引
    mask = tensor > 5
    selected_elements = tensor[mask]
    print(mask)
    print(selected_elements)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    输出:

    tensor([[False, False, False, False],
            [False,  True,  True,  True],
            [ True,  True,  True,  True]])
    tensor([ 6,  7,  8,  9, 10, 11, 12])
    
    • 1
    • 2
    • 3
    • 4

    5.torch.where()函数根据条件选择元素

    torch.where() 是一个非常有用的函数,在 PyTorch 中用于根据条件从两个张量中选择元素。该函数的工作原理相似于 NumPy 中的 np.where(),允许在条件为真时从一个张量选择元素,条件为假时从另一个张量选择元素。

    a.函数原型

    torch.where() 函数的基本语法如下:

    torch.where(condition, x, y)
    
    • 1
    • condition: 一个布尔张量,其中的每个元素都对应于 xy 张量中相应位置的条件检查。
    • x: 当条件为真时将从这个张量中选择元素。
    • y:当条件为假时将从这个张量中选择元素。

    结果张量的每个位置会根据 condition 张量在相应位置的值是真还是假,从 xy 张量中选择值。


    torch.where(condition)
    
    • 1
    • 当仅提供 condition参数时,此函数返回满足条件的元素的索引。这可以用于找出满足特定条件的所有元素的位置。返回的是一个元组,其中每个元素是一个张量,分别代表满足条件的元素在各个维度上的索引。
    # 创建温度张量
    temperatures = torch.tensor([-5, 13, -2, 8, -1])
    condition = temperatures < 0
    indexs = torch.where(condition)
    corrected_temperatures = temperatures[indexs]
    print(corrected_temperatures)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输出:

    tensor([-5, -2, -1])
    
    • 1

    b.示例

    假设有一个张量代表温度值,想要将所有低于零度的值设置为零:

    # 创建温度张量
    temperatures = torch.tensor([-5, 13, -2, 8, -1])
    
    # 使用 torch.where() 来修正负值
    # 注意,这里的temperatures<0是之前提到的元素级布尔操作,生成的是tensor([True,False,True,False,True])
    corrected_temperatures = torch.where(temperatures < 0, torch.tensor(0), temperatures)
    
    print(corrected_temperatures)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    or:

    # 创建温度张量
    temperatures = torch.tensor([-5, 13, -2, 8, -1])
    condition = temperatures < 0
    indexs = torch.where(condition)
    for i in indexs:
    	temperatures[i]=0
    print(temperatures)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    tensor([ 0, 13,  0,  8,  0])
    
    • 1

    import torch
    A = torch.tensor([5,1,2])
    B = torch.tensor([2,4,7])
    C = torch.where(A<B,A,B)
    print(C)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    输出:

    tensor([2, 1, 2])
    
    • 1

    6. torch.take()函数按索引从张量中选择元素

    在 PyTorch 中,.take() 函数是一个实用的张量操作,它用于从输入张量中按照指定的索引来提取元素。这个函数允许你将输入张量视为一维张量,并使用一维索引从中选择元素。这种方式特别适用于从多维张量中按特定顺序选择元素,而不必担心张量的原始维度。

    a. .take() 函数的基本用法

    函数如下:

    tensor.take(indices)  or torch.take(tensor,indices)
    
    • 1
    • indices:一个包含要提取的元素索引的一维张量。

    这个函数按照 indices 提供的索引从输入张量中取出元素。索引假设输入张量是一维的,并按照行优先(C样式)顺序展开。

    b.示例:

    假设你有一个二维张量,想根据特定的索引列表从中选择元素。下面是如何使用 .take() 来实现这一点的示例:

    import torch
    
    # 创建一个二维张量
    tensor = torch.tensor([[1, 2, 3],
                           [4, 5, 6],
                           [7, 8, 9]])
    
    # 定义一维索引张量
    indices = torch.tensor([0, 4, 8])
    
    # 使用.take()选择元素
    selected_elements = tensor.take(indices)
    
    print(selected_elements)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在这个例子中,selected_elements 将包含由 indices 指定的位置的元素。输出将是:

    tensor([1, 5, 9])
    
    • 1

    这里的索引 0, 4, 8 分别对应张量展开后的第1,第5,第9个元素(考虑到从零开始索引)。

  • 相关阅读:
    基于eXosip2实现的客户端和服务端
    linux-checklist命令行
    吐血整理,最全Pytest自动化测试框架快速上手(超详细)
    y135.第七章 服务网格与治理-Istio从入门到精通 -- 网格和SSO(二一)
    《web课程设计》 基于HTML+CSS+JavaScript实现中国水墨风的小学学校网站模板(6个网页)
    图的遍历概述
    javax.net.ssl.SSLException: Unrecognized SSL message, plaintext connection
    Day8 尚硅谷JUC——JUC概述
    git and svn 行尾风格配置强制为lf
    从零开始的Hadoop学习(六)| HDFS读写流程、NN和2NN工作机制、DataNode工作机制
  • 原文地址:https://blog.csdn.net/m0_63997099/article/details/137906359