• 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数



    前言

    本篇作为后期文章“特征融合”的基础。
    特征融合分早融合和晚融合,早融合里的重要手段是concat和add

    一、torch.cat()函数 拼接只存在h,w(高,宽)的图像

    torch.cat()可以将多个张量合并为一个张量,我们接下来从简单到复杂一点点来盘这个函数

    我们首先随机生成两个形状一致的张量:

    import torch
    A =torch.rand(3,2)  #单通道,高为3.宽为2的张量
    B=torch.rand(3,3)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述

    让这个张量在第0维度进行拼接,也就是在高这个维度进行拼接:

    C=torch.cat((A,B),dim=0)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3

    在这里插入图片描述
    可以看到高变成了3+3,宽不变

    让这个张量在第1维度进行拼接,也就是在宽这个维度进行拼接:

    C=torch.cat((A,B),dim=1)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3

    在这里插入图片描述
    可以看到,高不变,宽变成了2+2

    在第0维度拼接时,高可以不一样,但是宽需要一致,不然会报错:

    import torch
    A =torch.rand(3,3)  #单通道,高为3.宽为2的张量
    B=torch.rand(4,3)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    C=torch.cat((A,B),dim=0)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    不报错:
    在这里插入图片描述

    import torch
    A =torch.rand(3,3)  #单通道,高为3.宽为2的张量
    B=torch.rand(3,5)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    C=torch.cat((A,B),dim=0)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    直接报错:
    在这里插入图片描述
    在第1维度拼接时,高必须一致,宽可以不一样,不然会报错:

    import torch
    A =torch.rand(3,3)  #单通道,高为3.宽为2的张量
    B=torch.rand(3,5)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    C=torch.cat((A,B),dim=1)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    不报错:
    在这里插入图片描述

    import torch
    A =torch.rand(3,3)  #单通道,高为3.宽为2的张量
    B=torch.rand(4,3)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    C=torch.cat((A,B),dim=1)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    在这里插入图片描述

    二、torch.cat() 拼接存在c,h,w(通道,高,宽)的图像

    我们随机生成两个3通道的2X2图像

    import torch
    A =torch.rand(3,2,2)  #单通道,高为3.宽为2的张量
    B=torch.rand(3,2,2)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述
    在这里插入图片描述

    让他们在第0维度进行拼接(通道维度拼接):
    在这里插入图片描述
    相当于通道数堆叠了,变成了六个通道

    让他们在第1维度进行拼接(高维度拼接):
    在这里插入图片描述
    让他们在第2维度进行拼接(宽维度拼接):
    在这里插入图片描述
    这两个堆叠结果就和之前的方法一样了

    三、torch.add()使张量对应元素直接相加

    import torch
    A =torch.rand(3,2,2)  #单通道,高为3.宽为2的张量
    B=torch.rand(3,2,2)   #单通道,高为2.宽为3的张量
    print(A)
    print(B)
    C=torch.add(A,B)
    print(C)
    print(C.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    张量A:
    在这里插入图片描述
    张量B:
    在这里插入图片描述
    相加后张量:
    在这里插入图片描述
    当然也可以不用add(A,B) 用A+B

  • 相关阅读:
    [springboot源码分析]-Conditional
    十六、Java 数组
    开源组件 | 一款好用的小程序生成图片库
    Java序列化和Json格式的转化
    AndroidStudio中虚拟机(AVD)无法启动,出现unable to locate adb错误
    SQLite 3.43.0 发布,又有啥新功能?
    [Servlet 4]Bean与DAO设计模式
    红外特征吸收峰特征总结(主要基团的红外特征吸收峰)
    Python工程师Java之路(p)Module和Package
    mybatis
  • 原文地址:https://blog.csdn.net/weixin_46274756/article/details/127870206