1、裁剪
torch.clamp(a,-0.1,0.1) ==tf.clip()
2、复制
tf.tile(a,[1,2,1])和a.repeat(1,2,1)
3、增加维度
a.unsqueeze(0) == tf.expand_dim(a,axis=0)
4、L2归一化
- def normalize(x, axis=-1):
- x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x).clamp_min(1e-12))
- return x
- l2_normal_a = normalize(a) == tf.l2_normal(a,axis=1)
5、交换维度
b.permute(1,0) == tf.transopose(b,[1,0])
6.类型转换
- tf.cast(a,float16)==a.half()
- a.int() a.double() a.float() a.bool()
6、常量
torch.tensor([[1,2,3],[4,5,6.0]]) == tf.constant ([[1,2,3],[4,5,6.0]])
7、torch.zeros_like ==tf. zeros_like ones_like==ones_like where=where
torch.mean=tf.reduce_mean()
8、torch.Tensor.detach() 与tf.stop_gradient() (待验证)
9、torch.nn.LogSoftmax 与 tf.nn.log_softmax()
torch.nn.LogSoftmax(dim=-1)(a)==tf.nn.log_softmax(a, axis=-1)