Combine_sample assert

Hello! Seems like my combine_sample implementation fails because of this assert:

test_combination = combine_sample(
    torch.ones(n_test_samples, 10, 10), 
    torch.zeros(n_test_samples, 10, 10), 
    0.8
)
# Check that the shape is right
assert tuple(test_combination.shape) == (n_test_samples, 10, 10)
# Make sure that no mixing happened
assert torch.abs((test_combination.sum([1, 2]).median()) - 100) < 1e-5

I don’t understand why median 100 is expected when summing 100 Bernoulli variables with p = 0.8? My solution gives 80 and I guess that’s correct.

Hi @Vanster, welcome to the community!

The mean of test_combination.sum([1, 2]) should be 80, but not the median.

combine_sample(
    torch.ones(n_test_samples, 10, 10), 
    torch.zeros(n_test_samples, 10, 10), 
    0.8
)

should produce an output of the same shape as its inputs, where each channel along the first dimension (i.e. each image) is either all ones (with probability p_real=0.8) or all zeroes.

This means that, for any given image, the sum of all the components will yield either 0 or the size of the image (in this case 10x10 = 100).

test_combination.sum([1, 2]) does exactly that sum over each image, and the resulting vector will be composed of 0s and 100s. Taking the mean of that should give about 80, because we set the probability p_real as 0.8. But the median should really be 100, because there are more 100s (~80%) than 0s (~20%), so the median of the distribution falls at 100.

I believe your combine_sample implementation is sampling at each pixel, when in reality you should sample at each image. Please double check your code and get back to me if you have any more questions.

Hope that made sense for you.

4 Likes

Hi @pedrorohde! Yeah, you’re right, I thought we need to sample each pixel separately. Thank you very much!

1 Like