torch.no_grad()方法就像一个循环,其中循环中的每个张量都将requires_grad设置为False。这意味着,当前与当前计算图相连的具有梯度的张量现在与当前图分离了我们将不再能够计算关于该张量的梯度。直到张量在循环内,它才与当前图分离。一旦用梯度定义的张量脱离了循环,它就会再次附着到当前图上。此方法禁用梯度计算,从而减少计算的内存消耗。
示例:在这个例子中,我们将用requires_grad=true定义一个张量a,然后我们将使用张量a在torch.no_grad()中定义一个函数B。现在张量a在循环中,所以requires_grad被设置为false。
- # Python3
- # import necessary libraries
- import torch
-
- # define a tensor
- A = torch.tensor(1., requires_grad=True)
- print("Tensor-A:", A)
-
- # define a function using A tensor
- # inside loop
- with torch.no_grad():
- B = A + 1
- print("B:-", B)
-
- # check gradient
- print("B.requires_grad=", B.requires_grad)
OUTPUT
- Tensor-A: tensor(1., requires_grad=True)
- B:- tensor(2.)
- B.requires_grad= False