排列后如何进行tensordot运算
How to make tensordot operations after permutation
我有 2 个张量,A 和 B:
A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]
张量 D 来自操作“tensordot -> permute”。我如何实现一个新的操作 f() 以在 f() 之后进行张量点操作,如:
A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)
您是否考虑过使用非常灵活的 torch.einsum
?
D = torch.einsum('ijab,abkl->ikjl', A, B)
tensordot
的问题是它在 B
之前输出 A
的所有维度,而您正在寻找的(排列时)是从“交错”维度A
和 B
.
我有 2 个张量,A 和 B:
A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]
张量 D 来自操作“tensordot -> permute”。我如何实现一个新的操作 f() 以在 f() 之后进行张量点操作,如:
A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)
您是否考虑过使用非常灵活的 torch.einsum
?
D = torch.einsum('ijab,abkl->ikjl', A, B)
tensordot
的问题是它在 B
之前输出 A
的所有维度,而您正在寻找的(排列时)是从“交错”维度A
和 B
.