• 【隐私计算】SIRNN: A Math Library for Secure RNN Inference


    刚开始学隐私计算,读到SIRNN,感觉真的好难好难,门槛比deep learning高好多,先尽量啃一啃(捂脸.jpg)。


    1 文章及代码

    Paper: SIRNN: A Math Library for Secure RNN Inference
    Code: https://github.com/mpc-msri/EzPC.

    2 主要贡献

    • 为数学函数(指数、sigmoid、tanh、平方根倒数)提出了全新的密码学友好的新近似。
    • 为不均匀(混合)的bitwidth提供2PC协议,实现高效的数学函数。
    • SIRNN首次为RNN和CNN提供了安全推理库,在延迟、通信等方面达到SOTA,并拥有高数值准确率。

    3 概览

    3.1 Scale和bitwidth

    2PC在整数上运算比在浮点数上运算更高效,在定点运算数中, ⌊ r 2 s ⌋ m o d    2 l \lfloor r2^s\rfloor \mod 2^l r2smod2l,其中 l l l就是bitwidth, s s s是scale。

    3.2 对数学函数的近似

    • 首先用lookup table (LUT)得到一个不错的初始化近似,然后用迭代算法提升这个近似。
    • 更大的LUT近似结果更准确,但是通信开销线性增长。
    • 对于指数和负数输入,分解输入x到更小的子串(digit decomposition)。
    • 为了让迭代算法更高效,本文采用定点(fixed-point)算数及不均匀的混合bitwidth。

    3.3 SIRNN协议

    安全参数: λ = 128 \lambda=128 λ=128
    基于4种构造块:
    (1)Extension(扩展)
    Z 2 m → Z 2 n ( m < n ) \mathbb{Z}_{2^m} \rightarrow \mathbb{Z}_{2^n} (mZ2mZ2n(m<n)
    GC需要的通信开销(重构和重共享)为: λ ( 4 m + 2 n ) \lambda(4m+2n) λ(4m+2n) bits,SIRNN需要的通信开销仅为: λ m \lambda m λm,大约比GC快6x。
    (2)Truncation(截断)
    常用于乘法之后减小规模,对于 l l l-bit截断了 s s s-bit有四种截断操作:

    • 逻辑右移(保留位宽)
    • 算数右移(保留位宽)
    • 截断且减小(输出截断值 Z 2 l − s \mathbb Z_{2^{l-s}} Z2ls
    • 除以 2 s 2^s 2s

    目前最好的算数右移通信大约是: λ ( l + s ) \lambda(l+s) λ(l+s),本文提出的逻辑/算数右移协议大约是 λ l \lambda l λl,大多数数学函数都只需要截断且减小去减小scale和bitwidth,SIRNN只需要 λ ( s + 1 ) \lambda (s+1) λ(s+1)通信。
    (3)Multiplication(乘法)
    m m m-bit整数和 n n n-bit整数相乘得到 l = ( m + n ) l=(m+n) l=(m+n)-bit输出, l l l的选择保证了没有溢出。
    (4)Digit Decomposition(数位分解)
    l l l-bit的值分解为 c = l / d c=l/d c=l/d d d d-bits,可以用GC实现,通信量为 λ ( 6 l − 2 c − 2 ) \lambda (6l-2c-2) λ(6l2c2) bits。本文进一步优化,通信量为 λ ( c − 1 ) ( d + 2 ) \lambda (c-1)(d+2) λ(c1)(d+2) bits,大约比GC低5x。

    4 前提知识

    4.1 ULP误差(units in last place)

    ULP是真实数据和函数输出值之间的可表示数值的数量。

    4.2 威胁模型

    • 两方安全计算(2PC)
    • 静态的半诚实攻击:遵循协议,但是会学习额外信息

    4.3 符号表示

    符号意义
    x ∈ Z 2 l x\in \mathbb Z_{2^l} xZ2lpower-of-2 rings, x x x的环为 Z 2 l \mathbb Z_{2^l} Z2l,即以 2 l 2^l 2l为模
    B B Bring Z 2 \mathbb Z_2 Z2,即以2为模
    λ \lambda λ计算安全系数
    ⊕ \oplus 异或门
    ζ l , ζ l , m ( m > l ) \zeta_l, \zeta_{l,m} (m>l) ζl,ζl,m(m>l)无损lifting操作,映射 Z L → Z \mathbb Z_L\rightarrow \mathbb Z ZLZ,映射 Z L → Z M \mathbb Z_L\rightarrow \mathbb Z_M ZLZM
    L , M , N L,M,N L,M,N 2 l , 2 m , 2 n 2^l, 2^m, 2^n 2l,2m,2n
    [ k ] [k] [k] 0 , 1 , . . , k − 1 {0, 1, .., k-1} 0,1,..,k1
    1 { b } 1\{b\} 1{b} b = t r u e b=true b=true时为1,反之为0
    i n t ( x ) int(x) int(x) u i n t ( x ) uint(x) uint(x)对于 x ∈ Z l x\in \mathbb Z^l xZl,分别代表有符号和无符号值,int(x)=uint(x)−MSB(x)L
    MSB(x)MSB(x) = 1 { x ≥ 2 l − 1 } =1\{x\geq 2^{l-1}\} =1{x2l1},表示最有效高位
    F M i l l l ( x , y ) F_{Mill}^l(x, y) FMilll(x,y) F M i l l l ( x , y ) = ⟨ z ⟩ B = 1 { x < y } F_{Mill}^l(x, y)=\langle z\rangle^B=1\{xFMilll(x,y)=zB=1{x<y}在这里插入图片描述
    F w r a p l F_{wrap}^l Fwrapl F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L1x,y):w=wrap(x,y,L)=1{x+yL}
    e e e e = 1 { ( x + y m o d    L ) = L − 1 } e=1\{(x+y \mod L)=L-1\} e=1{(x+ymodL)=L1},判断是否全是1
    F w r a p & a l l 1 s l F_{wrap\&all1s}^l Fwrap&all1sl F w r a p & a l l 1 s l ( x , y ) = ( ⟨ w ⟩ B ∣ ∣ ⟨ e ⟩ B ) F_{wrap\&all1s}^l(x,y)=(\langle w\rangle^B||\langle e\rangle^B) Fwrap&all1sl(x,y)=(wBeB),至多一项是1
    ∗ m *_m m x ∗ m y = x y m o d    M x*_m y=xy\mod M xmy=xymodM,从 Z × Z → Z M \mathbb Z \times \mathbb Z \rightarrow \mathbb Z_M Z×ZZM
    l l lbitwidth
    s s sscale
    l − s l-s ls整数部分的bitwidth
    F i x ( x , l , s ) Fix(x, l, s) Fix(x,l,s) F i x ( x , l , s ) = x 2 s m o d    L Fix(x, l, s)=x2^s \mod L Fix(x,l,s)=x2smodL,从实数转到定点数表示
    u r t ( l , s ) ( a ) urt_{(l,s)}(a) urt(l,s)(a)对于无符号数, u r t ( l , s ) ( a ) = u i n t ( a ) / 2 s urt_{(l,s)}(a)=uint(a)/2^s urt(l,s)(a)=uint(a)/2s,从定点数转到实数表示
    s r t ( l , s ) ( a ) srt_{(l,s)}(a) srt(l,s)(a)对于有符号数, s r t ( l , s ) ( a ) = i n t ( a ) / 2 s srt_{(l,s)}(a)=int(a)/2^s srt(l,s)(a)=int(a)/2s,从定点数转到实数表示
    > > L , > > A >>_L, >>_A >>L,>>A逻辑右移和算术右移

    4.4 密码学基础

    • 秘密共享(SS)
      2-out-of-2加性秘密共享: x = ⟨ x ⟩ 0 l + ⟨ x ⟩ 1 l m o d    L x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L x=x0l+x1lmodL
    • 不经意传输(OT)
      1-out-of-k OT,用OT Extension (OTE)实现,并用了Correlated OT (COT)。

    4.5 2PC基本函数

    • 百万富翁/wrap
      F M i l l l = 1 { x < y } F_{Mill}^l=1\{xFMilll=1{x<y},CrypTFlow2中通信量低于 λ l + 14 l \lambda l+14l λl+14l bits和 log ⁡ l \log l logl rounds。
      F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L1x,y):w=wrap(x,y,L)=1{x+yL}
    • AND
      输入 ⟨ x ⟩ B , ⟨ y ⟩ B \langle x\rangle^B, \langle y\rangle^B xB,yB,输出 ⟨ x ∧ y ⟩ B \langle x \land y\rangle^B xyB,用Beaver bit-triples实现,CrypTFlow2中通信量为 λ + 20 \lambda+20 λ+20
    • Boolean to Arithmetic (B2A)
      输入boolean share,输出相同值的算术share,采用COT协议实现,通信量为 λ + l \lambda+l λ+l bits。
    • Multiplexer (MUX)
      ⟨ x ⟩ B \langle x\rangle^B xB ⟨ y ⟩ l \langle y\rangle^l yl作为输入,输出 ⟨ z ⟩ l \langle z\rangle^l zl,如果 x = 1 x=1 x=1,则 z = y z=y z=y,反之同理。本文提出的协议将通信量从 2 ( λ + 2 l ) 2(\lambda+2l) 2(λ+2l)(CrypTFlow2)降到 2 ( λ + l ) 2(\lambda+l) 2(λ+l)
    • Lookup Table (LUT)
      对于表 T T T M M M个入口,每个 n n n-bits,输入 ⟨ x ⟩ m \langle x\rangle^m xm ⟨ z ⟩ n \langle z\rangle^n zn,满足 z = T [ x ] z=T[x] z=T[x]。可以用1-out-of-m OT实现,通信量为 2 λ + M n 2\lambda +Mn 2λ+Mn bits。这是个查表的操作,输入和输出的位数是不同的。

    5 构建块协议

    5.1 零扩展和有符号扩展

    对于 m m m-bit的数 x ∈ Z M x\in \mathbb Z_M xZM,将其转换为 n n n-bit的数( n > m n>m n>m),这个过程就称为扩展(extension)。零扩展和有符号扩展分别用于扩展无符号数和有符号数的位宽。
    零扩展(Zero Extension)
    P 0 P_0 P0 P 1 P_1 P1两方输入 ⟨ x ⟩ m \langle x\rangle^m xm,扩展输出 ⟨ y ⟩ n \langle y\rangle^n yn,要求满足 u n i t ( x ) = u i n t ( y ) unit(x)=uint(y) unit(x)=uint(y)。对于 x m ∈ Z M x^m\in \mathbb Z_M xmZM,可以得到 【问:这个等式在后面广泛使用,没太理解怎么来的】【答:其实 − w M -wM wM就是实现的 m o d    M \mod M modM计算过程,防止求和在 Z 2 m \mathbb Z_{2^m} Z2m环上溢出】
    x m = ⟨ x ⟩ 0 m + ⟨ x ⟩ 1 m − w M x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM xm=x0m+x1mwM
    其中, w = w r a p ( ⟨ x ⟩ 0 m , ⟨ x ⟩ 1 m , M ) w=wrap(\langle x \rangle_0^m, \langle x \rangle_1^m, M) w=wrap(x0m,x1m,M),这是个boolean share,需要转换为算术share。这里考虑在 n − m n-m nm环上转换,原因就是下面的模约减步骤会使通信量大大降低。
    F B 2 A n − m ( ⟨ w ⟩ B ) = ⟨ w ⟩ n − m ∈ Z 2 n − m F_{B2A}^{n-m}(\langle w\rangle^B)=\langle w\rangle^{n-m}\in \mathbb Z_{2^{n-m}} FB2Anm(wB)=wnmZ2nm

    w = ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m w = \langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}-wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m} w=w0nm+w1nmwrap(w0nm,w1nm,Z2nm)2nm

    M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m} - wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m}) Mnw=Mn(w0nm+w1nmwrap(w0nm,w1nm,Z2nm)2nm)

    其中, M ∗ n w r a p ( ⋅ ) 2 n − m = M w r a p ( ⋅ ) 2 n − m m o d    N = w r a p ( ⋅ ) 2 n m o d    N = 0 M_{*n}wrap(\cdot)2^{n-m}=Mwrap(\cdot)2^{n-m} \mod N=wrap(\cdot)2^{n} \mod N=0 Mnwrap()2nm=Mwrap()2nmmodN=wrap()2nmodN=0(这一步称作“模约减”,modulo-reduce),所以上式子转换为:
    M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}) Mnw=Mn(w0nm+w1nm
    于是:
    y = ∑ b = 0 1 ( ⟨ x ⟩ b m − M ⟨ w ⟩ b n − m ) m o d    N y = \sum_{b=0}^1(\langle x\rangle_b^m-M\langle w\rangle_b^{n-m}) \mod N y=b=01(xbmMwbnm)modN
    这里是在 P 0 P_0 P0 P 1 P_1 P1上分别计算,然后求和取模,得到扩展后的结果。其中, x m o d    N = y x \mod N=y xmodN=y
    算法如下:
    在这里插入图片描述
    需要 log ⁡ ( m + 2 ) \log(m+2) log(m+2) rounds和少于 λ ( m + 1 ) + 13 m + n \lambda(m+1)+13m+n λ(m+1)+13m+n bits的通信量。作为对比,用GC实现零扩展和有符号扩展需要 λ ( 4 m + 2 n − 4 ) \lambda(4m+2n-4) λ(4m+2n4) bits的通信量,大约是SIRNN的6倍。

    有符号扩展(Signed Extension)
    有符号扩展可以基于以下等式,通过转换无符号扩展得到,在环 Z \mathbb Z Z上:
    i n t ( x ) = x ′ − 2 m − 1 , x ′ = x + 2 m − 1 m o d    M int(x)=x'-2^{m-1}, x'=x+2^{m-1} \mod M int(x)=x2m1,x=x+2m1modM
    证明如下:
    在这里插入图片描述
    于是:
    S E x t ( x , m , n ) = Z E x t ( x , m , n ) − 2 m − 1 SExt(x, m, n)=ZExt(x, m, n)-2^{m-1} SExt(x,m,n)=ZExt(x,m,n)2m1

    相比零扩展,没有额外的通信开销。

    5.2 截断

    首先,规定 > > L , > > A >>_L,>>_A >>L,>>A分别表示逻辑右移和算术右移,它们的输入和输出都是在 Z L \mathbb Z_L ZL环上。
    T R ( x , s ) TR(x, s) TR(x,s)表示截断且减小(truncate & reduce),将 x ∈ Z L x\in \mathbb Z_L xZL截断且减小 s s s-bits,最终得到的 x x x在更小的 Z 2 l − s \mathbb Z_{2^{l-s}} Z2ls环上。
    逻辑右移
    Toy example: x = 101001 x=101001 x=101001逻辑右移3位,则 x ′ = 000101 x'=000101 x=000101(右侧截掉,左侧补0)。
    对于 x ∈ Z L x\in \mathbb Z_L xZL,则 x = ⟨ x ⟩ 0 l + ⟨ x ⟩ 1 l m o d    L x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L x=x0l+x1lmodL,记 ⟨ x ⟩ b l = u b ∣ ∣ v b \langle x\rangle_b^l=u_b||v_b xbl=ubvb u b u_b ub是高位, v b v_b vb是低位),其中 u b ∈ { 0 , 1 } l − s , v b ∈ { 0 , 1 } s u_b\in\{0, 1\}^{l-s}, v_b\in\{0, 1\}^{s} ub{0,1}ls,vb{0,1}s。如下图:
    在这里插入图片描述
    根据前面提到的公式:
    x m = ⟨ x ⟩ 0 m + ⟨ x ⟩ 1 m − w M x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM xm=x0m+x1mwM
    可以得到:
    x > > L s = u 0 + u 1 − 2 l − s w r a p ( ⟨ x ⟩ 0 l , ⟨ x ⟩ 1 l , L ) + w r a p ( v 0 , v 1 , 2 s ) x>>_Ls=u_0+u_1-2^{l-s} wrap (\langle x\rangle_0^l, \langle x\rangle_1^l, L) + wrap(v_0, v_1, 2^s) x>>Ls=u0+u12lswrap(x0l,x1l,L)+wrap(v0,v1,2s)
    上式中, w r a p ( v 0 , v 1 , 2 s ) wrap(v_0, v_1, 2^s) wrap(v0,v1,2s)这一项是考虑了进位。我们知道,加性秘密共享时, v v v部分可能会存在1位进位的情况,所以 w r a p ( v 0 , v 1 , 2 s ) wrap(v_0, v_1, 2^s) wrap(v0,v1,2s)就是判断 v 0 + v 1 v_0+v_1 v0+v1是否大于 2 s 2^s 2s,如果是,则会进1,如果不是,则为0。
    常规做法是计算两个 w r a p ( ⋅ ) wrap(\cdot) wrap()值即可,但是SIRNN提出了一种优化,避开直接计算位宽是 l l l的那一项。文章中的Lemma 1即是这个引理:
    在这里插入图片描述
    通信开销低于 λ ( l + 3 ) + 15 + s + 20 \lambda(l+3)+15+s+20 λ(l+3)+15+s+20,并需要 log ⁡ l + 3 \log l+3 logl+3 rounds。
    原文证明如下:
    在这里插入图片描述

    算法如下:
    在这里插入图片描述

    算术右移
    对于无符号数,直接采用逻辑右移,对于有符号数,则需要采用算术右移。从前面零扩展到有符号扩展可以知道: i n t ( x ) = x ′ − 2 l − 1 , x ′ = x + 2 l − 1 m o d    L int(x)=x'-2^{l-1}, x'=x+2^{l-1} \mod L int(x)=x2l1,x=x+2l1modL,于是:
    x > > A s = x > > L s − 2 l − s − 1 x>>_As = x>>_Ls-2^{l-s-1} x>>As=x>>Ls2ls1

    截断且减小
    Toy example: x = 101001 x=101001 x=101001截断且减小3位,则 x ′ = 101 x'=101 x=101
    因为 2 l − s ∗ l w m o d    2 l − s = 0 2^{l-s}{*_l} w \mod 2^{l-s}=0 2lslwmod2ls=0(模约减),所以:
    ⟨ T R ( x , s ) ⟩ l − s = u 0 + u 1 + w r a p ( v 0 , v 1 , 2 s ) \langle TR(x, s)\rangle^{l-s}=u_0+u_1+wrap(v_0, v_1, 2^s) TR(x,s)ls=u0+u1+wrap(v0,v1,2s)

    除以power-of-2
    z < 0 , z = ⌈ i n t ( x ) / 2 s ⌉ m o d    L ; z ≥ 0 , z = ⌊ i n t ( x ) / 2 s ⌋ m o d    L z<0, z=\lceil int(x)/2^s\rceil \mod L; z\geq0, z=\lfloor int(x)/2^s\rfloor \mod L z<0,z=int(x)/2smodL;z0,z=int(x)/2smodL
    实际上 i n t ( x ) / 2 s m o d    L int(x)/2^s \mod L int(x)/2smodL就是做 > > A >>_A >>A,取整括号即是将值往0靠近。令 m x = 1 { x ≥ 2 l − 1 } m_x=1\{x\geq 2^{l-1}\} mx=1{x2l1}判断 x x x的正负性, c = 1 { x m o d    2 s = 0 } c=1\{x\mod 2^s=0\} c=1{xmod2s=0}
    m x = 1 m_x=1 mx=1,则 z < 0 , ⌈ z ⌉ z<0, \lceil z\rceil z<0,z;反之, ⌊ z ⌋ \lfloor z\rfloor z。所以有:
    D i v P o w 2 ( x , s ) = ( x > > A s ) + m x ∧ c DivPow2(x, s)=(x>>_As)+m_x\land c DivPow2(x,s)=(x>>As)+mxc

    5.3 混合位宽乘法

    以前做乘法通常是用Beaver Triplet三元组实现,SIRNN中不能用了,因为加法和乘法的数bitwidth不一致。
    无符号乘法
    输入 ⟨ x ⟩ m , ⟨ y ⟩ n \langle x\rangle^m, \langle y\rangle^n xm,yn,输出 ⟨ z ⟩ l , z = x ∗ l y , l = n + m \langle z\rangle^l, z=x*_l y, l=n+m zl,z=xly,l=n+m
    对于 x , y x,y x,y,在 Z \mathbb Z Z上有:
    u i n t ( x ) ⋅ u i n t ( y ) = ( x 0 + x 1 − 2 m w x ) ⋅ ( y 0 + y 1 − 2 n w y ) = x 0 y 0 + x 0 y 1 + x 1 y 0 + x 1 y 1 − 2 m w x y − 2 n w y x + 2 l w x w y uint(x)\cdot uint(y)=(x_0+x_1-2^mw_x)\cdot(y_0+y_1-2^nw_y)\\=x_0y_0+x_0y_1+x_1y_0+x_1y_1-2^mw_xy-2^nw_yx+2^lw_xw_y uint(x)uint(y)=(x0+x12mwx)(y0+y12nwy)=x0y0+x0y1+x1y0+x1y12mwxy2nwyx+2lwxwy
    观察上式, x 0 y 0 , x 1 y 1 x_0y_0,x_1y_1 x0y0,x1y1都是可以本地计算的【本地计算为什么不管位宽是否一致?】 2 l w x w y 2^lw_xw_y 2lwxwy可以在 m o d    L \mod L modL时被消掉(模约减), w x y , x y x w_xy, x_yx wxy,xyx是boolean share和算术share的计算,本质上是MUX,可用直接用OT实现。最难的一项是交叉项 x 0 y 1 , x 1 y 0 x_0y_1, x_1y_0 x0y1,x1y0,SIRNN采用COT实现。
    巧妙的一点在于:选择比特位短的一方作为receiver,比特位长的一方作为sender,这样在做OT的取数时,round数就会更少。
    交叉项算法如下:

    无符号乘法算法如下:
    在这里插入图片描述
    SIRNN利用1-out-of-2的COT来实现这个过程,将短的数按位拆解,每一位非0即1,然后做二选一的COT,每一位计算完成后,在本地累加起来。
    通信开销大约是: λ ( 3 μ + v ) + μ ( μ + 2 v ) + 16 ( m + n ) \lambda(3\mu + v) + \mu(\mu + 2v) + 16(m + n) λ(3μ+v)+μ(μ+2v)+16(m+n),其中 μ = min ⁡ ( m , n ) , ν = max ⁡ ( m , n ) \mu = \min(m, n), ν = \max(m, n) μ=min(m,n),ν=max(m,n)。普通的扩展位数然后相乘的开销是: 3 λ ( μ + v ) + ( m + n ) 2 + 15 ( m + n ) 3\lambda(\mu+v)+(m+n)^2+15(m + n) 3λ(μ+v)+(m+n)2+15(m+n),大约是SIRNN的1.5x。

    有符号乘法
    布尔分享转换为算术分享:
    ⟨ x ⟩ A = ⟨ x ⟩ 0 B + ⟨ x ⟩ 1 B − 2 ⟨ x ⟩ 0 B ⟨ x ⟩ 1 B \langle x\rangle^A=\langle x\rangle_0^B+\langle x\rangle_1^B-2\langle x\rangle_0^B\langle x\rangle_1^B xA=x0B+x1B2x0Bx1B
    基于前面无符号数和有符号数的关系,可以得到:无符号数 x ′ = x + 2 m − 1 m o d    M , y ′ = y + 2 n − 1 m o d    N x'=x+2^{m-1}\mod M, y'=y+2^{n-1}\mod N x=x+2m1modM,y=y+2n1modN。由秘密共享, x ′ = x 0 ′ + x 1 ′ m o d    M , y ′ = y 0 ′ + y 1 ′ m o d    N x'=x_0'+x_1' \mod M, y'=y_0'+y_1' \mod N x=x0+x1modM,y=y0+y1modN。有符号数 i n t ( x ) = x ′ − 2 m − 1 , i n t ( y ) = y ′ − 2 n − 1 int(x)=x'-2^{m-1}, int(y)=y'-2^{n-1} int(x)=x2m1,int(y)=y2n1。因此,在 Z \mathbb Z Z环上:
    在这里插入图片描述
    在这里插入图片描述

    x ′ y ′ x'y' xy是无符号数的乘法,可以用algorithm 3计算, 2 m − 1 y b ′ , 2 n − 1 x b ′ 2^{m-1}y_b', 2^{n-1}x_b' 2m1yb,2n1xb也都可以在本地计算出来。难点是wrap项应该如何计算。
    2 m + n − 1 w x ′ = 2 l − 1 w x ′ = 2 l − 1 ( ⟨ w x ′ ⟩ 0 B + ⟨ w x ′ ⟩ 1 B − 2 ⟨ w x ′ ⟩ 0 B ⟨ w x ′ ⟩ 1 B ) 2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B-2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B) 2m+n1wx=2l1wx=2l1(wx0B+wx1B2wx0Bwx1B)
    其中, 2 ⟨ w x ′ ⟩ 0 B ⟨ w x ′ ⟩ 1 B 2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B 2wx0Bwx1B 2 l − 1 2^{l-1} 2l1相乘再 m o d    L \mod L modL后会被消除掉,所以无需计算。因此,上式变为:
    2 m + n − 1 w x ′ = 2 l − 1 w x ′ = 2 l − 1 ( ⟨ w x ′ ⟩ 0 B + ⟨ w x ′ ⟩ 1 B ) 2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B) 2m+n1wx=2l1wx=2l1(wx0B+wx1B)
    有符号的乘法相比无符号的乘法,也没有额外的开销。

    矩阵乘法和卷积
    矩阵乘法和卷积是很常见的(实际上可以展开为普通乘法做elment-wise乘和加),两个矩阵 A ∈ Z M d 1 × d 2 , A ∈ Z N d 2 × d 3 A\in \mathbb Z_M^{d1\times d2}, A\in \mathbb Z_N^{d2\times d3} AZMd1×d2,AZNd2×d3,输出矩阵乘法结果 A ∈ Z L d 1 × d 3 A\in \mathbb Z_L^{d1\times d3} AZLd1×d3,其中 l = m + n l=m+n l=m+n。做矩阵乘法需要 d 2 d_2 d2次乘以及 d 2 − 1 d_2-1 d21次加。
    这个时候可能出现的问题是:加法导致溢出。一种解决方式是将element-wise乘后的结果扩展 e = ⌈ log ⁡ d 2 ⌉ e=\lceil \log d_2\rceil e=logd2-bits后,再做加法。但是,这样扩展开销很大,需要扩展 d 1 d 2 d 3 d_1d_2d_3 d1d2d3次。
    于是本文这样做:考虑到前面算交叉项(CrossTerm)时,通信round数取决于较小的bitwidth,所以本文将bitwidth较大的一项拿去扩展 e e e-bits,在不增加开销的情况下,扩大了环。
    通信开销大致为 λ ( 3 d 1 d 2 ( m + 2 ) + d 2 d 3 ( n + 2 ) ) + d 1 d 2 d 3 ( ( 2 m + 4 ) ( n + e ) + m 2 + 5 m ) \lambda(3d_1d_2(m+2)+d_2d_3(n+2))+d_1d_2d_3((2m+4)(n+e)+m^2+5m) λ(3d1d2(m+2)+d2d3(n+2))+d1d2d3((2m+4)(n+e)+m2+5m) bits。
    算法如下:
    在这里插入图片描述

    乘且截断
    首先调用有符号乘法,然后截断。输入 ⟨ x ⟩ m , ⟨ y ⟩ n \langle x\rangle^m, \langle y\rangle^n xm,yn,输出 ⟨ z ′ ⟩ l − s \langle z'\rangle^{l-s} zls z = i n t ( x ) ∗ l i n t ( y ) , z ′ = T R ( z , s ) z=int(x)*_l int(y), z'=TR(z, s) z=int(x)lint(y),z=TR(z,s)。其中 l = m + n l=m+n l=m+n

    5.4 数值分解和MSNZB (Most Significant Non-Zero Bit)

    数值分解
    l l l-bit的数分解为 c c c个长度为 d = l / c d=l/c d=l/c的子串或数值,使得 x = z c − 1 ∣ ∣ . . . ∣ ∣ z 0 x=z_{c-1}||...||z_0 x=zc1...z0
    算法如下:
    在这里插入图片描述

    MSNZB
    返回最高非零比特的索引:比如 x = 001010 x=001010 x=001010返回的就是3。
    算法如下:
    在这里插入图片描述

    5.5 MSB to Wrap Optimization

    本文大量依赖于 w = w r a p ( ⟨ x ⟩ 0 l , ⟨ x ⟩ 1 l , L ) w=wrap(\langle x\rangle_0^l, \langle x\rangle_1^l, L) w=wrap(x0l,x1l,L),一些情况下,我们能得到 m x = M S B ( x ) m_x=MSB(x) mx=MSB(x) ⟨ m x ⟩ B \langle m_x\rangle^B mxB,于是 w = ( ( 1 ⊕ m x ) ∧ ( m 0 ⊕ m 1 ) ⊕ ( m 0 ∧ m 1 ) ) w=((1\oplus m_x)\land (m_0\oplus m_1)\oplus(m_0\land m_1)) w=((1mx)(m0m1)(m0m1)),其中 m b = M S B ( ⟨ x ⟩ B l ) m_b=MSB(\langle x\rangle_B^l) mb=MSB(xBl)。当 m x m_x mx是秘密分享时,使用 ( 4 1 ) \binom{4}{1} (14)-OT;当 m x m_x mx是明文时,使用 ( 2 1 ) \binom{2}{1} (12)-OT。

    6 构建数学库

    6.1 指数

    r E x p ( z ) = e − z , z ∈ R + rExp(z)=e^{-z}, z\in \mathbb R^+ rExp(z)=ez,zR+的值,首先将输入 x x x分成 k k k段,然后每段在LUT (Look Up Table)进行查表,将得到的结果相乘。
    算法如下:
    在这里插入图片描述

    6.2 Sigmoid和Tanh

    s i g m o i d ( z ) = 1 1 + e − z sigmoid(z)=\frac{1}{1+e^{-z}} sigmoid(z)=1+ez1,可以表示如下:
    在这里插入图片描述
    其中, h ( z ) = 1 1 + r E x p ( z ) h(z)=\frac{1}{1+rExp(z)} h(z)=1+rExp(z)1的计算是先求 r E x p rExp rExp然后求倒数:
    在这里插入图片描述
    倒数则是采用Goldschmidt’s迭代近似算法实现,算法如下:
    在这里插入图片描述
    Tanh和sigmoid存在数学上的关系: T a n h ( z ) = e z − e − z e z + e − z = 2 s i g m o i d ( 2 z ) − 1 Tanh(z)=\frac{e^z-e^{-z}}{e^z+e^{-z}}=2sigmoid(2z)-1 Tanh(z)=ez+ezezez=2sigmoid(2z)1,所以可以用如上方式实现。

    6.3 平方根倒数

    计算 r s q r t ( x ) = 1 x rsqrt(x)=\frac{1}{\sqrt x} rsqrt(x)=x 1,为了防止分母为0,首先加上一个很小的 ϵ \epsilon ϵ r s q r t ( x ) = 1 x + ϵ rsqrt(x)=\frac{1}{\sqrt {x+\epsilon}} rsqrt(x)=x+ϵ 1
    首先,进行初始化,然后用Goldschmidt法进行迭代,
    算法如下:
    在这里插入图片描述
    在这里插入图片描述

    通信开销和轮次汇总

    在这里插入图片描述

    参考资料:
    SIRNN: A Math Library for Secure RNN Inference
    2021-10-02-SIRNN

  • 相关阅读:
    Spring依赖注入、循环依赖——三级缓存
    Druid连接池最小连接数设置失效问题
    机器人路径规划:基于Q-learning算法的移动机器人路径规划,可以自定义地图,修改起始点,提供MATLAB代码
    2023-06-13:统计高并发网站每个网页每天的 UV 数据,结合Redis你会如何实现?
    Maven的总结
    【已解决】 E45: ‘readonly‘ option is set (add ! to override)
    TypeScript(5)类、继承、多态
    Jmeter结构体系——Jmeter目录结构详解
    MotoGP Ignition:准备好参加 Spotlight 活动!
    全国见!飞桨星河社区五周年,邀你共赴大模型盛宴!
  • 原文地址:https://blog.csdn.net/qq_16763983/article/details/128117212