a = torch.zeros([5,5])
index = (torch.LongTensor([0,1]),torch.LongTensor([1,2]))
a.index_put_((index), torch.Tensor([1,1]))
a[index] = torch.Tensor([4,4])
print(a)
tensor([[0., 4., 0., 0., 0.],
[0., 0., 4., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
pytorch之tensor按索引赋值,三种方法[https://blog.csdn.net/qq_41368074/article/details/106986753]
原文:https://www.cnblogs.com/yanghailin/p/13206418.html