• 一文详解affine_grid 与 grid_sample以及与opencv坐标系的关系


    前言

    网上资料乱七八糟,本文通过坐标系和变换的角度,系统梳理两个操作的作用

    基本仿射变换

    二维仿射变换,我们可以综合为一个2x2的旋转矩阵R和一个2x1的平移矩阵t,[R,t]组合起来就是2x3的矩阵
    我们可以增广为3x3的矩阵,只需最后一行加上0 0 1即可。

    更细致描述请参阅games101

    在这里插入图片描述

    opencv 坐标系

    -------------------->x
    |
    |
    |
    |
    V y
    opencv的坐标系是从左上角开始(0,0) 其中x坐标向右侧延申,y坐标向下延申的。
    因此当我们按照想旋转和平移时,需要按照当前坐标系来理解,例如原先的正向现在为顺时针,向y轴平移现在就是向下平移。
    我们如下代码给出了先旋转45°然后向x轴方向平移50个像素的坐标

    import torch
    import cv2
    from skimage.transform._geometric import _umeyama as get_sym_mat
    import torch.nn.functional as F
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    I = cv2.imread('test.png')
    print(I.size)
    rotate_theta = np.radians(45) # 转换为弧度
    rotate_matrix = np.array([
            [np.cos(rotate_theta), -np.sin(rotate_theta), 50],
            [np.sin(rotate_theta), np.cos(rotate_theta), 0],
            [0, 0, 1]
        ])
    print(rotate_matrix)
    M=rotate_matrix[:2, :]
    J = cv2.warpAffine(I, M,I.shape[:2])
    cv2.imshow('show', J)
    cv2.waitKey(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    在这里插入图片描述
    我们发现事实确实如我们所料

    pytorch坐标系

    pytorch的坐标是在-1到1的归一化坐标系下的值,例如左上角是(-1,-1),右上角是(1,1)
    图片的中心在图片正中间

    将已知的变换矩阵转换为pytorch的变换矩阵

    STN网络中最常用的两个函数是
    affine_gridgrid_sample,这两个函数组合使用可以完成类似的warp操作,而且支持批量操作。
    我们通常的使用方法是
    给定一个theta矩阵,通过theta矩阵变换一个网格grid,dist的每个点根据grid在src里插值填满整个图。

    grid = F.affine_grid(theta, [1, C, H, W])
    dist = F.grid_sample(src, grid)
    
    • 1
    • 2

    一个非常自然的想法就是,我们希望也能像opencv一样能利用我们学过的知识,手动控制变化。
    当我们获得了opencv坐标系下的一个仿射变换矩阵如下所示时,我们需要进行坐标转换
    [ a 1 a 2 t x a 3 a 4 t y 0 0 1 ] \left[ a1a2txa3a4ty001

    a1a30a2a40txty1
    \right] a1a30a2a40txty1

    如果你学过线性代数或者矩阵论的相关知识,就知道我们可以用相似变换来解决这种坐标系下的变换问题。
    不了解的可以查看线性代数的本质-基变换
    对于任意一个原图pytorch坐标系下的点 ( u 1 , v 1 ) (u_1,v_1) (u1,v1)经过变换后,即可找到新图中 ( u 2 , v 2 ) (u_2,v_2) (u2,v2)的点
    [ u 2 v 2 1 ] = [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] [ a 1 a 2 t x a 3 a 4 t y 0 0 1 ] [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] − 1 [ u 1 v 1 1 ] [u2v21]

    u2v21
    =[2W0102H1001]
    2W0002H0111
    [a1a2txa3a4ty001]
    a1a30a2a40txty1
    [2W0102H1001]
    2W0002H0111
    ^{-1}[u1v11]
    u2v21 = W2000H20111 a1a30a2a40txty1 W2000H20111 1 u1v11

    但是由于实际上theta控制着dst图去找src图,因此我们实际上的theta是上述变换的逆矩阵
    θ = [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] [ a 1 a 2 t x a 3 a 4 t y 0 0 1 ] − 1 [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] − 1 \theta = \left[ 2W0102H1001

    \right] \left[ a1a2txa3a4ty001
    \right]^{-1} \left[ 2W0102H1001
    \right]^{-1} θ= W2000H20111 a1a30a2a40txty1 1 W2000H20111 1

    实战

    给定一张宽高为W,H的图像,以及其中的某一块区域(x,y,w,h)意思是在原图x,y像素位置,有像素长度为w和h的框,请你将这个框截取出来并通过仿射变换为一个W和H大小的图片
    我们依次按照theta矩阵填入

    θ = [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] [ W / w 0 − x W / w 0 H / h − y H / h 0 0 1 ] − 1 [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] − 1 \theta = \left[ 2W0102H1001

    \right] \left[ W/w0xW/w0H/hyH/h001
    \right]^{-1} \left[ 2W0102H1001
    \right]^{-1} θ= W2000H20111 W/w000H/h0xW/wyH/h1 1 W2000H20111 1
    中间变换矩阵的含义是先将小图 x,y 的源点位置移动到0,0,随后进行尺度缩放。
    θ = [ w / W 0 − 1 + 2 ∗ ( x + w / 2 ) / W 0 h / H − 1 + 2 ∗ ( y + h / 2 ) / H 0 0 1 ] \theta = \left[ w/W01+2(x+w/2)/W0h/H1+2(y+h/2)/H001
    \right]
    θ= w/W000h/H01+2(x+w/2)/W1+2(y+h/2)/H1

    import torch
    import cv2
    from skimage.transform._geometric import _umeyama as get_sym_mat
    import torch.nn.functional as F
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    I = cv2.imread('test.png')
    H, W, C = I.shape
    
    It = torch.from_numpy(I).type(torch.float32).permute(2, 0, 1).unsqueeze(0)
    x, y, w, h = [80,70,100,100]
    scale_x = w / W
    scale_y = h / H
    translate_x = (-1 + 2 * (x +w/2)/ W) 
    translate_y =(-1 + 2 * (y +h/2)/ H) 
    # 构造 theta 矩阵
    theta = torch.tensor([[scale_x, 0, translate_x],
                        [0, scale_y, translate_y]], dtype=torch.float).unsqueeze(0)
    
    grid = F.affine_grid(theta, [1, C, H, W])
    Jt = F.grid_sample(It, grid)
    J = Jt.squeeze().permute(1, 2, 0).detach().numpy().astype('uint8')
    cv2.imshow('show', J)
    cv2.waitKey(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    在这里插入图片描述

    参考资料

    https://www.zhihu.com/question/294673086

  • 相关阅读:
    java毕业生设计重工教师职称管理系统计算机源码+系统+mysql+调试部署+lw
    尚硅谷_宋红康_IntelliJ IDEA 常用快捷键一览表
    基于Web的盾构机盾尾变形远程监测系统
    SpringSecurity
    Docker常见用法
    多线程的概念(多线程的代码实现)
    音容笑貌,两臻佳妙,人工智能AI换脸(deepfake)技术复刻《卡萨布兰卡》名场面(Python3.10)
    15.一种坍缩式的简单——组合模式详解
    Navicat:显示的行数与表中实际的行数不一致
    PythonAppium自动化测试环境搭建
  • 原文地址:https://blog.csdn.net/qq_26239785/article/details/138077501