python 官方标准库 random 中,有个函数 random.choices(population, weights=None, *, cum_weights=None, k=1)
,比起常用的 random.choice(seq)
,这个函数可以指定概率权重和选择次数。
因为刷题的时候用到了这个函数,题目又对时间复杂度有限制,我就很好奇,然后来分析一下这个函数的时间复杂度。
def choices(self, population, weights=None, *, cum_weights=None, k=1):
"""Return a k sized list of population elements chosen with replacement.
If the relative weights or cumulative weights are not specified,
the selections are made with equal probability.
"""
random = self.random
n = len(population)
if cum_weights is None:
if weights is None:
floor = _floor
n += 0.0 # convert to float for a small speed improvement
return [population[floor(random() * n)] for i in _repeat(None, k)]
cum_weights = list(_accumulate(weights))
elif weights is not None:
raise TypeError(‘Cannot specify both weights and cumulative weights‘)
if len(cum_weights) != n:
raise ValueError(‘The number of weights does not match the population‘)
total = cum_weights[-1] + 0.0 # convert to float
if total <= 0.0:
raise ValueError(‘Total of weights must be greater than zero‘)
bisect = _bisect
hi = n - 1
return [population[bisect(cum_weights, random() * total, 0, hi)]
for i in _repeat(None, k)]
population
: 输入的待选取序列weights
: 权重序列cum_weights
: 累加的权重序列,相当于 weights
的前缀和数组k
: 选取的次数,该函数会返回一个长度为 k
的列表参考官方文档可知,这个函数通过权重随机选取数字,比如 choices([1, 2], weights=[3, 2])
,相当于使用 choice([1, 1, 1, 2, 2])
,也可以写成 choices([1, 2], cum_weights=[3, 5])
假设给出了权重(weights
)但是没有累加权重(cum_weights
):
cum_weights = list(_accumulate(weights))
;random()
函数输出一个 [0.0, 1.0)
区间的数,乘上所有权重的累加和,作为生成的随机数。权重的累加和也是 cum_weights
数组最后一个元素值;cum_weights
中找到随机数的位置,输出该位置的数据。函数共有 2 个出口:
weights
和 cum_weights
均为 None
的情况:
return [population[floor(random() * n)] for i in _repeat(None, k)]
时间复杂度:O(k) ,因为 k
为常数,所以也可以认为时间复杂度为 O(1)
这种情况和直接使用 choice
没有差别,所以我就不考虑在最终结果里了。
weights
不为 None
的情况:
return [population[bisect(cum_weights, random() * total, 0, hi)] for i in _repeat(None, k)]
时间复杂度:O(klog(n)),因为 k
为常数,所以也可以认为时间复杂度为 O(log(n)) (注:log(n) 来自二分查找)
cum_weights
为 None
,还需要执行 cum_weights = list(_accumulate(weights))
,_accumulate
类似于 itertools.accumulate()
,时间复杂度:O(n),与上面的 O(log(n)) 叠加,总时间复杂度为:O(n)所以结论在于用户有没有给出累加权重,也就是 cum_weights
数组:
cum_weights
:O(log(n)) ,精确一点就是 O(klog(n)) ,这个 k
就是那个参数 k
,是个常数。所以呢,如果数据规模特别大,还是要谨慎使用这个函数的,尤其是没有提供 cum_weights
参数的时候。
原文:https://www.cnblogs.com/adjwang/p/13908093.html