torch.distributions.Categorical()
功能:根据概率分布来产生sample,产生的sample是输入tensor的index
如:
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
tensor(3)
Pytorch中的强化学习
原文:https://www.cnblogs.com/sbj123456789/p/9692711.html