PyTorch中的高效张量操作库Einops

PyTorch中的高效张量操作库Einops

einops 是一个功能强大且灵活的库,用于以可读和可靠的方式操作张量。它的名字来源于 “Einstein Operations”(爱因斯坦操作),其设计的灵感来自于爱因斯坦求和约定。einops 并非 PyTorch 的一部分,而是一个独立的库,但它与 PyTorch、NumPy, TensorFlow, JAX 等深度学习框架无缝集成。

使用 einops 的主要目的是让张量操作(如变形、重排、聚合)变得更加直观和不容易出错。

einops 的核心操作

einops 主要围绕三个核心函数展开:rearrangereducerepeat

1. rearrange - 重排和变形

这是 einops 中最常用的操作,用于改变张量的维度顺序和形状。它通过一种非常直观的字符串模式来定义操作。

语法: rearrange(tensor, 'pattern_from -> pattern_to')

示例:

  • 维度换位: 假设你有一个形状为 (batch, channels, height, width) 的图像张量,并想将其转换为 (batch, height, width, channels)

    1
    2
    3
    4
    5
    6
    7
    8
    import torch
    from einops import rearrange

    images = torch.randn(10, 3, 32, 32) # (batch, channels, height, width)
    # 使用 rearrange 进行维度重排
    rearranged_images = rearrange(images, 'b c h w -> b h w c')
    print(rearranged_images.shape)
    # 输出: torch.Size([10, 32, 32, 3])
  • 合并维度: 将高度和宽度合并为一个维度。

    1
    2
    3
    4
    # 将 h 和 w 合并
    merged_dims = rearrange(images, 'b c h w -> b c (h w)')
    print(merged_dims.shape)
    # 输出: torch.Size([10, 3, 1024])

2. reduce - 聚合操作

reduce 函数结合了重排和降维(聚合)操作,例如求和、取平均值或最大值。

语法: reduce(tensor, 'pattern_from -> pattern_to', 'reduction_op')

示例:

  • 全局平均池化: 对每个通道的高度和宽度维度进行平均池化。

    1
    2
    3
    4
    5
    6
    7
    from einops import reduce

    # 假设 images 是 (10, 3, 32, 32) 的张量
    # 对 h 和 w 维度求平均值
    global_avg_pool = reduce(images, 'b c h w -> b c', 'mean')
    print(global_avg_pool.shape)
    # 输出: torch.Size([10, 3])

3. repeat - 复制和扩展维度

repeat 用于在指定维度上复制张量的数据。

语法: repeat(tensor, 'pattern_from -> pattern_to', **axis_lengths)

示例:

  • 扩展维度: 将一个张量在新的维度上进行复制。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    import torch
    from einops import repeat

    single_image = torch.randn(3, 32, 32) # (channels, height, width)

    # 在新的 batch 维度上复制 10 次
    batch_of_images = repeat(single_image, 'c h w -> b c h w', b=10)
    print(batch_of_images.shape)
    # 输出: torch.Size([10, 3, 32, 32])

为什么在 PyTorch 中使用 einops

einops 提供了许多优于传统 PyTorch 张量操作(如 view, reshape, permute, transpose)的优势:

  • 可读性强: einops 的字符串模式清晰地描述了输入和输出的维度关系,使得代码更易于理解和维护。
  • 减少错误: 在使用 viewreshape 时,如果维度顺序不正确,可能会导致数据混乱而不报错。einops 的显式模式有助于在操作不匹配时立即引发错误。
  • 代码简洁: 复杂的多步张量操作,如 “unflattening”、转置和 “flattening” 的组合,通常可以通过一行 einops 代码完成。
  • 框架无关性: einops 的语法在不同的深度学习框架中保持一致,这使得代码的迁移和复用变得更加容易。
  • 强大的灵活性: 它能够轻松实现复杂的张量操作,例如将图像分割成块 (patch),这在 Vision Transformer 等模型中非常常见。

如何在 PyTorch 项目中集成 einops

首先,你需要安装 einops

1
pip install einops

然后,你可以像上面的例子一样,在你的 Python 脚本中导入并使用 rearrange, reduce, 和 repeateinops 还可以与 torch.nn.Module 结合使用,创建自定义的神经网络层。

总而言之,einops 是一个非常有用的工具,它可以显著提高 PyTorch 开发者的生产力和代码质量。通过提供一种更直观、更可靠的方式来处理张量操作,它让开发者可以更专注于模型架构和逻辑本身。


整理自: Einops For Tensor Operations


PyTorch中的高效张量操作库Einops
http://zl1bks.github.io/2025/09/20/PyTorch中的高效张量操作库Einops/
作者
zl1bks
发布于
2025年9月20日
许可协议