• pytorch.反向传播算法和优化器


    在训练神经网络时,最常见的算法就是反向传播

    为了支持反向传播,pytorch有一个内置的分类引擎,叫做TORCH.AUTOGRAD

    import torch
    
    x = torch.ones(5)  # input tensor
    y = torch.zeros(3)  # expected output
    w = torch.randn(5, 3, requires_grad=True)  # 如果需要反向传播就打开这个参数
    b = torch.randn(3, requires_grad=True)  # 如果需要反向传播就打开这个参数
    z = torch.matmul(x, w) + b
    # print(f"x:{x}\ny:{y}\nw:{w}\nb:{b}\nz:{z}\n")
    
    loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
    # print(loss)
    
    print(f"Gradient function for z = {z.grad_fn}")
    print(f"Gradient function for loss = {loss.grad_fn}")
    
    loss.backward()
    print(w.grad)
    print(b.grad)
    
    # 怎么样把反向传播的选项关掉呢,为什么要关掉,当我们训练好了网络,不在需要训练网络时,可以关掉
    z = torch.matmul(x, w) + b
    print(z.requires_grad)
    
    with torch.no_grad():
        z = torch.matmul(x, w) + b
    print(z.requires_grad)
    
    # 也可以对tensor下手
    z = torch.matmul(x, w) + b
    z_det = z.detach()
    print(z_det.requires_grad)
    
    inp = torch.eye(5, requires_grad=True)
    out = (inp+1).pow(2)
    out.backward(torch.ones_like(inp), retain_graph=True)
    print(f"First call\n{inp.grad}")
    out.backward(torch.ones_like(inp), retain_graph=True)
    print(f"\nSecond call\n{inp.grad}")
    inp.grad.zero_()
    out.backward(torch.ones_like(inp), retain_graph=True)
    print(f"\nCall after zeroing gradients\n{inp.grad}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    我们首先需要创建一个神经网络,并导入一些训练数据

    import torch
    from torch import nn
    from torch.utils.data import DataLoader
    from torchvision import datasets
    from torchvision.transforms import ToTensor, Lambda
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )
    
    train_dataloader = DataLoader(training_data, batch_size=64)
    test_dataloader = DataLoader(test_data, batch_size=64)
    
    
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28 * 28, 512),
                nn.ReLU(),
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Linear(512, 10),
            )
    
        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits
    
    
    model = NeuralNetwork().to(device)
    print(model)
    
    # 学习率
    learning_rate = 1e-3
    # 每次导入的数据量
    batch_size = 64
    # 训练轮数
    epochs = 5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    常见的损失函数包括nn.MSELoss(均方误差)

    nn.NLLLoss(负对数,用于分类问题)

    nn.CrossEntropyLoss(结合了softmax和NULLLoss)

    优化器Optimizer

    优化是每次训练过程中调整参数减小模型误差的过程,优化算法定义了这个过程是怎么进行的,我们使用梯度下降法

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    def train_loop(dataloader, model, loss_fn, optimizer):
        size = len(dataloader.dataset)
        for batch, (X, y) in enumerate(dataloader):
            X = X.to(device)
            y = y.to(device)
            # Compute prediction and loss
            pred = model(X)
            loss = loss_fn(pred, y)
    
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    
    def test_loop(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        test_loss, correct = 0, 0
    
        with torch.no_grad():
            for X, y in dataloader:
                X = X.to(device)
                y = y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
    # 损失函数
    loss_fn = nn.CrossEntropyLoss()
    # 优化器
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)
    print("Done!")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
  • 相关阅读:
    kafka启动报错
    【红黑树】都这样讲了,不会还有人不会红黑树吧
    Unity(Android)——实现手机摇杆和自由移动
    Springboot整合Websocket
    OpenCV的应用——快递二维码识别
    ViewModel 源码设计思路分析
    从新手到高手:Scala函数式编程完全指南,Scala 数据类型(4)
    Go语言将string解析为time.Time时两种常见报错
    最新CleanMyMac X4.12.1中文版Mac系统优化清理工具
    nginx-配置拆分(各个模块详细说明)
  • 原文地址:https://blog.csdn.net/weixin_43903639/article/details/126907969