👉torch.multinomial的源码见https://github.com/dongjinkun/PyTorch/tree/main/torch
一、前言
torch.multinomial()方法多出现在需要采样的场景中,如强化学习。具体讲,当使用强化学习解决旅行商问题时,针对某一个instance,trained model会逐步推理得到下一步要经过的城市,直至所有的城市都访问过。在某个时刻t,都会从下一个访问城市的概率分布中均匀采样得到下一个要访问的城市。OK! 上述例子不理解没关系,言归正传,正式开始介绍torch.multinomial()方法。
二、方法解析
首先,查阅PyTorch官方对multinomial方法的介绍:
- 从一个概率分布中采样num_samples次,返回一个包含采样结果索引的tensor。
- 输入:input<=>概率分布,nums_samples<=>采样次数
- 输出:返回值是包含采样结果索引的tensor,注意,返回的是索引,并不是采样结果。
其次,值得关注的几点如下:
- 1.input中的元素和不需要为1,这种情况下,input中的元素被视为权重系数,但必须是非负的、有限且元素和不等于0的。
- 2.input的类型必须为float型。
- 3.如果input是一个vector,那么返回的tensor的size为采样的次数。
- 4.如果input是一个shape为(m,-1)的tensor,那么返回的tensor的size为(m,num_samples)。
- 5.如果replacement=True,表示采样可以放回抽样,即一个元素可能被重复抽中;如果replacement=Fasle,表示采样是不放回抽样,一个元素只能被抽中一次,在这种情况下,num_sample要小于等于input中非0元素的个数。
铺垫了这么多,下面通过几个案例详细的介绍一下torch.multinomial()方法的使用。
三、案例分析
- 3.1 input为vector, replacement=False
- 3.2 input的shape为(m, -1), replacement=False
- 3.3 replacement=True