Skip to content

[index_fill]在给定维度填充指定val

PyTorch提供了函数index_fill_用于在张量的指定维度上填充指定val

定义

index_fill_(dim, index, val) → Tensor

  • dim:给定维度。0表示行,1表示列
  • indexLongTensor。给定维度下的指定下标
  • val:待填充值

示例

对于大小为\(3\times 4\)的张量

>>> import torch
>>> a = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> a
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

填充第1/3行,大小为33

>>> index=torch.LongTensor([0, 2])
>>> index
tensor([0, 2])
>>> a.index_fill(0, index, 33)
tensor([[33., 33., 33., 33.],
        [ 4.,  5.,  6.,  7.],
        [33., 33., 33., 33.]])

填充第2/3列,大小为-1

>>> index=torch.LongTensor([1,2])
>>> index
tensor([1, 2])
>>> a.index_fill(1, index, -1)
tensor([[ 0., -1., -1.,  3.],
        [ 4., -1., -1.,  7.],
        [ 8., -1., -1., 11.]])