• 深度补偿模型sparse-to-dense测试


    原始链接

    GitHub - fangchangma/sparse-to-dense.pytorch: ICRA 2018 "Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image" (PyTorch Implementation)

     1.数据说明

            第一列是场景原图,第二列是稀疏数据,第三列是稠密数据,第四列是该模型的预测结果

    该模型有三种训练模式:1.第一列RGB作为图像输入,第三列作为标签;2.第一列和第二列合并为4通道数据作为图像输入,第三列作为标签;3.第二列作为输入,第三列作为标签。

    其中第二列是通过第三列采样得到。采样方式有两种,在dataloaders/dense_to_sparse.py脚本中,如果要跑自己的数据集那么需要准备的是第一列和第三列数据。

     这个数据有30G左右,比较大,我下载了,网盘链接链接:https://pan.baidu.com/s/1SzQhDVZBJSy9gnMkr4UCMQ 
    提取码:tpql 
    --来自百度网盘超级会员V6的分享

    解压后如下:

             这里的h5文件其实就是数据的一种存储形式而已,内部结构和字典类似,在代码里有加载的函数(dataloaders/dataloader.py脚本中的h5_loader函数),包含了rgb和对应的depth数据。

    2.跑val数据(需要输入标签,rgbd模式)

            该项目通过设置evaluate模式来做评价,下载好数据、模型后,直接创建数据文件夹data放解压的数据即可。然后命令行输入python main.py --evaluate model_best.pth会自动创建结果文件夹results,以及产生一个拼接的长图comparison_7.png

     3.自建test数据测试(无需输入标签,rgbd模式)

             这个需要得到和val相同的效果,并且批量跑的情况下,需要在val的基础上改,改的地方比较多,下面我把改了的脚本都贴出来。

            (1)数据加载部分

            nyu_dataloader.py

    1. import numpy as np
    2. import dataloaders.transforms as transforms
    3. from dataloaders.dataloader import MyDataloader
    4. iheight, iwidth = 480, 640 # raw image size
    5. class NYUDataset(MyDataloader):
    6. def __init__(self, root, type, sparsifier=None, modality='rgb'):
    7. super(NYUDataset, self).__init__(root, type, sparsifier, modality)
    8. self.output_size = (228, 304)
    9. def train_transform(self, rgb, depth):
    10. s = np.random.uniform(1.0, 1.5) # random scaling
    11. depth_np = depth / s
    12. angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
    13. do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
    14. # perform 1st step of data augmentation
    15. transform = transforms.Compose([
    16. transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow
    17. transforms.Rotate(angle),
    18. transforms.Resize(s),
    19. transforms.CenterCrop(self.output_size),
    20. transforms.HorizontalFlip(do_flip)
    21. ])
    22. rgb_np = transform(rgb)
    23. rgb_np = self.color_jitter(rgb_np) # random color jittering
    24. rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    25. depth_np = transform(depth_np)
    26. return rgb_np, depth_np
    27. def val_transform(self, rgb, depth):
    28. depth_np = depth
    29. transform = transforms.Compose([
    30. transforms.Resize(240.0 / iheight),
    31. transforms.CenterCrop(self.output_size),
    32. ])
    33. rgb_np = transform(rgb)
    34. rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    35. depth_np = transform(depth_np)
    36. return rgb_np, depth_np
    37. def test_transform(self, rgb, depth):
    38. depth_np = depth
    39. transform = transforms.Compose([
    40. transforms.Resize(240.0 / iheight),
    41. transforms.CenterCrop(self.output_size),
    42. ])
    43. rgb_np = transform(rgb)
    44. rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    45. depth_np = transform(depth_np)
    46. return rgb_np, depth_np

            dataloader.py

    1. import os
    2. import os.path
    3. import numpy as np
    4. import torch.utils.data as data
    5. import h5py
    6. import dataloaders.transforms as transforms
    7. IMG_EXTENSIONS = ['.h5',]
    8. def is_image_file(filename):
    9. return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
    10. def find_classes(dir):
    11. classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    12. classes.sort()
    13. class_to_idx = {classes[i]: i for i in range(len(classes))}
    14. return classes, class_to_idx
    15. def make_dataset(dir, class_to_idx):
    16. images = []
    17. dir = os.path.expanduser(dir)
    18. for target in sorted(os.listdir(dir)):
    19. d = os.path.join(dir, target)
    20. if not os.path.isdir(d):
    21. continue
    22. for root, _, fnames in sorted(os.walk(d)):
    23. for fname in sorted(fnames):
    24. if is_image_file(fname):
    25. path = os.path.join(root, fname)
    26. item = (path, class_to_idx[target])
    27. images.append(item)
    28. return images
    29. def h5_loader(path):
    30. h5f = h5py.File(path, "r")
    31. rgb = np.array(h5f['rgb'])
    32. rgb = np.transpose(rgb, (1, 2, 0))
    33. depth = np.array(h5f['depth'])
    34. return rgb, depth
    35. # def rgb2grayscale(rgb):
    36. # return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114
    37. to_tensor = transforms.ToTensor()
    38. class MyDataloader(data.Dataset):
    39. modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd'
    40. color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
    41. def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
    42. classes, class_to_idx = find_classes(root)
    43. imgs = make_dataset(root, class_to_idx)
    44. assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n"
    45. print("Found {} images in {} folder.".format(len(imgs), type))
    46. self.root = root
    47. self.imgs = imgs
    48. self.classes = classes
    49. self.class_to_idx = class_to_idx
    50. if type == 'train':
    51. self.transform = self.train_transform
    52. elif type == 'val':
    53. self.transform = self.val_transform
    54. elif type == 'test':
    55. self.transform = self.test_transform
    56. else:
    57. raise (RuntimeError("Invalid dataset type: " + type + "\n"
    58. "Supported dataset types are: train, val"))
    59. self.loader = loader
    60. self.sparsifier = sparsifier
    61. assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \
    62. "Supported dataset types are: " + ''.join(self.modality_names)
    63. self.modality = modality
    64. self.mark = type
    65. def train_transform(self, rgb, depth):
    66. raise (RuntimeError("train_transform() is not implemented. "))
    67. def val_transform(rgb, depth):
    68. raise (RuntimeError("val_transform() is not implemented."))
    69. def test_transform(rgb, depth):
    70. raise (RuntimeError("test_transform() is not implemented."))
    71. def create_sparse_depth(self, rgb, depth):
    72. if self.sparsifier is None:
    73. return depth
    74. else:
    75. mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
    76. sparse_depth = np.zeros(depth.shape)
    77. sparse_depth[mask_keep] = depth[mask_keep]
    78. return sparse_depth
    79. def create_rgbd(self, rgb, depth):
    80. sparse_depth = self.create_sparse_depth(rgb, depth)
    81. rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
    82. return rgbd
    83. def __getraw__(self, index):
    84. """
    85. Args:
    86. index (int): Index
    87. Returns:
    88. tuple: (rgb, depth) the raw data.
    89. """
    90. path, target = self.imgs[index]
    91. rgb, depth = self.loader(path)
    92. _, name = os.path.split(path)
    93. name = name.split('.')[0]
    94. return rgb, depth, name
    95. def __getitem__(self, index):
    96. rgb, depth, name = self.__getraw__(index)
    97. if self.transform is not None:
    98. rgb_np, depth_np = self.transform(rgb, depth)
    99. else:
    100. raise(RuntimeError("transform not defined"))
    101. # color normalization
    102. # rgb_tensor = normalize_rgb(rgb_tensor)
    103. # rgb_np = normalize_np(rgb_np)
    104. if self.modality == 'rgb':
    105. input_np = rgb_np
    106. elif self.modality == 'rgbd':
    107. input_np = self.create_rgbd(rgb_np, depth_np)
    108. elif self.modality == 'd':
    109. input_np = self.create_sparse_depth(rgb_np, depth_np)
    110. input_tensor = to_tensor(input_np)
    111. while input_tensor.dim() < 3:
    112. input_tensor = input_tensor.unsqueeze(0)
    113. if self.mark == 'test':
    114. depth_tensor = name
    115. else:
    116. depth_tensor = to_tensor(depth_np)
    117. depth_tensor = depth_tensor.unsqueeze(0)
    118. return input_tensor, depth_tensor
    119. def __len__(self):
    120. return len(self.imgs)

            2.参数部分

            util.py

    1. import os
    2. import cv2
    3. import torch
    4. import shutil
    5. import numpy as np
    6. import matplotlib.pyplot as plt
    7. from PIL import Image
    8. cmap = plt.cm.viridis
    9. def parse_command():
    10. model_names = ['resnet18', 'resnet50']
    11. loss_names = ['l1', 'l2']
    12. data_names = ['nyudepthv2', 'kitti']
    13. from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo
    14. sparsifier_names = [x.name for x in [UniformSampling, SimulatedStereo]]
    15. from models import Decoder
    16. decoder_names = Decoder.names
    17. from dataloaders.dataloader import MyDataloader
    18. modality_names = MyDataloader.modality_names
    19. import argparse
    20. parser = argparse.ArgumentParser(description='Sparse-to-Dense')
    21. parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', choices=model_names,
    22. help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
    23. parser.add_argument('--data', metavar='DATA', default='nyudepthv2',
    24. choices=data_names,
    25. help='dataset: ' + ' | '.join(data_names) + ' (default: nyudepthv2)')
    26. parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgb', choices=modality_names,
    27. help='modality: ' + ' | '.join(modality_names) + ' (default: rgb)')
    28. parser.add_argument('-s', '--num-samples', default=0, type=int, metavar='N',
    29. help='number of sparse depth samples (default: 0)')
    30. parser.add_argument('--max-depth', default=-1.0, type=float, metavar='D',
    31. help='cut-off depth of sparsifier, negative values means infinity (default: inf [m])')
    32. parser.add_argument('--sparsifier', metavar='SPARSIFIER', default=UniformSampling.name, choices=sparsifier_names,
    33. help='sparsifier: ' + ' | '.join(sparsifier_names) + ' (default: ' + UniformSampling.name + ')')
    34. parser.add_argument('--decoder', '-d', metavar='DECODER', default='deconv2', choices=decoder_names,
    35. help='decoder: ' + ' | '.join(decoder_names) + ' (default: deconv2)')
    36. parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
    37. help='number of data loading workers (default: 10)')
    38. parser.add_argument('--epochs', default=15, type=int, metavar='N',
    39. help='number of total epochs to run (default: 15)')
    40. parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1', choices=loss_names,
    41. help='loss function: ' + ' | '.join(loss_names) + ' (default: l1)')
    42. parser.add_argument('-b', '--batch-size', default=2, type=int, help='mini-batch size (default: 8)')
    43. parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
    44. metavar='LR', help='initial learning rate (default 0.01)')
    45. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    46. help='momentum')
    47. parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
    48. metavar='W', help='weight decay (default: 1e-4)')
    49. parser.add_argument('--print-freq', '-p', default=10, type=int,
    50. metavar='N', help='print frequency (default: 10)')
    51. parser.add_argument('--resume', default='', type=str, metavar='PATH',
    52. help='path to latest checkpoint (default: none)')
    53. parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, default='',
    54. help='evaluate model on validation set')
    55. parser.add_argument('-t', '--test', dest='test', type=str, default='',
    56. help='test model on test set')
    57. parser.add_argument('--no-pretrain', dest='pretrained', action='store_false',
    58. help='not to use ImageNet pre-trained weights')
    59. parser.set_defaults(pretrained=True)
    60. args = parser.parse_args()
    61. if args.modality == 'rgb' and args.num_samples != 0:
    62. print("number of samples is forced to be 0 when input modality is rgb")
    63. args.num_samples = 0
    64. if args.modality == 'rgb' and args.max_depth != 0.0:
    65. print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
    66. args.max_depth = 0.0
    67. return args
    68. def save_checkpoint(state, is_best, epoch, output_directory):
    69. checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar')
    70. torch.save(state, checkpoint_filename)
    71. if is_best:
    72. best_filename = os.path.join(output_directory, 'model_best.pth.tar')
    73. shutil.copyfile(checkpoint_filename, best_filename)
    74. if epoch > 0:
    75. prev_checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch-1) + '.pth.tar')
    76. if os.path.exists(prev_checkpoint_filename):
    77. os.remove(prev_checkpoint_filename)
    78. def adjust_learning_rate(optimizer, epoch, lr_init):
    79. """Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
    80. lr = lr_init * (0.1 ** (epoch // 5))
    81. for param_group in optimizer.param_groups:
    82. param_group['lr'] = lr
    83. def get_output_directory(args):
    84. output_directory = os.path.join('results',
    85. '{}.sparsifier={}.samples={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}.pretrained={}'.
    86. format(args.data, args.sparsifier, args.num_samples, args.modality, \
    87. args.arch, args.decoder, args.criterion, args.lr, args.batch_size, \
    88. args.pretrained))
    89. return output_directory
    90. def colored_depthmap(depth, d_min=None, d_max=None):
    91. if d_min is None:
    92. d_min = np.min(depth)
    93. if d_max is None:
    94. d_max = np.max(depth)
    95. depth_relative = (depth - d_min) / (d_max - d_min)
    96. return 255 * cmap(depth_relative)[:,:,:3] # H, W, C
    97. def merge_into_row(input, depth_target, depth_pred):
    98. rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
    99. depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
    100. depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
    101. d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu))
    102. d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu))
    103. depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
    104. depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
    105. img_merge = np.hstack([rgb, depth_target_col, depth_pred_col])
    106. return img_merge
    107. def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred):
    108. rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
    109. depth_input_cpu = np.squeeze(depth_input.cpu().numpy())
    110. depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
    111. depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
    112. d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu))
    113. d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu))
    114. depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max)
    115. depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
    116. depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
    117. img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col])
    118. return img_merge
    119. def add_row(img_merge, row):
    120. return np.vstack([img_merge, row])
    121. def save_image(img_merge, filename):
    122. img_merge = Image.fromarray(img_merge.astype('uint8'))
    123. img_merge.save(filename)
    124. def strentch_img(pred):
    125. depth_pred_cpu = np.squeeze(pred.data.cpu().numpy())
    126. d_min = np.min(depth_pred_cpu)
    127. d_max = np.max(depth_pred_cpu)
    128. depth_pred_cpu = cv2.resize(depth_pred_cpu, (1280, 720))
    129. depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
    130. return depth_pred_col

            (3)主函数部分

            main.py

    1. import os
    2. import time
    3. import csv
    4. import numpy as np
    5. import torch
    6. import torch.backends.cudnn as cudnn
    7. import torch.optim
    8. cudnn.benchmark = True
    9. from models import ResNet
    10. from metrics import AverageMeter, Result
    11. from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo
    12. import criteria
    13. import utils
    14. from PIL import Image
    15. torch.nn.Module.dump_patches = True
    16. args = utils.parse_command()
    17. print(args)
    18. fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae',
    19. 'delta1', 'delta2', 'delta3',
    20. 'data_time', 'gpu_time']
    21. best_result = Result()
    22. best_result.set_to_worst()
    23. def create_data_loaders(args):
    24. # Data loading code
    25. print("=> creating data loaders ...")
    26. traindir = os.path.join('data', args.data, 'train')
    27. valdir = os.path.join('data', args.data, 'val')
    28. testdir = os.path.join('data', args.data, 'test')
    29. train_loader = None
    30. val_loader = None
    31. test_loader = None
    32. # sparsifier is a class for generating random sparse depth input from the ground truth
    33. sparsifier = None
    34. max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
    35. if args.sparsifier == UniformSampling.name:
    36. sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
    37. elif args.sparsifier == SimulatedStereo.name:
    38. sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
    39. if args.data == 'nyudepthv2':
    40. from dataloaders.nyu_dataloader import NYUDataset
    41. if args.evaluate:
    42. val_dataset = NYUDataset(valdir, type='val',
    43. modality=args.modality, sparsifier=sparsifier)
    44. # set batch size to be 1 for validation
    45. val_loader = torch.utils.data.DataLoader(val_dataset,
    46. batch_size=1, shuffle=False, num_workers=args.workers,
    47. pin_memory=True)
    48. elif args.test:
    49. test_dataset = NYUDataset(testdir, type='test',
    50. modality=args.modality, sparsifier=sparsifier)
    51. test_loader = torch.utils.data.DataLoader(test_dataset,
    52. batch_size=1, shuffle=False, num_workers=args.workers,
    53. pin_memory=True)
    54. else:
    55. train_dataset = NYUDataset(traindir, type='train',
    56. modality=args.modality, sparsifier=sparsifier)
    57. elif args.data == 'kitti':
    58. from dataloaders.kitti_dataloader import KITTIDataset
    59. if not args.evaluate:
    60. train_dataset = KITTIDataset(traindir, type='train',
    61. modality=args.modality, sparsifier=sparsifier)
    62. val_dataset = KITTIDataset(valdir, type='val',
    63. modality=args.modality, sparsifier=sparsifier)
    64. # set batch size to be 1 for validation
    65. val_loader = torch.utils.data.DataLoader(val_dataset,
    66. batch_size=1, shuffle=False, num_workers=args.workers,
    67. pin_memory=True)
    68. else:
    69. raise RuntimeError('Dataset not found.' +
    70. 'The dataset must be either of nyudepthv2 or kitti.')
    71. # put construction of train loader here, for those who are interested in testing only
    72. if not args.evaluate and not args.test:
    73. train_loader = torch.utils.data.DataLoader(
    74. train_dataset, batch_size=args.batch_size, shuffle=True,
    75. num_workers=args.workers, pin_memory=True, sampler=None,
    76. worker_init_fn=lambda work_id:np.random.seed(work_id))
    77. # worker_init_fn ensures different sampling patterns for each data loading thread
    78. print("=> data loaders created.")
    79. return train_loader, val_loader, test_loader
    80. test_save_path = './results/'
    81. def main():
    82. global args, best_result, output_directory, train_csv, test_csv
    83. # evaluation mode
    84. start_epoch = 0
    85. if args.evaluate:
    86. assert os.path.isfile(args.evaluate), \
    87. "=> no best model found at '{}'".format(args.evaluate)
    88. print("=> loading best model '{}'".format(args.evaluate))
    89. checkpoint = torch.load(args.evaluate)
    90. output_directory = os.path.dirname(args.evaluate)
    91. args = checkpoint['args']
    92. start_epoch = checkpoint['epoch'] + 1
    93. best_result = checkpoint['best_result']
    94. model = checkpoint['model']
    95. print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
    96. args.test = ''
    97. args.evaluate = True
    98. _, val_loader, _ = create_data_loaders(args)
    99. validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
    100. return
    101. elif args.test:
    102. assert os.path.isfile(args.test), \
    103. "=> no best model found at '{}'".format(args.test)
    104. print("=> loading best model '{}'".format(args.test))
    105. checkpoint = torch.load(args.test)
    106. output_directory = os.path.dirname(args.test)
    107. args = checkpoint['args']
    108. start_epoch = checkpoint['epoch'] + 1
    109. best_result = checkpoint['best_result']
    110. model = checkpoint['model']
    111. print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
    112. args.test = True
    113. _, _, test_loader = create_data_loaders(args)
    114. test(test_loader, model, test_save_path)
    115. return
    116. # optionally resume from a checkpoint
    117. elif args.resume:
    118. chkpt_path = args.resume
    119. assert os.path.isfile(chkpt_path), \
    120. "=> no checkpoint found at '{}'".format(chkpt_path)
    121. print("=> loading checkpoint '{}'".format(chkpt_path))
    122. checkpoint = torch.load(chkpt_path)
    123. args = checkpoint['args']
    124. start_epoch = checkpoint['epoch'] + 1
    125. best_result = checkpoint['best_result']
    126. model = checkpoint['model']
    127. optimizer = checkpoint['optimizer']
    128. output_directory = os.path.dirname(os.path.abspath(chkpt_path))
    129. print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    130. train_loader, val_loader, test_loader = create_data_loaders(args)
    131. args.resume = True
    132. # create new model
    133. else:
    134. train_loader, val_loader, test_loader = create_data_loaders(args)
    135. print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
    136. in_channels = len(args.modality)
    137. if args.arch == 'resnet50':
    138. model = ResNet(layers=50, decoder=args.decoder, output_size=train_loader.dataset.output_size,
    139. in_channels=in_channels, pretrained=args.pretrained)
    140. elif args.arch == 'resnet18':
    141. model = ResNet(layers=18, decoder=args.decoder, output_size=train_loader.dataset.output_size,
    142. in_channels=in_channels, pretrained=args.pretrained)
    143. print("=> model created.")
    144. optimizer = torch.optim.SGD(model.parameters(), args.lr, \
    145. momentum=args.momentum, weight_decay=args.weight_decay)
    146. # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
    147. model = model.cuda()
    148. # define loss function (criterion) and optimizer
    149. if args.criterion == 'l2':
    150. criterion = criteria.MaskedMSELoss().cuda()
    151. elif args.criterion == 'l1':
    152. criterion = criteria.MaskedL1Loss().cuda()
    153. # create results folder, if not already exists
    154. output_directory = utils.get_output_directory(args)
    155. if not os.path.exists(output_directory):
    156. os.makedirs(output_directory)
    157. train_csv = os.path.join(output_directory, 'train.csv')
    158. test_csv = os.path.join(output_directory, 'test.csv')
    159. best_txt = os.path.join(output_directory, 'best.txt')
    160. # create new csv files with only header
    161. if not args.resume:
    162. with open(train_csv, 'w') as csvfile:
    163. writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    164. writer.writeheader()
    165. with open(test_csv, 'w') as csvfile:
    166. writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    167. writer.writeheader()
    168. for epoch in range(start_epoch, args.epochs):
    169. utils.adjust_learning_rate(optimizer, epoch, args.lr)
    170. train(train_loader, model, criterion, optimizer, epoch) # train for one epoch
    171. result, img_merge = validate(val_loader, model, epoch) # evaluate on validation set
    172. # remember best rmse and save checkpoint
    173. is_best = result.rmse < best_result.rmse
    174. if is_best:
    175. best_result = result
    176. with open(best_txt, 'w') as txtfile:
    177. txtfile.write("epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n".
    178. format(epoch, result.mse, result.rmse, result.absrel, result.lg10, result.mae, result.delta1, result.gpu_time))
    179. if img_merge is not None:
    180. img_filename = output_directory + '/comparison_best.png'
    181. utils.save_image(img_merge, img_filename)
    182. utils.save_checkpoint({
    183. 'args': args,
    184. 'epoch': epoch,
    185. 'arch': args.arch,
    186. 'model': model,
    187. 'best_result': best_result,
    188. 'optimizer' : optimizer,
    189. }, is_best, epoch, output_directory)
    190. def train(train_loader, model, criterion, optimizer, epoch):
    191. average_meter = AverageMeter()
    192. model.train() # switch to train mode
    193. end = time.time()
    194. for i, (input, target) in enumerate(train_loader):
    195. input, target = input.cuda(), target.cuda()
    196. torch.cuda.synchronize()
    197. data_time = time.time() - end
    198. # compute pred
    199. end = time.time()
    200. pred = model(input)
    201. loss = criterion(pred, target)
    202. optimizer.zero_grad()
    203. loss.backward() # compute gradient and do SGD step
    204. optimizer.step()
    205. torch.cuda.synchronize()
    206. gpu_time = time.time() - end
    207. # measure accuracy and record loss
    208. result = Result()
    209. result.evaluate(pred.data, target.data)
    210. average_meter.update(result, gpu_time, data_time, input.size(0))
    211. end = time.time()
    212. if (i + 1) % args.print_freq == 0:
    213. print('=> output: {}'.format(output_directory))
    214. print('Train Epoch: {0} [{1}/{2}]\t'
    215. 't_Data={data_time:.3f}({average.data_time:.3f}) '
    216. 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
    217. 'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
    218. 'MAE={result.mae:.2f}({average.mae:.2f}) '
    219. 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
    220. 'REL={result.absrel:.3f}({average.absrel:.3f}) '
    221. 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
    222. epoch, i+1, len(train_loader), data_time=data_time,
    223. gpu_time=gpu_time, result=result, average=average_meter.average()))
    224. avg = average_meter.average()
    225. with open(train_csv, 'a') as csvfile:
    226. writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    227. writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
    228. 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
    229. 'gpu_time': avg.gpu_time, 'data_time': avg.data_time})
    230. def validate(val_loader, model, epoch, write_to_file=True):
    231. average_meter = AverageMeter()
    232. model.eval() # switch to evaluate mode
    233. end = time.time()
    234. for i, (input, target) in enumerate(val_loader):
    235. input, target = input.cuda(), target.cuda()
    236. torch.cuda.synchronize()
    237. data_time = time.time() - end
    238. # compute output
    239. end = time.time()
    240. with torch.no_grad():
    241. pred = model(input)
    242. torch.cuda.synchronize()
    243. gpu_time = time.time() - end
    244. # measure accuracy and record loss
    245. result = Result()
    246. result.evaluate(pred.data, target.data)
    247. average_meter.update(result, gpu_time, data_time, input.size(0))
    248. end = time.time()
    249. # save 8 images for visualization
    250. skip = 50
    251. if args.modality == 'd':
    252. img_merge = None
    253. else:
    254. if args.modality == 'rgb':
    255. rgb = input
    256. elif args.modality == 'rgbd':
    257. rgb = input[:,:3,:,:]
    258. depth = input[:,3:,:,:]
    259. if i == 0:
    260. if args.modality == 'rgbd':
    261. img_merge = utils.merge_into_row_with_gt(rgb, depth, target, pred)
    262. else:
    263. img_merge = utils.merge_into_row(rgb, target, pred)
    264. elif (i < 8*skip) and (i % skip == 0):
    265. if args.modality == 'rgbd':
    266. row = utils.merge_into_row_with_gt(rgb, depth, target, pred)
    267. else:
    268. row = utils.merge_into_row(rgb, target, pred)
    269. img_merge = utils.add_row(img_merge, row)
    270. elif i == 8*skip:
    271. filename = output_directory + '/comparison_' + str(epoch) + '.png'
    272. utils.save_image(img_merge, filename)
    273. if (i+1) % args.print_freq == 0:
    274. print('Test: [{0}/{1}]\t'
    275. 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
    276. 'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
    277. 'MAE={result.mae:.2f}({average.mae:.2f}) '
    278. 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
    279. 'REL={result.absrel:.3f}({average.absrel:.3f}) '
    280. 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
    281. i+1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average()))
    282. avg = average_meter.average()
    283. print('\n*\n'
    284. 'RMSE={average.rmse:.3f}\n'
    285. 'MAE={average.mae:.3f}\n'
    286. 'Delta1={average.delta1:.3f}\n'
    287. 'REL={average.absrel:.3f}\n'
    288. 'Lg10={average.lg10:.3f}\n'
    289. 't_GPU={time:.3f}\n'.format(
    290. average=avg, time=avg.gpu_time))
    291. if write_to_file:
    292. with open(test_csv, 'a') as csvfile:
    293. writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    294. writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
    295. 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
    296. 'data_time': avg.data_time, 'gpu_time': avg.gpu_time})
    297. return avg, img_merge
    298. def test(test_loader, model, save_path):
    299. average_meter = AverageMeter()
    300. model.eval() # switch to evaluate mode
    301. for i, (input, target) in enumerate(test_loader):
    302. input, name = input.cuda(), target
    303. torch.cuda.synchronize()
    304. # compute output
    305. end = time.time()
    306. with torch.no_grad():
    307. pred = model(input)
    308. torch.cuda.synchronize()
    309. pred1 = utils.strentch_img(pred)
    310. save_to_file = os.path.join(save_path, name[0] + '.png')
    311. utils.save_image(pred1, save_to_file)
    312. save_to_tif = os.path.join(save_path, name[0] + '_ori.tiff')
    313. depth_pred_cpu = np.squeeze(pred.data.cpu().numpy())
    314. img = Image.fromarray(depth_pred_cpu)
    315. img = img.resize((1280, 720))
    316. img.save(save_to_tif)
    317. if __name__ == '__main__':
    318. main()

            (4)测试

            改好上面后,创建test文件夹,放入数据

     接着命令行输入下面的命令

    python main.py --test model_best.pth

    白色的图是结果,彩色图是白色图可视化后的结果 ,存放位置在mian.py的第90行改(test_save_path = './results/')

  • 相关阅读:
    WPF实现轮播图(图片、视屏)
    OpenMP编程-九点差分法求解泊松方程
    基于Python代码的相关性热力图,VIF共线性诊断图及残差四图的使用及解释
    QQ浏览器怎么才能设置默认搜索引擎为百度
    一百八十五、大数据离线数仓完整流程——步骤四、在Hive的DWD层建动态分区表并动态加载数据
    阿里云周宇:神龙计算平台智能运维体系建设
    【PyTorch深度学习项目实战100例】—— 基于Pytorch的中文问题相似度实战 | 第72例
    质量属性案例-架构真题(二十一)
    【JVM】触发 Full GC 的条件
    生成m3u8视频:批量剪辑与分割的完美结合
  • 原文地址:https://blog.csdn.net/qq_20373723/article/details/126815756