Numpy 和 PyTorch 中的索引
使用整型 list/array/tensor 进行索引
索引类型可以是 int list、int_? ndarray(任意 numpy int 类型)、long torch.tensor (int64)。大多数函数返回的都是这种类型,例如 torch.topk()、np.where() 等。
- 基本使用,单个数组的 index:返回一个跟 index 形状相同的数组,从 a 的 first dim 上取对应位置的数据
1 | a = torch.arange(9) ** 2 |
索引数组中可以有重复值
- 上述情况下
a.dim() != index.dim()时同理,只要求max(index) <= len(a)
1 | index = torch.tensor([[8,1], [2,1]], dtype=torch.long) |
- 多个数组的索引
1 | a = torch.arange(12).view(3,4) |
a[i,j] 的机制是数组 i 和数组 j 相同位置的对应数字两两组成一对索引,然后用这对索引在 a 中进行取值。比如 i[0,0] == 0, j[0,0] == 2,它们组成的索引对是 (0,2),在数组 a 中对应的值 a[0,2] == 2。
- 上述情况下,
i.shape != j.shape时,会先进行 broadcast
1 | j = torch.tensor([2,1]) |
shape 较小的先 broadcast 到较大的 shape,j -> [[2,1], [2,1]]。具体可参考 PYTORCH BROADCASTING SEMANTICS。
| i | j | index | |||||
|---|---|---|---|---|---|---|---|
| 0 | 1 | 2 | 1 | [0, 2] | [1, 1] | ||
| 1 | 2 | 2 | 1 | [1, 2] | [1, 1] | ||
- 切片
1 | a = torch.arange(12).view(3,4) |
如 result = a[i, :] 则是在 i.shape = (2, 2) 的 2×2 数组的每个位置内,放入 i 相同位置的元素在 a 中索引出的值(a 的行 i 的所有列元素 :),结果 result.shape = (2, 2, 4)。
- 赋值(In-place)
如果索引列表有重复值,赋值的话也会多次赋值,以最后一次赋值为准。但如果使用 Python 的 += 运算,只会自增一次。
1 | a = torch.arange(5) |
使用布尔 list/array/tensor 进行索引
索引类型可以是 bool list、bool np.array、byte torch.tensor(uint8)。>、== 等操作返回的是这种类型。
比较简单,如果是 True/1 就表示选中,False/0 表示不选中。
1 | a = np.arange(12).reshape(3,4) |