I got the device runtime error on the last assertion test. I’ve set everything to device = ‘cuda’
as part of the exercise. i.e the real & fake input tensors to cuda as well as the rand tensors.
May I ask what this last assertion is doing? And whether it’s a bug that will always generate the RuntimeError.
If I set all my tensor inputs & target to ‘cuda’. This assertion will surely error out given ‘test_reals’ is always on cpu?
You definitely don’t want to hard-code device=‘cuda’ in your code. It’s OK as an experiment, of course, but not good coding practice for your final code. You want the code to be flexible, so it works with whatever device is being used. This is also important because the auto-grader runs in an environment that only uses cpu, but you’ll want to use gpu (for speed) when you’re running the exercise yourself.
This is actually the purpose of the unit test you point out - to check that combine_sample returns a result that has the same device as the input values have. If the input values have device = cuda, then the result should, too. Similarly, if the input values have device=‘cpu’, then so should the result.
As you noticed from your experiment, it can sometimes be hard to make sure the right device is being used. Pytorch is tricky that way, since it creates new tensors with device= cpu by default, unless using a function that specifically copies the input parameter’s device, like torch.clone(), or torch.zeros_like(), or torch.ones_like().
Check back over your code for combine_samples() and make sure that wherever you’re creating a tensor, you’re either using a function like torch.clone() that creates the tensor with the same device as the input parameter, OR make sure to specifically set the device yourself to match the device of the input parameters passed to combine_samples(). It might be helpful to temporarily add some print statements in your code to print the device of the different values to help narrow down where the problem is.
Thanks Wendy. That’s perfect explanation! This error definately help with general understanding on device setting and the use of .clone() function. Jason.