• 【看不懂才怪系列】一套通俗的基于Pytorch的网络训练代码模板


    俗话说得好:“如果看不懂别人的代码,那一定是别人的代码写得不好!”
    既然公开了代码就要对代码的可读性和正确性负责,但是看了很多源码,总是会遇到一些坑,比如:环境配置没讲清楚 -> (好不容易跑通了) 效果并没有论文中说的那么好-> (想分析代码吧)代码写得乱七八糟看不懂。。。

    为了让小白能够快速上手,代码肯定需要由总到分叙述,就像讲故事一样,循序渐进才能够一目了然。下面的模板是我根据自己的理解总结的网络训练最基本的模板,大家根据自己需要再添加:

    1. 构造主函数
    if __name__ == "__main__":
        # 使用GPU训练
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        # 模型训练的基本预设参数
        batch_size = 16 # Batch Size
        num_epochs = 100 # 训练迭代次数
        learning_rate = 0.0001 # 学习率
        root = 'data/train' # 数据集位置
        
        ##### 导入数据集
     	train_loader, val_loader = get_loader(root, batch_size, shuffle=True)
    
        # 搭建模型框架
        model = Model(device).to(device)
        
        # 开始训练
        train(device, model, num_epochs, learning_rate, train_loader, val_loader)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    下面将具体讲解:

    使用GPU训练
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    • 1
    模型训练的基本预设参数
        batch_size = 16 # Batch Size
        num_epochs = 100 # 训练迭代次数
        learning_rate = 0.0001 # 学习率
        root = 'data/train' # 数据集位置
    
    • 1
    • 2
    • 3
    • 4
    导入数据集
     train_loader, val_loader = get_loader(root, batch_size, shuffle=True)
    
    • 1
    搭建模型框架
        model = Model(device).to(device)
    
    • 1
    开始训练
        train(device, model, num_epochs, learning_rate, train_loader, val_loader)
    
    • 1
    2.1 模型搭建

    编码器和解码器

    class Model(nn.Module):
        def __init__(self, device):
            super(Model, self).__init__()#super的目的在于继承nn.Module并使用__init__初始化了nn.Module里的参数
            self.encoder = Encoder(device) # 编码器
            self.decoder = Decoder(device) # 解码器
    
        def forward(self, x):
            x = x.float()
            p1 = self.encoder(x)
            p2 = self.decoder(p1)
            return p2
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    2.2 开始训练
    def train(device, model, num_epochs, learning_rate, train_loader, val_loader):
        print('start training ...........')
        #优化器设置
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        #损失函数设置
        criterion = LossCalculation(device)
        train_losses, val_losses = [], []
        
        #迭代训练
        for epoch in range(epochs):
        	#训练集
            train_epoch_loss = fit(epoch, model, optimizer, criterion, device, train_loader, phase='training')
            #测试集
            val_epoch_loss = fit(epoch, model, optimizer, criterion, device, val_loader, phase='validation')
            print('-----------------------------------------')
    		
    		#保存最优的训练参数
            if epoch == 0 or val_epoch_loss <= np.min(val_losses):
                torch.save(model.state_dict(), 'output/weight.pth')
    		
    		#保存训练和测试的loss结果,为画图做准备
            train_losses.append(train_epoch_loss)
            val_losses.append(val_epoch_loss)
    
    		#绘制结果图
            write_figures('output', train_losses, val_losses)
            write_log('output', epoch, train_epoch_loss, val_epoch_loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    2.2.1 优化器设置

    根据自己需求设置Adam或SGD,学习率可以设置固定或自适应学习率

     optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    • 1
    2.2.2 损失函数设置
        criterion = LossCalculation(device)
    
    • 1
    class LossCalculation(nn.Module):
        def __init__(self, device):
            super(LossCalculation, self).__init__()
    
        def forward(self, inputs, outputs, targets):
            #inputs, outputs, targets 分别是输入图像、预测图像和label
            batch_size, _, width, height = outputs.shape
            total_loss = 0.0 # 一个epoch的损失函数计算
            
            for b in range(batch_size):
            	#根据自己需要添加
            	total_loss += xxx
    
            return total_loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    2.2.3 迭代训练

    这里面最重要的就是自定义的fit函数

    def fit(epoch, model, optimizer, criterion, device, data_loader, phase='training'):
        if phase == 'training':
        	#启用一些特定的层(BN,Dropout),设置为训练状态
            model.train()
        else:
        	#禁用一些特定的层(BN,Dropout),设置为测试状态
            model.eval()
    
        running_loss = 0
    
        for inputs, targets in tqdm(data_loader):
        	#我们使用gpu训练,那图像也必须输入到显存中
            inputs = inputs.to(device)
            targets = targets.to(device)
    
            if phase == 'training':
            	#每一batch图像的梯度初始化为零,把loss涉及的权重的导数变成0
                optimizer.zero_grad()
                #输出预测图像
                outputs = model(inputs)
            else:
            	#强制之后的内容不进行计算图构建及梯度计算
                with torch.no_grad():
                    outputs = model(inputs)
    
            # 计算一个epoch中的一个batch的loss
            loss = criterion(inputs, outputs, targets, separate_loss=False)
            #累加一个epoch中的loss
            running_loss += loss.item()
    
            if phase == 'training':
            	#loss回传
                loss.backward()
                # 更新所有的参数
                optimizer.step()
                
    	#计算一个epoch的loss
        epoch_loss = running_loss / len(data_loader.dataset)#除以训练集或测试集图片数目
    	
    	#打印一个epoch的训练结果
        print('[%d][%s] loss: %.4f' % (epoch, phase, epoch_loss))
        
        return epoch_loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    完结,相互交流!
  • 相关阅读:
    「OSS中间件系列」Minio的文件服务的存储模型及整合Java客户端访问的实战指南
    SQL Server中row_number函数用法介绍
    【JVM笔记】导出内存映像(dump)文件
    Django——视图层
    vue+element-ui el-descriptions 详情渲染组件二次封装(Vue项目)
    js变量的声明带var与不带的区别
    【玩儿】Win 11 安装安卓子系统
    数仓学习之DWD学习
    设计模式:干掉if else的几种方法
    数据库基础知识(面试)
  • 原文地址:https://blog.csdn.net/qq_41598072/article/details/128039202