Batch size problem in Exercise 6

Hello. In Exercise 6, I’m getting the following error from unit tests

ValueError: Inputs batch size (1) does not match batch_size arg (0.

I don’t really understand how batch size can be zero, the appropriate code raising the exception is

raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match '
                        ​f'batch_size arg ({batch_size}.')

I tried printing the variables and I have the input sequence ‘test’, which is tokenized to [2074] with shape (1,) and then reshaped to [[2074]] with shape (1,1), which I think is the correct way to do that.
I do the reshaping the following way

    # Create input tokens using the the tokenize function
    input_tokens = tokenize(start_sentence, vocab_file, vocab_dir)

    # Add batch dimension to array. Convert from (n,) to (1, n)
    input_tokens_with_batch = np.reshape(input_tokens, (1,-1))    

Could someone please help me understand what I am doing wrong? Thanks!

1 Like

Hi Jan,

I believe what you did for reshaping is correct.

I had the same issue, but later figured out it was because of how I implemented output_gen - did you only put input_tokens_with_batch and temperature rather than inputs = input_tokens_with_batch and temperature = temperature? I solved the issue by doing the latter.

Hope this helps.

Cheers,
Weimin

13 Likes

Agreed with @Weimin - the way the arguments are laid out in UNQ_6 is misleading. Looking at the documentation for autoregressive_sample_stream() you can see the order of arguments is:

model , inputs=None , batch_size=1 , temperature=1.0

so by feeding in the suggested order in the code (# model, # inputs, and # temperature, according to the code comments), you’d be setting batch_size=temperature, which causes the error posted.

As suggested in the other answer, to avoid this issue just assign values explicitly to their parameters:


trax.supervised.decoding.autoregressive_sample_stream( 
        # model
        model=...,
        # inputs will be the tokens with batch dimension
        inputs=...,
        # temperature
        temperature=...
    )
7 Likes

+1

This explains the decimal batch size argument in the first post. Got me, too.

It would be nice to note it in assigment at least, that parameters not in positional order. It was really unclear.

I ran into this issue as well. This post helped me fix it, thanks!