torch.einsum
是 PyTorch 中一个强大且灵活的张量运算函数,基于爱因斯坦求和约定进行操作。它允许用户通过简单的字符串表达式来定义复杂的张量运算,代替显式的循环或多个矩阵乘法操作。
函数签名
torch.einsum(equation, *operands) → Tensor
参数
equation
: 一个字符串,描述了张量间的操作关系。它使用爱因斯坦求和约定,用逗号分隔不同张量的索引,使用箭头(->
)定义输出的形状。- 左侧部分是输入张量的维度索引,逗号分隔。
- 右侧是输出张量的维度索引。如果没有提供输出维度,函数默认对所有不重复的索引进行求和。
*operands
: 需要操作的张量。张量的维度必须与equation
中的描述匹配。
爱因斯坦求和约定
爱因斯坦求和约定是一种简化张量运算的表示方式。它假设对所有重复的索引进行求和。例如:
'ij,jk->ik'
表示矩阵乘法,其中i
和k
是保留下来的维度,j
是求和的维度。
示例
1. 矩阵乘法
矩阵乘法可以通过 torch.einsum
实现:
import torch
A = torch.tensor([[1, 2],
[3, 4]])
B = torch.tensor([[5, 6],
[7, 8]])
# 等同于 torch.matmul(A, B)
result = torch.einsum('ij,jk->ik', A, B)
print(result)
在这里:
A
的维度索引为ij
(行和列)。B
的维度索引为jk
。- 输出的维度索引为
ik
,表示保留A
的行和B
的列。
2. 向量内积
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
# 等同于 torch.dot(x, y)
result = torch.einsum('i,i->', x, y)
print(result)
这里的 'i,i->'
表示对 i
这个索引求和,即两个向量的内积。
3. 批次矩阵乘法
对批次的矩阵执行乘法:
A = torch.randn(3, 2, 4) # 3 个 2x4 矩阵
B = torch.randn(3, 4, 5) # 3 个 4x5 矩阵
# 执行批次矩阵乘法
result = torch.einsum('bij,bjk->bik', A, B)
print(result.shape) # 输出 (3, 2, 5)
bij
表示A
的维度,其中b
是批次维度,i
是行,j
是列。bjk
表示B
的维度,其中b
是批次维度,j
是行,k
是列。- 输出
bik
保留批次b
维度和矩阵乘法后的行和列。
4. 广播加法
A = torch.randn(3, 4)
B = torch.randn(4)
# 广播加法
result = torch.einsum('ij,j->ij', A, B)
print(result.shape) # 输出 (3, 4)
这里,B
的维度被广播扩展到与 A
的第二维度匹配。
总结
torch.einsum
适合各种线性代数操作,例如矩阵乘法、点积、转置、广播等。- 使用
einsum
可以使代码更简洁,并避免繁琐的手动张量操作。 - 它可以处理多维张量,并通过显式地定义每个维度的映射关系,适用于很多高效的深度学习运算。
torch.einsum
是一个高效、简洁的工具,可以替代多个显式的 torch.matmul
或 torch.bmm
等函数,使代码更加直观和灵活。
本站资源均来自互联网,仅供研究学习,禁止违法使用和商用,产生法律纠纷本站概不负责!如果侵犯了您的权益请与我们联系!
转载请注明出处: 免费源码网-免费的源码资源网站 » pytorch torch.einsum函数介绍
发表评论 取消回复