An assert statement says
# Make sure that the order is maintained assert torch.abs(test_combination - test_reals).sum() < 1e-4
I’d like to know how this unit test wants order to be maintained.
I understand that the purpose of this is to not shuffle the labels, however I’m not quite sure how to do this. Say p_real is 30% like in the example. It seems that sampling 30% from the reals and 70% from the fakes and just concatenating them in axis 0 such that 30% reals are followed by 70% fakes doesn’t work.
-
I tried getting the first 30% reals at first and getting the first 70% fakes, it didn’t work and I was getting a
sample isn't biased
assertion. -
I tried permuting the indices, then getting the first 30% indices and then sorting them to use them as an index filter for reals and fakes. It also didn’t work and I get a maintain order assertion.
-
I tried bernoulli sampling and masking such that I get 30% random reals and 70% random fakes, it also didn’t work and I get the maintain order assertion
-
I’ve also tried concatenating them in the opposite order, 70% fakes first then 30% reals but to no avail.
Here’s my best attempt
filt = torch.bernoulli(real, p_real)
reals = real[filt == 1]
fakes = fake[filt == 0]
target_images = torch.cat((fakes , reals), 0).reshape(real.shape)
Hence I’d like to ask what order is required? or what could I be missing here?
Thank you in advance for the responses