• 中秋佳节,基于华为云AI制作属于自己的月亮!


    一、前言

    农历八月十五处于仲秋,节近秋分,故称“中秋节”,周代也有“秋分祭月”的习俗,中秋节自古便与月亮结下了不解之缘,故又称月夕、拜月节等。时值中秋,举家齐团圆,共望皎洁月,可曾想到中秋与那天边月有哪些历史的秘密呢?

    中秋节之名源自于节令,吴子牧云:“八月十五日中秋节,此日三秋恰半,故谓之中秋。此夜月色倍明于常时,又谓之月夕。”中秋一词最早可见于《周礼》,于隋唐时正式定为节日。大约从宋代开始,中秋节越来越流行,终于到明清时期开始与春节齐名,成了我国传统节日中最重要的两个节日之一。

    二、结果展示

    三、模型简介

    Github有一个开源项目SkyAR修改而成,可以自动识别天空,然后将天空从图片中切割出来,再将天空替换成目标天空,从而实现魔法换天。

    参考文献中的论文提出了一种基于视觉的视频天空替换和协调方法,该方法可以在具有可控风格的视频中自动生成逼真的天空背景。与以前的天空编辑方法不同,该方法专注于静态照片或需要集成在智能手机中的惯性测量装置拍摄视频,该方法完全基于视觉,对捕获设备没有任何要求,并且可以很好地应用于在线或离线处理场景。

    算法流程大致可以分为三个步骤:

    1. 天空抠图:这一步主要是通过对蒙版数据集进行训练,将图片中的天空和其它物体进行像素级的划分,将天空部分从图片中分离。
    2. 运动估计:对图片中物体的位移情况进行分析,预估相机的移动方向,使替换后的天空和之前的天空位移一致。
    3. 图像混合:将去掉天空的原视频和要替换后的天空视频进行融合,同时对非天空的部分采用色彩叠加,是天空和其它物体的视觉效果相近,是视频效果更加逼真。

    四、实验环境

    本案例使用AI框架:PyTorch-1.4,在CPU和GPU下面均可运行,CPU环境运行预计花费9分钟,GPU环境运行预计花费2分钟

    CPU:8核
    内存:64GB
    GPU:nvidia-p100(32GB) * 1
    架构:x86_64
    规格:modelarts.vm.gpu

    五、实验步骤

    代码基于 ModelArts jupyterLab 运行

    1、导入依赖包

    import os
    import moxing as mox
    
    file_name = 'SkyAR'
    if not os.path.exists(file_name):
        mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')
        os.system('unzip SkyAR.zip')
        os.system('rm SkyAR.zip')
        mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    !pip uninstall opencv-python -y
    !pip uninstall opencv-contrib-python -y
    
    • 1
    • 2
    !pip install opencv-contrib-python==4.5.3.56
    !pip install ipywidgets==7.7.1
    
    • 1
    • 2
    cd SkyAR/
    
    • 1
    import time
    import json
    import base64
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import argparse
    from networks import *
    from skyboxengine import *
    import utils
    import torch
    from IPython.display import clear_output, Image, display, HTML
    
    
    %matplotlib inline
    
    # 如果存在GPU则在GPU上面运行
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    2、参数设置

    parameter = {
      "net_G": "coord_resnet50",
      "ckptdir": "./checkpoints_G_coord_resnet50",
    
      "input_mode": "video",
      "datadir": "./test_videos/preview.mp4",  # 待处理的原视频路径
      "skybox": "supermoon.jpg",  # 要替换的天空图片路径
    
      "in_size_w": 384,
      "in_size_h": 384,
      "out_size_w": 845,
      "out_size_h": 480,
    
      "skybox_center_crop": 0.5,
      "auto_light_matching": False,
      "relighting_factor": 0.8,
      "recoloring_factor": 0.5,
      "halo_effect": True,
    
      "output_dir": "./jpg_output",
      "save_jpgs": False
    }
    
    str_json = json.dumps(parameter)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    其中,

    • skybox_center_crop:天空体中心偏移
    • auto_light_matching: 是否自动亮度匹配
    • relighting_factor: 补光
    • recoloring_factor: 重新着色
    • halo_effect: 是否开启光环效应
    • datadir: 待处理的原视频和要
    • skybox: 替换的天空图片

    3、调用视频和图片

    video_name = parameter['datadir']
    
    
    def arrayShow(img):
        img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)
        _,ret = cv2.imencode('.jpg', img)
        return Image(data=ret)
    
    # 打开一个视频流
    cap = cv2.VideoCapture(video_name)
    
    frame_id = 0
    while True:
        try:
            clear_output(wait=True) # 清除之前的显示
            ret, frame = cap.read() # 读取一帧图片
            if ret:
                frame_id += 1
                if frame_id > 200:
                    break
                cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # 画frame_id
                tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
                img = arrayShow(frame)
                display(img) # 显示图片
                time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
            else:
                break
        except KeyboardInterrupt:
            cap.release()
    cap.release()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    img= cv2.imread(os.path.join('./skybox', parameter['skybox']))
    img2 = img[:, :, ::-1]
    plt.imshow(img2)
    
    • 1
    • 2
    • 3

    4、定义SkyFilter类

    class Struct:
        def __init__(self, **entries):
            self.__dict__.update(entries)  
    def parse_config():
        data = json.loads(str_json)
        args = Struct(**data)
    
        return args
    args = parse_config()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    class SkyFilter():
    
        def __init__(self, args):
    
            self.ckptdir = args.ckptdir
            self.datadir = args.datadir
            self.input_mode = args.input_mode
    
            self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
            self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
    
            self.skyboxengine = SkyBox(args)
    
            self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
            self.load_model()
    
            self.video_writer = cv2.VideoWriter('out.avi',
                                                cv2.VideoWriter_fourcc(*'MJPG'),
                                                20.0,
                                                (args.out_size_w, args.out_size_h))
            self.video_writer_cat = cv2.VideoWriter('compare.avi',
                                                    cv2.VideoWriter_fourcc(*'MJPG'),
                                                    20.0,
                                                    (2*args.out_size_w, args.out_size_h))
    
            if os.path.exists(args.output_dir) is False:
                os.mkdir(args.output_dir)
    
            self.output_img_list = []
    
            self.save_jpgs = args.save_jpgs
    
    
        def load_model(self):
            # 加载预训练的天空抠图模型
            print('loading the best checkpoint...')
            checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),
                                    map_location=device)
            self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
            self.net_G.to(device)
            self.net_G.eval()
    
    
        def write_video(self, img_HD, syneth):
    
            frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
            self.video_writer.write(frame)
    
            frame_cat = np.concatenate([img_HD, syneth], axis=1)
            frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
            self.video_writer_cat.write(frame_cat)
    
            # 定义结果缓冲区
            self.output_img_list.append(frame_cat)
    
    
        def synthesize(self, img_HD, img_HD_prev):
    
            h, w, c = img_HD.shape
    
            img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
    
            img = np.array(img, dtype=np.float32)
            img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
    
            with torch.no_grad():
                G_pred = self.net_G(img.to(device))
                G_pred = torch.nn.functional.interpolate(G_pred,
                                                         (h, w),
                                                         mode='bicubic',
                                                         align_corners=False)
                G_pred = G_pred[0, :].permute([1, 2, 0])
                G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
                G_pred = np.array(G_pred.detach().cpu())
                G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
    
            skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
    
            syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
    
            return syneth, G_pred, skymask
    
    
        def cvtcolor_and_resize(self, img_HD):
    
            img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
            img_HD = np.array(img_HD / 255., dtype=np.float32)
            img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
    
            return img_HD
            
    
        def process_video(self):
            # 逐帧处理视频
            cap = cv2.VideoCapture(self.datadir)
            m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            img_HD_prev = None
    
            for idx in range(m_frames):
                ret, frame = cap.read()
                if ret:
                    img_HD = self.cvtcolor_and_resize(frame)
    
                    if img_HD_prev is None:
                        img_HD_prev = img_HD
    
                    syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
    
                    self.write_video(img_HD, syneth)
    
                    img_HD_prev = img_HD
    
                    if (idx + 1) % 50 == 0:
                        print(f'processing video, frame {idx + 1} / {m_frames} ... ')
    
                else:  # 如果到达最后一帧
                    break
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117

    5、处理视频并与原视频对比

    sf = SkyFilter(args)
    sf.process_video()
    
    • 1
    • 2
    video_name = "compare.avi"
    
    
    def arrayShow(img):
        _,ret = cv2.imencode('.jpg', img)
        return Image(data=ret)
    
    # 打开一个视频流
    cap = cv2.VideoCapture(video_name)
    
    frame_id = 0
    while True:
        try:
            clear_output(wait=True) # 清除之前的显示
            ret, frame = cap.read() # 读取一帧图片
            if ret:
                frame_id += 1
                cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # 画frame_id
                tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
                img = arrayShow(frame)
                display(img) # 显示图片
                time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
            else:
                break
        except KeyboardInterrupt:
            cap.release()
    cap.release()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    六、生成自己的换天视频

    1. 在自己本地电脑上准备好一个待处理的mp4视频文件和一张天空特效图片,注意视频文件必须满足白天拍摄、有蓝天白云天空背景、镜头水平缓慢移动、横屏四个条件,否则天空换背景的效果不佳;
    2. 将视频文件和图片文件拖动到左侧文件浏览窗口,分别上传到ModelArts JupyterLab的SkyAR/test_videos目录和SkyAR/skybox目录下;
    3. 修改 “设定算法参数” 中datadir 和 skybox 两个参数的路径为你刚上传的视频和图片路径;
    4. 重新运行步骤2~5。

    七、参考文献

    ModelArts开发者:魔幻黑科技,可换天造物,秒变科幻大片!

    Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos.

  • 相关阅读:
    停车系统源码
    C++ 图片完整性校验
    发烧友实测 | OKA40i-C开发板SATA硬盘挂载及读写速率测试
    h5 plus 无法下载base64格式图片解决办法
    25、业务层标准开发(也就是service)
    DataTable导出Excel
    TypeScript学习 | 泛型
    CUDA优化之PReLU性能调优
    创邻科技Galaxybase—激活数据要素的核心引擎
    仅作笔记用:Windows 11 通过 VBS 打开 IE 浏览器
  • 原文地址:https://blog.csdn.net/zhu_rui/article/details/126793615