Skip to content

398. Random Pick Index

Problem

Given an array, go through it for once to find a target value and output the index. If there are ties, pick each one with equal probability.

Solution

Pure reservoir sampling problem.

Algorithm

Reservoir sampling

Given an array of length unknown, only going through it for once and pick m out of it with equal probability.

We keep a set of size m. As we go through the first m elements, we just add them to the set.

If we arrived at the ith (i > m) element, since we don't know the length, this one could be the last element in the array. If it is the last, we have to pick it with a probability of m/i, which is the equal probability for each element. Since we don't know, we just treat it as the last.

If it is not picked, we just move to the next one. If it is picked, we replace one of the m elements in the set with equal probability.

We keep doing this until we get to the end of the array. No matter where we stop, we always have the equal probability to pick each one.

Prove with mathematical induction: If the m elements we picked after going through the first (i-1) elements is m/(i-1) each, the probability of keeping one of these m elements in the set instead of being replaced by the new one is: (1-m/i) + (m/i)*(1/m) = m/i. Since the condition is satisfied with i=m, so all the following should be satisfied. Proved.

A variation of the problem would be, We can only pick the elements satisfying certain condition, for example, only picking prime numbers in an array of integers with equal probability. We just change the definition of i from the number of elements we have gone through to the number of elements we have gone through and satisfiying the condition.

Code

import random


class Solution:

    def __init__(self, nums: List[int]):
        self.nums = nums

    def pick(self, target: int) -> int:
        count = 0
        ret = -1
        for index, value in enumerate(self.nums):
            if value == target:
                count += 1
                if random.random() < 1.0 / count:
                    ret = index
        return ret


# Your Solution object will be instantiated and called as such:
# obj = Solution(nums)
# param_1 = obj.pick(target)