Numpy 和 PyTorch 中的索引
使用整型 list/array/tensor 进行索引
索引类型可以是 int list
、int_? ndarray
(任意 numpy int
类型)、long torch.tensor (int64)
。大多数函数返回的都是这种类型,例如 torch.topk()
、np.where()
等。
- 基本使用,单个数组的 index:返回一个跟 index 形状相同的数组,从 a 的 first dim 上取对应位置的数据
1 | 9) ** 2 a = torch.arange( |
索引数组中可以有重复值
- 上述情况下
a.dim() != index.dim()
时同理,只要求max(index) <= len(a)
1 | 8,1], [2,1]], dtype=torch.long) index = torch.tensor([[ |
- 多个数组的索引
1 | 12).view(3,4) a = torch.arange( |
a[i,j]
的机制是数组 i 和数组 j 相同位置的对应数字两两组成一对索引,然后用这对索引在 a 中进行取值。比如 i[0,0] == 0, j[0,0] == 2
,它们组成的索引对是 (0,2),在数组 a 中对应的值 a[0,2] == 2
。
- 上述情况下,
i.shape != j.shape
时,会先进行 broadcast
1 | 2,1]) j = torch.tensor([ |
shape 较小的先 broadcast 到较大的 shape,j -> [[2,1], [2,1]]
。具体可参考 PYTORCH BROADCASTING SEMANTICS。
i | j | index | |||||
---|---|---|---|---|---|---|---|
0 | 1 | 2 | 1 | [0, 2] | [1, 1] | ||
1 | 2 | 2 | 1 | [1, 2] | [1, 1] |
- 切片
1 | 12).view(3,4) a = torch.arange( |
如 result = a[i, :]
则是在 i.shape = (2, 2)
的 2×2 数组的每个位置内,放入 i
相同位置的元素在 a
中索引出的值(a
的行 i
的所有列元素 :
),结果 result.shape = (2, 2, 4)
。
- 赋值(In-place)
如果索引列表有重复值,赋值的话也会多次赋值,以最后一次赋值为准。但如果使用 Python 的 +=
运算,只会自增一次。
1 | 5) a = torch.arange( |
使用布尔 list/array/tensor 进行索引
索引类型可以是 bool list、bool np.array、byte torch.tensor(uint8)。>
、==
等操作返回的是这种类型。
比较简单,如果是 True/1
就表示选中,False/0
表示不选中。
1 | 12).reshape(3,4) a = np.arange( |