• pytorch中nn.functional和nn.Module的区别


    在pytorch中有两个常用的模块,分别是nn.functionalnn.Module,这二者能够实现一些网络层的定义,对于nn中大多数的layer,在nn.functional中都有对应的函数算子

    这二者但也并非相同,还是存在区别的:

    • nn.Module中实现的layer是一个特殊的类,继承了nn.Module,例如 class Layer(nn.Module),对于继承了nn.Module的层结构,他们能够自动提取可学习参数,并且内部已经实现好了forward函数。
    • nn.functional实现的layer是一个函数,可以直接调用,不需要实例化,对于这些函数内部是没有可学习参数的。
    x = torch.randn(2, 3)
    
    y_1 = nn.ReLU()(x)
    y_2 = nn.functional.relu(x)
    
    • 1
    • 2
    • 3
    • 4

    那么为什么同样功能要设计两个接口呢?

    对于模型具有可学习参数,例如Conv2d、Linear等,最好使用nn.Module,因为继承了nn.Module能够自动提取可学习参数,也可以使用nn.functional来实现,但是这样较为复杂,需要自己手动设置参数Parameter然后传入。

    如果模型不具备可学习参数,例如ReLU、MaxPool2d等,使用nn.functional和nn.Module都可以。

    但是有特例,nn.Dropout,最好使用nn.Module,虽然它没有可学习参数,但是这个层有个特点就是训练和推理不同,如果使用nn.Module来实现,这时就可以使用model.train()和model.eval()来区分。

    x = torch.randn(10, 4)
    w = torch.randn(3, 4)
    b = torch.randn(1, 3)
    
    y = F.linear(x, w, b)
    print(y.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    需要手动定义参数Parameter来实现,还有一种写法:

    class Linear(nn.Module):
        def __init__(self):
            super().__init__()
            self.w = torch.randn(3, 4)
            self.b = torch.randn(1, 3)
        
        def forward(self, x):
            return F.linear(x, w, b)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    这两种写法是一样的,对于刚入门一般都是采用下面写法,利用定义好的层,采用自定义参数写法来说能够更灵活,能够实现更为复杂的操作。

    class Linear(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(4, 3)
        
        def forward(self, x):
            return self.linear(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
  • 相关阅读:
    数据结构-二叉树的基本操作
    如何判断结构体是否相等?能否用 memcmp 函数判断结构体相等?
    Servlet
    LeetCode 0142.环形链表 II
    Java中什么是多态?多态的优势和劣势是什么?
    Redis 异常三连环
    主流开发语言和开发环境介绍
    学完 Fluent 官方基础教程,你离一名合格Fluent 流体工程师还有多远?
    【数据分析实战】kaggle项目:bike sharing demand
    vue2入门--->非单文件组件(html直接使用组件)
  • 原文地址:https://blog.csdn.net/m0_47256162/article/details/127829709