DLS week 3 exercise 3, one_hot_matrix error

I implemented the one_hot_matrix where I output the one hot encoding using this line of code:

    one_hot = tf.reshape(tf.one_hot(label, depth, axis = 0), depth)
    return one_hot

and all tests were passed succesfully.

However, later on in this code block:

new_y_test = y_test.map(one_hot_matrix)
new_y_train = y_train.map(one_hot_matrix)

it gives out the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-123-cefe2ef5b6b2> in <module>
----> 1 new_y_test = y_test.map(one_hot_matrix)
      2 new_y_train = y_train.map(one_hot_matrix)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
   1693     """
   1694     if num_parallel_calls is None:
-> 1695       return MapDataset(self, map_func, preserve_cardinality=True)
   1696     else:
   1697       return ParallelMapDataset(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
   4043         self._transformation_name(),
   4044         dataset=input_dataset,
-> 4045         use_legacy_function=use_legacy_function)
   4046     variant_tensor = gen_dataset_ops.map_dataset(
   4047         input_dataset._variant_tensor,  # pylint: disable=protected-access

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
   3369       with tracking.resource_tracker_scope(resource_tracker):
   3370         # TODO(b/141462134): Switch to using garbage collection.
-> 3371         self._function = wrapper_fn.get_concrete_function()
   3372         if add_to_graph:
   3373           self._function.add_to_graph(ops.get_default_graph())

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
   2937     """
   2938     graph_function = self._get_concrete_function_garbage_collected(
-> 2939         *args, **kwargs)
   2940     graph_function._garbage_collector.release()  # pylint: disable=protected-access
   2941     return graph_function

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   2904       args, kwargs = None, None
   2905     with self._lock:
-> 2906       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2907       seen_names = set()
   2908       captured = object_identity.ObjectIdentitySet(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3211 
   3212       self._function_cache.missed.add(call_context_key)
-> 3213       graph_function = self._create_graph_function(args, kwargs)
   3214       self._function_cache.primary[cache_key] = graph_function
   3215       return graph_function, args, kwargs

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3073             arg_names=arg_names,
   3074             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075             capture_by_value=self._capture_by_value),
   3076         self._function_attributes,
   3077         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
   3362           attributes=defun_kwargs)
   3363       def wrapper_fn(*args):  # pylint: disable=missing-docstring
-> 3364         ret = _wrapper_helper(*args)
   3365         ret = structure.to_tensor_list(self._output_structure, ret)
   3366         return [ops.convert_to_tensor(t) for t in ret]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
   3297         nested_args = (nested_args,)
   3298 
-> 3299       ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
   3300       # If `func` returns a list of tensors, `nest.flatten()` and
   3301       # `ops.convert_to_tensor()` would conspire to attempt to stack

/opt/conda/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e, 'ag_error_metadata'):
--> 258           raise e.ag_error_metadata.to_exception(e)
    259         else:
    260           raise

ValueError: in user code:

    <ipython-input-72-4d1c2aa9c8ff>:17 one_hot_matrix  *
        one_hot = tf.reshape(tf.one_hot(label, depth, axis = 0), depth)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:195 reshape
        result = gen_array_ops.reshape(tensor, shape, name)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py:8234 reshape
        "Reshape", tensor=tensor, shape=shape, name=name)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:744 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:593 _create_op_internal
        compute_device)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3485 _create_op_internal
        op_def=op_def)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1975 __init__
        control_input_ops, op_def)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1815 _create_c_op
        raise ValueError(str(e))

    ValueError: Shape must be rank 1 but is rank 0 for '{{node Reshape}} = Reshape[T=DT_FLOAT, Tshape=DT_INT32](one_hot, Reshape/shape)' with input shapes: [6], [].

I have to change the one hot matrix line of code to have a shape of (depth, ) instead of just ‘depth’, and then it works.
My question here is why does that happen although the tests were passed and the shapes were correct?

Following, as I had the same issue.
Thank you for the fix

yes same for me? not sure why

I don’t know why it fixes it but encase the depth used in the reshape function in square brackets like so [depth]. It seems to fix the problem for me.