使用整型 list/array/tensor 进行索引

索引类型可以是 int listint_? ndarray(任意 numpy int 类型)、long torch.tensor (int64)。大多数函数返回的都是这种类型,例如 torch.topk()np.where() 等。

  1. 基本使用,单个数组的 index:返回一个跟 index 形状相同的数组,从 a 的 first dim 上取对应位置的数据
1
2
3
4
5
6
7
8
>>> a = torch.arange(9) ** 2
>>> a
tensor([ 0, 1, 4, 9, 16, 25, 36, 49, 64])
>>> index = torch.tensor([1,3,1,5,8], dtype=torch.long)
>>> index
tensor([1, 3, 1, 5, 8])
>>> a[index]
tensor([ 1, 9, 1, 25, 64])

索引数组中可以有重复值

  1. 上述情况下 a.dim() != index.dim() 时同理,只要求 max(index) <= len(a)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> index = torch.tensor([[8,1], [2,1]], dtype=torch.long)
>>> index
tensor([[8, 1],
[2, 1]])
>>> a[index]
tensor([[64, 1],
[ 4, 1]])

>>> a = a.view(3, 3)
>>> index = torch.tensor([2,0])
>>> a[index]
tensor([[6, 7, 8],
[0, 1, 2]])
>>> index = torch.tensor([5,0])
>>> a[index]
RuntimeError: "index 5 is out of bounds for dim with size 3"
  1. 多个数组的索引
1
2
3
4
5
6
7
8
9
10
>>> a = torch.arange(12).view(3,4)
>>> a
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> i = torch.tensor([[0,1], [1,2]])
>>> j = torch.tensor([[2,1], [3,3]])
>>> a[i, j]
tensor([[ 2, 5],
[ 7, 11]])

a[i,j] 的机制是数组 i 和数组 j 相同位置的对应数字两两组成一对索引,然后用这对索引在 a 中进行取值。比如 i[0,0] == 0, j[0,0] == 2,它们组成的索引对是 (0,2),在数组 a 中对应的值 a[0,2] == 2

  1. 上述情况下,i.shape != j.shape 时,会先进行 broadcast
1
2
3
4
>>> j = torch.tensor([2,1])
>>> a[i,j]
tensor([[2, 5],
[6, 9]])

shape 较小的先 broadcast 到较大的 shape,j -> [[2,1], [2,1]]。具体可参考 PYTORCH BROADCASTING SEMANTICS

i j index
0 1 2 1 [0, 2] [1, 1]
1 2 2 1 [1, 2] [1, 1]
  1. 切片
1
2
3
4
5
6
>>> a = torch.arange(12).view(3,4)
>>> a
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> i = torch.tensor([[0,1], [1,2]])

result = a[i, :] 则是在 i.shape = (2, 2) 的 2×2 数组的每个位置内,放入 i 相同位置的元素在 a 中索引出的值(a 的行 i 的所有列元素 :),结果 result.shape = (2, 2, 4)

  1. 赋值(In-place)

如果索引列表有重复值,赋值的话也会多次赋值,以最后一次赋值为准。但如果使用 Python 的 += 运算,只会自增一次。

1
2
3
4
5
6
7
8
>>> a = torch.arange(5)
>>> a[[0,0,2]] = torch.tensor([1,2,3])
>>> a
tensor([2, 1, 3, 3, 4])

>>> a[[0,0,2]] += 1
>>> a
tensor([3, 1, 4, 3, 4])

使用布尔 list/array/tensor 进行索引

索引类型可以是 bool list、bool np.array、byte torch.tensor(uint8)。>== 等操作返回的是这种类型。

比较简单,如果是 True/1 就表示选中,False/0 表示不选中。

1
2
3
4
5
>>> a = np.arange(12).reshape(3,4)
>>> index = np.array([False,True,True])
>>> a[index, :]
array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])