• pytorch_lightning模型训练加速技巧与涨点技巧


    pytorch-lightning 是建立在pytorch之上的高层次模型接口。

    pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow.

    pytorch-lightning 有以下一些引人注目的功能:

    • 可以不必编写自定义循环,只要指定loss计算方法即可。

    • 可以通过callbacks非常方便地添加CheckPoint参数保存、early_stopping 等功能。

    • 可以非常方便地在单CPU、多CPU、单GPU、多GPU乃至多TPU上训练模型。

    • 可以通过调用torchmetrics库,非常方便地添加Accuracy,AUC,Precision等各种常用评估指标。

    • 可以非常方便地实施多批次梯度累加、半精度混合精度训练、最大batch_size自动搜索等技巧,加快训练过程。

    • 可以非常方便地使用SWA(随机参数平均)、CyclicLR(学习率周期性调度策略)与auto_lr_find(最优学习率发现)等技巧 实现模型涨点。

    一般按照如下方式 安装和 引入 pytorch-lightning 库。

    1. #安装
    2. pip install pytorch-lightning
    1. #引入
    2. import pytorch_lightning as pl

    顾名思义,它可以帮助我们漂亮(pl)地进行深度学习研究。😋😋 You do the research. Lightning will do everything else.⭐️⭐️

    参考文档:

    • pl_docs: https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html

    • pl_template:https://github.com/PyTorchLightning/deep-learning-project-template

    • torchmetrics: https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html

    公众号后台回复关键词:pl,获取本文jupyter notebook源代码。

    一,pytorch-lightning的设计哲学

    pytorch-lightning 的核心设计哲学是将 深度学习项目中的 研究代码(定义模型) 和 工程代码 (训练模型) 相互分离。

    用户只需专注于研究代码(pl.LightningModule)的实现,而工程代码借助训练工具类(pl.Trainer)统一实现。

    更详细地说,深度学习项目代码可以分成如下4部分:

    • 研究代码 (Research code),用户继承LightningModule实现。

    • 工程代码 (Engineering code),用户无需关注通过调用Trainer实现。

    • 非必要代码 (Non-essential research code,logging, etc...),用户通过调用Callbacks实现。

    • 数据 (Data),用户通过torch.utils.data.DataLoader实现,也可以封装成pl.LightningDataModule。

    二,pytorch-lightning使用范例

    下面我们使用minist图片分类问题为例,演示pytorch-lightning的最佳实践。

    1,准备数据

    1. import torch 
    2. from torch import nn 
    3. from torchvision import transforms as T
    4. from torchvision.datasets import MNIST
    5. from torch.utils.data import DataLoader,random_split
    6. import pytorch_lightning as pl 
    7. from torchmetrics import Accuracy
    1. class MNISTDataModule(pl.LightningDataModule):
    2.     def __init__(self, data_dir: str = "./minist/"
    3.                  batch_size: int = 32,
    4.                  num_workers: int =4):
    5.         super().__init__()
    6.         self.data_dir = data_dir
    7.         self.batch_size = batch_size
    8.         self.num_workers = num_workers
    9.     def setup(self, stage = None):
    10.         transform = T.Compose([T.ToTensor()])
    11.         self.ds_test = MNIST(self.data_dir, train=False,transform=transform,download=True)
    12.         self.ds_predict = MNIST(self.data_dir, train=False,transform=transform,download=True)
    13.         ds_full = MNIST(self.data_dir, train=True,transform=transform,download=True)
    14.         self.ds_train, self.ds_val = random_split(ds_full, [550005000])
    15.     def train_dataloader(self):
    16.         return DataLoader(self.ds_train, batch_size=self.batch_size,
    17.                           shuffle=True, num_workers=self.num_workers,
    18.                           pin_memory=True)
    19.     def val_dataloader(self):
    20.         return DataLoader(self.ds_val, batch_size=self.batch_size,
    21.                           shuffle=False, num_workers=self.num_workers,
    22.                           pin_memory=True)
    23.     def test_dataloader(self):
    24.         return DataLoader(self.ds_test, batch_size=self.batch_size,
    25.                           shuffle=False, num_workers=self.num_workers,
    26.                           pin_memory=True)
    27.     def predict_dataloader(self):
    28.         return DataLoader(self.ds_predict, batch_size=self.batch_size,
    29.                           shuffle=False, num_workers=self.num_workers,
    30.                           pin_memory=True)
    1. data_mnist = MNISTDataModule()
    2. data_mnist.setup()
    1. for features,labels in data_mnist.train_dataloader():
    2.     print(features.shape)
    3.     print(labels.shape)
    4.     break
    1. torch.Size([3212828])
    2. torch.Size([32])

    2,定义模型

    1. net = nn.Sequential(
    2.     nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
    3.     nn.MaxPool2d(kernel_size = 2,stride = 2),
    4.     nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
    5.     nn.MaxPool2d(kernel_size = 2,stride = 2),
    6.     nn.Dropout2d(p = 0.1),
    7.     nn.AdaptiveMaxPool2d((1,1)),
    8.     nn.Flatten(),
    9.     nn.Linear(64,32),
    10.     nn.ReLU(),
    11.     nn.Linear(32,10)
    12. )
    13. class Model(pl.LightningModule):
    14.     
    15.     def __init__(self,net,learning_rate=1e-3):
    16.         super().__init__()
    17.         self.save_hyperparameters()
    18.         self.net = net
    19.         self.train_acc = Accuracy()
    20.         self.val_acc = Accuracy()
    21.         self.test_acc = Accuracy() 
    22.         
    23.         
    24.     def forward(self,x):
    25.         x = self.net(x)
    26.         return x
    27.     
    28.     
    29.     #定义loss
    30.     def training_step(self, batch, batch_idx):
    31.         x, y = batch
    32.         preds = self(x)
    33.         loss = nn.CrossEntropyLoss()(preds,y)
    34.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    35.     
    36.     #定义各种metrics
    37.     def training_step_end(self,outputs):
    38.         train_acc = self.train_acc(outputs['preds'], outputs['y']).item()    
    39.         self.log("train_acc",train_acc,prog_bar=True)
    40.         return {"loss":outputs["loss"].mean()}
    41.     
    42.     #定义optimizer,以及可选的lr_scheduler
    43.     def configure_optimizers(self):
    44.         return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    45.     
    46.     def validation_step(self, batch, batch_idx):
    47.         x, y = batch
    48.         preds = self(x)
    49.         loss = nn.CrossEntropyLoss()(preds,y)
    50.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    51.     def validation_step_end(self,outputs):
    52.         val_acc = self.val_acc(outputs['preds'], outputs['y']).item()    
    53.         self.log("val_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)
    54.         self.log("val_acc",val_acc,prog_bar=True,on_epoch=True,on_step=False)
    55.     
    56.     def test_step(self, batch, batch_idx):
    57.         x, y = batch
    58.         preds = self(x)
    59.         loss = nn.CrossEntropyLoss()(preds,y)
    60.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    61.     
    62.     def test_step_end(self,outputs):
    63.         test_acc = self.test_acc(outputs['preds'], outputs['y']).item()    
    64.         self.log("test_acc",test_acc,on_epoch=True,on_step=False)
    65.         self.log("test_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)
    66.     
    67. model = Model(net)
    68. #查看模型大小
    69. model_size = pl.utilities.memory.get_model_size_mb(model)
    70. print("model_size = {} M \n".format(model_size))
    71. model.example_input_array = [features]
    72. summary = pl.utilities.model_summary.ModelSummary(model,max_depth=-1)
    73. print(summary)
    1. model_size = 0.218447 M 
    2.    | Name      | Type              | Params | In sizes         | Out sizes       
    3. ---------------------------------------------------------------------------------------
    4. 0  | net       | Sequential        | 54.0 K | [3212828]  | [3210]        
    5. 1  | net.0     | Conv2d            | 320    | [3212828]  | [32322626]
    6. 2  | net.1     | MaxPool2d         | 0      | [32322626] | [32321313]
    7. 3  | net.2     | Conv2d            | 51.3 K | [32321313] | [326499]  
    8. 4  | net.3     | MaxPool2d         | 0      | [326499]   | [326444]  
    9. 5  | net.4     | Dropout2d         | 0      | [326444]   | [326444]  
    10. 6  | net.5     | AdaptiveMaxPool2d | 0      | [326444]   | [326411]  
    11. 7  | net.6     | Flatten           | 0      | [326411]   | [3264]        
    12. 8  | net.7     | Linear            | 2.1 K  | [3264]         | [3232]        
    13. 9  | net.8     | ReLU              | 0      | [3232]         | [3232]        
    14. 10 | net.9     | Linear            | 330    | [3232]         | [3210]        
    15. 11 | train_acc | Accuracy          | 0      | ?                | ?               
    16. 12 | val_acc   | Accuracy          | 0      | ?                | ?               
    17. 13 | test_acc  | Accuracy          | 0      | ?                | ?               
    18. ---------------------------------------------------------------------------------------
    19. 54.0 K    Trainable params
    20. 0         Non-trainable params
    21. 54.0 K    Total params
    22. 0.216     Total estimated model params size (MB)

    3,训练模型

    1. pl.seed_everything(1234)
    2. ckpt_callback = pl.callbacks.ModelCheckpoint(
    3.     monitor='val_loss',
    4.     save_top_k=1,
    5.     mode='min'
    6. )
    7. early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
    8.                patience=3,
    9.                mode = 'min')
    10. # gpus=0 则使用cpu训练,gpus=1则使用1个gpu训练,gpus=2则使用2个gpu训练,gpus=-1则使用所有gpu训练,
    11. # gpus=[0,1]则指定使用0号和1号gpu训练, gpus="0,1,2,3"则使用0,1,2,3号gpu训练
    12. # tpus=1 则使用1个tpu训练
    13. trainer = pl.Trainer(max_epochs=20,   
    14.      #gpus=0, #单CPU模式
    15.      gpus=0, #单GPU模式
    16.      #num_processes=4,strategy="ddp_find_unused_parameters_false", #多CPU(进程)模式
    17.      #gpus=[0,1,2,3],strategy="dp", #多GPU的DataParallel(速度提升效果一般)
    18.      #gpus=[0,1,2,3],strategy=“ddp_find_unused_parameters_false" #多GPU的DistributedDataParallel(速度提升效果好)
    19.      callbacks = [ckpt_callback,early_stopping],
    20.      profiler="simple") 
    21. #断点续训
    22. #trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')
    23. #训练模型
    24. trainer.fit(model,data_mnist)
    1. Epoch 8100%
    2. 1876/1876 [01:44<00:0017.93it/s, loss=0.0603, v_num=0, train_acc=1.000, val_acc=0.985]

    4,评估模型

    result = trainer.test(model,data_mnist.train_dataloader(),ckpt_path='best')
    1. --------------------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9966545701026917'test_loss'0.010617421939969063}
    4. --------------------------------------------------------------------------------
    result = trainer.test(model,data_mnist.val_dataloader(),ckpt_path='best')
    1. --------------------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9865999817848206'test_loss'0.042671505361795425}
    4. --------------------------------------------------------------------------------
    result = trainer.test(model,data_mnist.test_dataloader(),ckpt_path='best')
    1. --------------------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.987500011920929'test_loss'0.047178059816360474}
    4. --------------------------------------------------------------------------------

    5,使用模型

    1. data,label = next(iter(data_module.test_dataloader()))
    2. model.eval()
    3. prediction = model(data)
    4. print(prediction)
    1. tensor([[-13.0112,  -2.8257,  -1.8588,  -3.6137,  -0.3307,  -5.4953-19.7282,
    2.           15.9651,  -8.0379,  -2.2925],
    3.         [ -6.0261,  -2.5480,  13.4140,  -5.5701-10.2049,  -6.4469,  -3.7119,
    4.           -6.0732,  -6.0826,  -7.7339],
    5.           ...
    6.         [-16.7028,  -4.9060,   0.4400,  24.4337-12.8793,   1.5085-17.9232,
    7.           -3.0839,   0.5491,   1.9846],
    8.         [ -5.0909,  10.1805,  -8.2528,  -9.2240,  -1.8044,  -4.0296,  -8.2297,
    9.           -3.1828,  -5.9361,  -4.8410]], grad_fn=)

    6,保存模型

    最优模型默认保存在 trainer.checkpoint_callback.best_model_path 的目录下,可以直接加载。

    1. print(trainer.checkpoint_callback.best_model_path)
    2. print(trainer.checkpoint_callback.best_model_score)
    1. lightning_logs/version_10/checkpoints/epoch=8-step=15470.ckpt
    2. tensor(0.0376, device='cuda:0')
    1. model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    2. trainer_clone = pl.Trainer(max_epochs=3,gpus=1
    3. result = trainer_clone.test(model_clone,data_module.test_dataloader())
    4. print(result)
    1. --------------------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9887999892234802'test_loss'0.03627564385533333}
    4. --------------------------------------------------------------------------------
    5. [{'test_acc'0.9887999892234802'test_loss'0.03627564385533333}]

    三,训练加速技巧

    下面重点介绍pytorch_lightning 模型训练加速的一些技巧。

    • 1,使用多进程读取数据(num_workers=4)

    • 2,使用锁业内存(pin_memory=True)

    • 3,使用加速器(gpus=4,strategy="ddp_find_unused_parameters_false")

    • 4,使用梯度累加(accumulate_grad_batches=6)

    • 5,使用半精度(precision=16,batch_size=2*batch_size)

    • 6,自动搜索最大batch_size(auto_scale_batch_size='binsearch')

    (注:过大的batch_size对模型学习是有害的。)

    详细原理,可以参考:

    https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html

    我们将训练代码封装成如下脚本形式,方便后面测试使用。

    1. %%writefile mnist_cnn.py
    2. import torch 
    3. from torch import nn 
    4. from argparse import ArgumentParser
    5. import torchvision 
    6. from torchvision import transforms as T
    7. from torchvision.datasets import MNIST
    8. from torch.utils.data import DataLoader,random_split
    9. import pytorch_lightning as pl
    10. from torchmetrics import Accuracy
    11. #================================================================================
    12. # 一,准备数据
    13. #================================================================================
    14. class MNISTDataModule(pl.LightningDataModule):
    15.     def __init__(self, data_dir: str = "./minist/"
    16.                  batch_size: int = 32,
    17.                  num_workers: int =4,
    18.                  pin_memory:bool =True):
    19.         super().__init__()
    20.         self.data_dir = data_dir
    21.         self.batch_size = batch_size
    22.         self.num_workers = num_workers
    23.         self.pin_memory = pin_memory
    24.     def setup(self, stage = None):
    25.         transform = T.Compose([T.ToTensor()])
    26.         self.ds_test = MNIST(self.data_dir, download=True,train=False,transform=transform)
    27.         self.ds_predict = MNIST(self.data_dir, download=True, train=False,transform=transform)
    28.         ds_full = MNIST(self.data_dir, download=True, train=True,transform=transform)
    29.         self.ds_train, self.ds_val = random_split(ds_full, [550005000])
    30.     def train_dataloader(self):
    31.         return DataLoader(self.ds_train, batch_size=self.batch_size,
    32.                           shuffle=True, num_workers=self.num_workers,
    33.                           pin_memory=self.pin_memory)
    34.     def val_dataloader(self):
    35.         return DataLoader(self.ds_val, batch_size=self.batch_size,
    36.                           shuffle=False, num_workers=self.num_workers,
    37.                           pin_memory=self.pin_memory)
    38.     def test_dataloader(self):
    39.         return DataLoader(self.ds_test, batch_size=self.batch_size,
    40.                           shuffle=False, num_workers=self.num_workers,
    41.                           pin_memory=self.pin_memory)
    42.     def predict_dataloader(self):
    43.         return DataLoader(self.ds_predict, batch_size=self.batch_size,
    44.                           shuffle=False, num_workers=self.num_workers,
    45.                           pin_memory=self.pin_memory)
    46.     
    47.     @staticmethod
    48.     def add_dataset_args(parent_parser):
    49.         parser = ArgumentParser(parents=[parent_parser], add_help=False)
    50.         parser.add_argument('--batch_size'type=intdefault=32)
    51.         parser.add_argument('--num_workers'type=intdefault=4)
    52.         parser.add_argument('--pin_memory'type=booldefault=True)
    53.         return parser
    54. #================================================================================
    55. # 二,定义模型
    56. #================================================================================
    57. net = nn.Sequential(
    58.     nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
    59.     nn.MaxPool2d(kernel_size = 2,stride = 2),
    60.     nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
    61.     nn.MaxPool2d(kernel_size = 2,stride = 2),
    62.     nn.Dropout2d(p = 0.1),
    63.     nn.AdaptiveMaxPool2d((1,1)),
    64.     nn.Flatten(),
    65.     nn.Linear(64,32),
    66.     nn.ReLU(),
    67.     nn.Linear(32,10)
    68. )
    69. class Model(pl.LightningModule):
    70.     
    71.     def __init__(self,net,learning_rate=1e-3):
    72.         super().__init__()
    73.         self.save_hyperparameters()
    74.         self.net = net
    75.         self.train_acc = Accuracy()
    76.         self.val_acc = Accuracy()
    77.         self.test_acc = Accuracy() 
    78.         
    79.         
    80.     def forward(self,x):
    81.         x = self.net(x)
    82.         return x
    83.     
    84.     
    85.     #定义loss
    86.     def training_step(self, batch, batch_idx):
    87.         x, y = batch
    88.         preds = self(x)
    89.         loss = nn.CrossEntropyLoss()(preds,y)
    90.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    91.     
    92.     #定义各种metrics
    93.     def training_step_end(self,outputs):
    94.         train_acc = self.train_acc(outputs['preds'], outputs['y']).item()    
    95.         self.log("train_acc",train_acc,prog_bar=True)
    96.         return {"loss":outputs["loss"].mean()}
    97.     
    98.     #定义optimizer,以及可选的lr_scheduler
    99.     def configure_optimizers(self):
    100.         return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    101.     
    102.     def validation_step(self, batch, batch_idx):
    103.         x, y = batch
    104.         preds = self(x)
    105.         loss = nn.CrossEntropyLoss()(preds,y)
    106.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    107.     def validation_step_end(self,outputs):
    108.         val_acc = self.val_acc(outputs['preds'], outputs['y']).item()    
    109.         self.log("val_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)
    110.         self.log("val_acc",val_acc,prog_bar=True,on_epoch=True,on_step=False)
    111.     
    112.     def test_step(self, batch, batch_idx):
    113.         x, y = batch
    114.         preds = self(x)
    115.         loss = nn.CrossEntropyLoss()(preds,y)
    116.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    117.     
    118.     def test_step_end(self,outputs):
    119.         test_acc = self.test_acc(outputs['preds'], outputs['y']).item()    
    120.         self.log("test_acc",test_acc,on_epoch=True,on_step=False)
    121.         self.log("test_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)
    122.     
    123.     @staticmethod
    124.     def add_model_args(parent_parser):
    125.         parser = ArgumentParser(parents=[parent_parser], add_help=False)
    126.         parser.add_argument('--learning_rate'type=float, default=1e-3)
    127.         return parser
    128.     
    129. #================================================================================
    130. # 三,训练模型
    131. #================================================================================
    132.     
    133. def main(hparams):
    134.     pl.seed_everything(1234)
    135.     
    136.     data_mnist = MNISTDataModule(batch_size=hparams.batch_size,
    137.                                  num_workers=hparams.num_workers)
    138.     
    139.     model = Model(net,learning_rate=hparams.learning_rate)
    140.     
    141.     ckpt_callback = pl.callbacks.ModelCheckpoint(
    142.         monitor='val_loss',
    143.         save_top_k=1,
    144.         mode='min'
    145.     )
    146.     early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
    147.                    patience=3,
    148.                    mode = 'min')
    149.     
    150.     trainer = pl.Trainer.from_argparse_args( 
    151.         hparams,
    152.         max_epochs=10,
    153.         
    154.         callbacks = [ckpt_callback,early_stopping]
    155.     ) 
    156.     
    157.     
    158.     if hparams.auto_scale_batch_size is not None:
    159.         #搜索不发生OOM的最大batch_size
    160.         max_batch_size = trainer.tuner.scale_batch_size(model,data_mnist,
    161.                         mode=hparams.auto_scale_batch_size)
    162.         data_mnist.batch_size = max_batch_size
    163.         
    164.         #等价于
    165.         #trainer.tune(model,data_mnist)
    166.         
    167.     
    168.     #gpus=0, #单CPU模式
    169.     #gpus=1, #单GPU模式
    170.     #num_processes=4,strategy="ddp_find_unused_parameters_false", #多CPU(进程)模式
    171.     #gpus=4,strategy="dp", #多GPU(dp速度提升效果一般)
    172.     #gpus=4,strategy=“ddp_find_unused_parameters_false" #多GPU(ddp速度提升效果好)
    173.     trainer.fit(model,data_mnist)
    174.     result = trainer.test(model,data_mnist,ckpt_path='best')
    175. if __name__ == "__main__":
    176.     parser = ArgumentParser()
    177.     parser = MNISTDataModule.add_dataset_args(parser)
    178.     parser = Model.add_model_args(parser)
    179.     parser = pl.Trainer.add_argparse_args(parser)
    180.     hparams = parser.parse_args()
    181.     main(hparams)

    1,使用多进程读取数据(num_workers=4)

    使用多进程读取数据,可以避免数据加载过程成为性能瓶颈。

    • 单进程读取数据(num_workers=0, gpus=1): 1min 18s

    • 多进程读取数据(num_workers=4, gpus=1): 59.7s

    1. %%time
    2. #单进程读取数据(num_workers=0)
    3. !python3 mnist_cnn.py --num_workers=0 --gpus=1
    1. ------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9857000112533569'test_loss'0.04885349050164223}
    4. --------------------------------------------------------------------------------
    5. CPU times: user 4.67 s, sys: 2.14 s, total: 6.81 s
    6. Wall time: 2min 50s
    1. %%time
    2. #多进程读取数据(num_workers=4)
    3. !python3 mnist_cnn.py --num_workers=4 --gpus=1
    1. ---------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9764000177383423'test_loss'0.0820135846734047}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00163.40it/s]
    6. CPU times: user 1.56 s, sys: 647 ms, total: 2.21 s
    7. Wall time: 59.7 s

    2,使用锁业内存(pin_memory=True)

    锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘)

    因此锁业内存比非锁业内存读写效率更高,copy到GPU上也更快速。

    当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。

    因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。

    • 非锁业内存存储数据(pin_memory=False, gpus=1): 1min

    • 锁业内存存储数据(pin_memory=True, gpus=1): 59.5s

    1. %%time
    2. #非锁业内存存储数据(pin_memory=False)
    3. !python3 mnist_cnn.py --pin_memory=False --gpus=1
    1. ----------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9812999963760376'test_loss'0.06231774762272835}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00171.69it/s]
    6. CPU times: user 1.59 s, sys: 619 ms, total: 2.21 s
    7. Wall time: 1min
    1. %%time
    2. #锁业内存存储数据(pin_memory=True)
    3. !python3 mnist_cnn.py --pin_memory=True --gpus=1
    1. ---------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9757999777793884'test_loss'0.08017424494028091}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00174.58it/s]
    6. CPU times: user 1.54 s, sys: 677 ms, total: 2.22 s
    7. Wall time: 59.5 s

    3,使用加速器(gpus=4,strategy="ddp_find_unused_parameters_false")

    pl 可以很方便地应用单CPU、多CPU、单GPU、多GPU乃至多TPU上训练模型。

    以下几种情况训练耗时统计如下:

    • 单CPU: 2min 17s

    • 单GPU:  59.4 s

    • 4个GPU(dp模式): 1min

    • 4个GPU(ddp模式): 38.9 s

    一般情况下,如果是单机多卡,建议使用 ddp模式,因为dp模式需要非常多的data和model传输,非常耗时。

    1. %%time
    2. #单CPU
    3. !python3 mnist_cnn.py --gpus=0
    1. -----------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9790999889373779'test_loss'0.07223792374134064}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|████████████████████████████████| 313/313 [00:05<00:0055.95it/s]
    6. CPU times: user 2.67 s, sys: 740 ms, total: 3.41 s
    7. Wall time: 2min 17s
    1. %%time
    2. #单GPU
    3. !python3 mnist_cnn.py --gpus=1
    1. ---------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9778000116348267'test_loss'0.06929327547550201}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00171.04it/s]
    6. CPU times: user 1.83 s, sys: 488 ms, total: 2.32 s
    7. Wall time: 1min 3s
    1. %%time
    2. #多GPU,dp模式(为公平比较,batch_size=32*4
    3. !python3 mnist_cnn.py --gpus=4 --strategy="dp" --batch_size=128
    1. ------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9790999889373779'test_loss'0.06855566054582596}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|██████████████████████████████████| 79/79 [00:02<00:0038.55it/s]
    6. CPU times: user 1.2 s, sys: 553 ms, total: 1.75 s
    7. Wall time: 1min
    1. %%time
    2. #多GPU,ddp模式
    3. !python3 mnist_cnn.py --gpus=4 --strategy="ddp_find_unused_parameters_false"
    1. ---------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9732000231742859'test_loss'0.08606339246034622}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|██████████████████████████████████| 79/79 [00:00<00:0085.79it/s]
    6. CPU times: user 784 ms, sys: 387 ms, total: 1.17 s
    7. Wall time: 38.9 s

    4,使用梯度累加(accumulate_grad_batches=6)

    梯度累加就是累加多个batch的梯度,然后用累加的梯度更新一次参数,使用梯度累加相当于增大batch_size.

    由于更新参数的计算量略大于简单梯度求和的计算量(对于大部分优化器而言),使用梯度累加会让速度略有提升。

    • 4个GPU(ddp模式): 38.9 s

    • 4个GPU(ddp模式)+梯度累加: 36.9 s

    1. %%time
    2. #多GPU,ddp模式, 考虑梯度累加
    3. !python3 mnist_cnn.py --accumulate_grad_batches=6 --gpus=4 --strategy="ddp_find_unused_parameters_false"
    1. ----------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9603000283241272'test_loss'0.1400066614151001}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|██████████████████████████████████| 79/79 [00:00<00:0089.10it/s]
    6. CPU times: user 749 ms, sys: 402 ms, total: 1.15 s
    7. Wall time: 36.9 s

    5,使用半精度(precision=16)

    通过precision可以设置 double (64), float (32), bfloat16 ("bf16"), half (16) 精度的训练。

    默认是float(32) 标准精度,bfloat16 ("bf16")是混合精度。

    如果选择 half(16) 半精度,并同时增大batch_size为原来2倍, 通常训练速度会提升3倍左右。

    1. %%time 
    2. #半精度
    3. !python3 mnist_cnn.py --precision=16 --batch_size=64 --gpus=1

    6,自动搜索最大batch_size(auto_scale_batch_size="power")

    !python3 mnist_cnn.py --auto_scale_batch_size="power"  --gpus=1

    四,训练涨分技巧

    pytorch_lightning 可以非常容易地支持以下训练涨分技巧:

    • SWA(随机参数平均): 调用pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging实现。

    • CyclicLR(学习率周期性调度策略): 设置 lr_scheduler 为 torch.optim.lr_scheduler.CyclicLR实现。

    • auto_lr_find最优学习率发现: 设置 pl.Trainer(auto_lr_find = True)实现。

    参考论文:

    • Cyclical Learning Rates for Training Neural Networks 【https://arxiv.org/pdf/1506.01186.pdf】

    • Averaging Weights Leads to Wider Optima and Better Generalization【https://arxiv.org/abs/1803.05407】

    我们将代码整理成如下形式,以便后续测试使用。

    1. %%writefile mnist_cnn.py
    2. import torch 
    3. from torch import nn 
    4. from argparse import ArgumentParser
    5. import numpy as np 
    6. import torchvision 
    7. from torchvision import transforms as T
    8. from torchvision.datasets import MNIST
    9. from torch.utils.data import DataLoader,random_split
    10. import pytorch_lightning as pl
    11. from torchmetrics import Accuracy
    12. #================================================================================
    13. # 一,准备数据
    14. #================================================================================
    15. class MNISTDataModule(pl.LightningDataModule):
    16.     def __init__(self, data_dir: str = "./minist/"
    17.                  batch_size: int = 32,
    18.                  num_workers: int =4,
    19.                  pin_memory:bool =True):
    20.         super().__init__()
    21.         self.data_dir = data_dir
    22.         self.batch_size = batch_size
    23.         self.num_workers = num_workers
    24.         self.pin_memory = pin_memory
    25.     def setup(self, stage = None):
    26.         transform = T.Compose([T.ToTensor()])
    27.         self.ds_test = MNIST(self.data_dir, download=True,train=False,transform=transform)
    28.         self.ds_predict = MNIST(self.data_dir, download=True, train=False,transform=transform)
    29.         ds_full = MNIST(self.data_dir, download=True, train=True,transform=transform)
    30.         ds_train, self.ds_val = random_split(ds_full, [590001000])
    31.         #为加速训练,随机取10000
    32.         indices = np.arange(59000)
    33.         np.random.shuffle(indices)
    34.         self.ds_train = torch.utils.data.dataset.Subset(
    35.             ds_train,indices = indices[:3000]) 
    36.     def train_dataloader(self):
    37.         return DataLoader(self.ds_train, batch_size=self.batch_size,
    38.                           shuffle=True, num_workers=self.num_workers,
    39.                           pin_memory=self.pin_memory)
    40.     def val_dataloader(self):
    41.         return DataLoader(self.ds_val, batch_size=self.batch_size,
    42.                           shuffle=False, num_workers=self.num_workers,
    43.                           pin_memory=self.pin_memory)
    44.     def test_dataloader(self):
    45.         return DataLoader(self.ds_test, batch_size=self.batch_size,
    46.                           shuffle=False, num_workers=self.num_workers,
    47.                           pin_memory=self.pin_memory)
    48.     def predict_dataloader(self):
    49.         return DataLoader(self.ds_predict, batch_size=self.batch_size,
    50.                           shuffle=False, num_workers=self.num_workers,
    51.                           pin_memory=self.pin_memory)
    52.     
    53.     @staticmethod
    54.     def add_dataset_args(parent_parser):
    55.         parser = ArgumentParser(parents=[parent_parser], add_help=False)
    56.         parser.add_argument('--batch_size'type=intdefault=32)
    57.         parser.add_argument('--num_workers'type=intdefault=8)
    58.         parser.add_argument('--pin_memory'type=booldefault=True)
    59.         return parser
    60. #================================================================================
    61. # 二,定义模型
    62. #================================================================================
    63. net = nn.Sequential(
    64.     nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
    65.     nn.MaxPool2d(kernel_size = 2,stride = 2),
    66.     nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
    67.     nn.MaxPool2d(kernel_size = 2,stride = 2),
    68.     nn.Dropout2d(p = 0.1),
    69.     nn.AdaptiveMaxPool2d((1,1)),
    70.     nn.Flatten(),
    71.     nn.Linear(64,32),
    72.     nn.ReLU(),
    73.     nn.Linear(32,10)
    74. )
    75. class Model(pl.LightningModule):
    76.     
    77.     def __init__(self,net,
    78.                  learning_rate=1e-3,
    79.                  use_CyclicLR = False,
    80.                  epoch_size=500):
    81.         super().__init__()
    82.         self.save_hyperparameters() #自动创建self.hparams
    83.         self.net = net
    84.         self.train_acc = Accuracy()
    85.         self.val_acc = Accuracy()
    86.         self.test_acc = Accuracy() 
    87.         
    88.         
    89.     def forward(self,x):
    90.         x = self.net(x)
    91.         return x
    92.     
    93.     
    94.     #定义loss
    95.     def training_step(self, batch, batch_idx):
    96.         x, y = batch
    97.         preds = self(x)
    98.         loss = nn.CrossEntropyLoss()(preds,y)
    99.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    100.     
    101.     #定义各种metrics
    102.     def training_step_end(self,outputs):
    103.         train_acc = self.train_acc(outputs['preds'], outputs['y']).item()    
    104.         self.log("train_acc",train_acc,prog_bar=True)
    105.         return {"loss":outputs["loss"].mean()}
    106.     
    107.     #定义optimizer,以及可选的lr_scheduler
    108.     def configure_optimizers(self):
    109.         optimizer = torch.optim.RMSprop(self.parameters(), lr=self.hparams.learning_rate)
    110.         if not self.hparams.use_CyclicLR:
    111.             return optimizer 
    112.         max_lr = self.hparams.learning_rate
    113.         base_lr = max_lr/4.0
    114.         scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
    115.             base_lr=base_lr,max_lr=max_lr,
    116.             step_size_up=5*self.hparams.epoch_size,cycle_momentum=False)
    117.         self.print("set lr = "+str(max_lr))
    118.         
    119.         return ([optimizer],[scheduler])
    120.     
    121.     def validation_step(self, batch, batch_idx):
    122.         x, y = batch
    123.         preds = self(x)
    124.         loss = nn.CrossEntropyLoss()(preds,y)
    125.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    126.     def validation_step_end(self,outputs):
    127.         val_acc = self.val_acc(outputs['preds'], outputs['y']).item()    
    128.         self.log("val_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)
    129.         self.log("val_acc",val_acc,prog_bar=True,on_epoch=True,on_step=False)
    130.     
    131.     def test_step(self, batch, batch_idx):
    132.         x, y = batch
    133.         preds = self(x)
    134.         loss = nn.CrossEntropyLoss()(preds,y)
    135.         return {"loss":loss,"preds":preds.detach(),"y":y.detach()}
    136.     
    137.     def test_step_end(self,outputs):
    138.         test_acc = self.test_acc(outputs['preds'], outputs['y']).item()    
    139.         self.log("test_acc",test_acc,on_epoch=True,on_step=False)
    140.         self.log("test_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)
    141.     
    142.     @staticmethod
    143.     def add_model_args(parent_parser):
    144.         parser = ArgumentParser(parents=[parent_parser], add_help=False)
    145.         parser.add_argument('--learning_rate'type=float, default=7e-3)
    146.         parser.add_argument('--use_CyclicLR'type=booldefault=False)
    147.         return parser
    148.     
    149. #================================================================================
    150. # 三,训练模型
    151. #================================================================================
    152.     
    153. def main(hparams):
    154.     pl.seed_everything(1234)
    155.     
    156.     data_mnist = MNISTDataModule(batch_size=hparams.batch_size,
    157.                                  num_workers=hparams.num_workers)
    158.     data_mnist.setup()
    159.     epoch_size = len(data_mnist.ds_train)//data_mnist.batch_size
    160.     
    161.     model = Model(net,learning_rate=hparams.learning_rate,
    162.                   use_CyclicLR = hparams.use_CyclicLR,
    163.                   epoch_size=epoch_size)
    164.     
    165.     ckpt_callback = pl.callbacks.ModelCheckpoint(
    166.         monitor='val_acc',
    167.         save_top_k=3,
    168.         mode='max'
    169.     )
    170.     
    171.     early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_acc',
    172.                    patience=16,
    173.                    mode = 'max')
    174.     callbacks = [ckpt_callback,early_stopping]
    175.     if hparams.use_swa:
    176.         callbacks.append(pl.callbacks.StochasticWeightAveraging())
    177.         
    178.     trainer = pl.Trainer.from_argparse_args( 
    179.         hparams,
    180.         max_epochs=1000,
    181.         callbacks = callbacks) 
    182.     
    183.     print("hparams.auto_lr_find=",hparams.auto_lr_find)
    184.     if hparams.auto_lr_find:
    185.         
    186.         #搜索学习率范围
    187.         lr_finder = trainer.tuner.lr_find(model,
    188.           datamodule = data_mnist,
    189.           min_lr=1e-08,
    190.           max_lr=1,
    191.           num_training=100,
    192.           mode='exponential',
    193.           early_stop_threshold=4.0
    194.           )
    195.         lr_finder.plot() 
    196.         lr = lr_finder.suggestion()
    197.         model.hparams.learning_rate = lr 
    198.         print("suggest lr=",lr)
    199.         
    200.         del model 
    201.         
    202.         hparams.learning_rate = lr
    203.         model = Model(net,learning_rate=hparams.learning_rate,
    204.                   use_CyclicLR = hparams.use_CyclicLR,
    205.                   epoch_size=epoch_size)
    206.         
    207.         #等价于
    208.         #trainer.tune(model,data_mnist)
    209.         
    210.     trainer.fit(model,data_mnist)
    211.     train_result = trainer.test(model,data_mnist.train_dataloader(),ckpt_path='best')
    212.     val_result = trainer.test(model,data_mnist.val_dataloader(),ckpt_path='best')
    213.     test_result = trainer.test(model,data_mnist.test_dataloader(),ckpt_path='best')
    214.     
    215.     print("train_result:\n")
    216.     print(train_result)
    217.     print("val_result:\n")
    218.     print(val_result)
    219.     print("test_result:\n")
    220.     print(test_result)
    221.     
    222.     
    223. if __name__ == "__main__":
    224.     parser = ArgumentParser()
    225.     parser.add_argument('--use_swa'default=False, type=bool)
    226.     parser = MNISTDataModule.add_dataset_args(parser)
    227.     parser = Model.add_model_args(parser)
    228.     parser = pl.Trainer.add_argparse_args(parser)
    229.     hparams = parser.parse_args()
    230.     main(hparams)

    1,SWA 随机权重平均 (pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging)

    • 平凡方式训练:test_acc = 0.9581000208854675

    • SWA随机权重:test_acc = 0.963100016117096

    1. #平凡方式训练
    2. !python3 mnist_cnn.py --gpus=2 --strategy="ddp_find_unused_parameters_false"
    1. ------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9581000208854675'test_loss'0.14859822392463684}
    4. --------------------------------------------------------------------------------
    1. #使用SWA随机权重
    2. !python3 mnist_cnn.py --gpus=2 --strategy="ddp_find_unused_parameters_false" --use_swa=True
    1. -----------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.963100016117096'test_loss'0.18146753311157227}
    4. --------------------------------------------------------------------------------

    2,CyclicLR学习率调度策略(torch.optim.lr_scheduler.CyclicLR)

    • 平凡方式训练:test_acc = 0.9581000208854675

    • SWA随机权重:test_acc = 0.963100016117096

    • SWA随机权重 + CyClicLR学习率调度策略: test_acc = 0.9688000082969666

    !python3 mnist_cnn.py --gpus=2 --strategy="ddp_find_unused_parameters_false" --use_swa=True --use_CyclicLR=True
    1. ------------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9688000082969666'test_loss'0.11470437049865723}
    4. --------------------------------------------------------------------------------

    3, 最优学习率搜索(auto_lr_find=True)

    • 平凡方式训练:test_acc = 0.9581000208854675

    • SWA随机权重:test_acc = 0.963100016117096

    • SWA随机权重 + CyClicLR学习率调度策略: test_acc = 0.9688000082969666

    • SWA随机权重 + CyClicLR学习率调度策略 + 最优学习率搜索:test_acc = 0.9693999886512756

    !python3 mnist_cnn.py --gpus=1  --auto_lr_find=True --use_swa=True --use_CyclicLR=True
    1. ---------------------------------------------------------------
    2. DATALOADER:0 TEST RESULTS
    3. {'test_acc'0.9693999886512756'test_loss'0.11024412512779236}
    4. --------------------------------------------------------------------------------
    5. Testing: 100%|███████████████████████████████| 313/313 [00:02<00:00137.85it/s]

    以上。

    万水千山总是情,点个在看行不行?😋

    公众号后台回复关键词:pl,获取本文jupyter notebook源代码。

    04fcaeb92eb3c446af1697f3d653408b.png

  • 相关阅读:
    数智未来 持续创新 | 易趋受邀出席CIAS 2022中国数智汽车峰会
    关于视频超分辨率的学习-day1
    Python输入输出、遍历文件夹(input、stdin、os.path)
    盒子模型详解
    数据结构六:线性表之顺序栈的设计
    第三届 “鹏城杯”(初赛)
    Java专题训练——21天学习挑战赛
    v73.结构
    企业自己申报高企会遇到哪些问题,如何处理?
    win10 环境下Python 3.8按装fastapi paddlepaddle 进行身份证及营业执照的识别2
  • 原文地址:https://blog.csdn.net/Python_Ai_Road/article/details/126446693