C1W4A_Build_a_Conditional_GAN_get_input_dimensions implementation

*https://www.coursera.org/learn/build-basic-generative-adversarial-networks-gans/programming/aFndv/conditional-gan/lab?path=%2Fnotebooks%2FC1W4A_Build_a_Conditional_GAN.ipynb
The docstring of this function

def get_input_dimensions(z_dim, mnist_shape, n_classes):
    '''
    Function for getting the size of the conditional input dimensions 
    from z_dim, the image shape, and number of classes.
    Parameters:
        z_dim: the dimension of the noise vector, a scalar
        mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
        n_classes: the total number of classes in the dataset, an integer scalar
                (10 for MNIST)
    Returns: 
        generator_input_dim: the input dimensionality of the conditional generator, 
                          which takes the noise and class vectors
        discriminator_im_chan: the number of input channels to the discriminator
                            (e.g. C x 28 x 28 for MNIST)

states the discriminator_im_channel as above.

However the following unit test
def test_input_dims():
gen_dim, disc_dim = get_input_dimensions(23, (12, 23, 52), 9)
assert gen_dim == 32
assert disc_dim == 21
test_input_dims()

regarding disc_dim doesn’t seem to match that. Specifically in the test mnist_shape = (12,23,52) so CxWxH should be 12x23x52 but the test is looking for 21 as in
assert disc_dim == 21.

Am I missing something?

Here is the paragraph of instructions that preceded that function:

Then, you can initialize your generator, discriminator, and optimizers. To do this, you will need to update the input dimensions for both models. For the generator, you will need to calculate the size of the input vector; recall that for conditional GANs, the generator's input is the noise vector concatenated with the class vector. For the discriminator, you need to add a channel for every class.

The last sentence there could have been a bit more clear, but what they are trying to say is that the discriminator channels should be the sum of C and the number of classes (12 + 9 in that test case).

Thanks! I had guessed that answer from the expected value of the test and finished the project. However I was questioning about the info in the docstring of the function.

Ah, ok, fair enough. That is a bit misleading. I’ll consider filing a bug about that, but the instructions that I showed above were clear enough.