einsum全称Einstein summation convention(爱因斯坦求和约定),又称为爱因斯坦标记法,是爱因斯坦1916年提出的一种标记约定,简单的说就是省去求和式中的求和符号,例如下面的公式:
$c=\sum_ia_ib_i$
以einsum的写法就是:
$c=a_ib_i$
后者将$\sum$符号给省去了,显得更加简洁;再比如:
$c_i=\sum_ja_{ij}b_j$ (1)
$c_{klmn}=\sum_i\sum_ja_{ijkl}b_{ijmn}$ (2)
上面两个例子换成einsum的写法就变成:
$c_i=a_{ij}b_j$ (1)
$c_{klmn}=a_{ijkl}b_{ijmn}$ (2)
在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,简单的一些还好,有时碰到例如矩阵转置、矩阵乘法、求迹、张量乘法、数组求和等等,若是以分别以transopse、sum、trace、tensordot等函数实现的话,不但复杂,还容易出错。
现在,这些问题统统可以用einsum函数搞定,einsum函数就是根据上面的标记法实现的一种函数,可以根据给定的表达式进行运算,可以替代但不限于以下函数:
矩阵求迹:trace
求矩阵对角线:diag
张量(沿轴)求和:sum
张量转置:transopose
矩阵乘法:dot
张量乘法:tensordot
向量内积:inner
外积:outer
该函数在numpy、tensorflow、pytorch上都有实现,用法基本一样,定义如下:
einsum(equation, *operands)
equation是字符串的表达式,operands是操作数,是一个元组参数,并不是只能有两个,所以只要是能够通过einsum标记法表示的乘法求和公式,都可以用一个einsum解决,下面以numpy举几个例子:
元素求和

矩阵乘法

张量乘法

pytorch版本
# trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) # batch matrix multiplication >>> As = torch.randn(3,2,5) >>> Bs = torch.randn(3,5,4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3,5,4) >>> l = torch.randn(2,5) >>> r = torch.randn(2,4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])