tf.data.Datasets and TFX

Hello,

I know at some point that it was mentioned that the ETL process for training data should be optimized with tf.data.Datasets so the CPU and GPU do not have much downtime (something I’ve done before locally), but I haven’t seen this anywhere in the course. The course also didn’t seem to touch much on how to use the Trainer component of TFX. Therefore the question I have is, is ETL optimization something that is handled automatically by TFX, or is it something that must be handled manually as in a local data pipeline using tf.data.Datasets?

Thanks!

Hi Gage! Welcome to Discourse and good question! The Transform component stores your files as TFRecords and these are fed to the Tuner and Trainer component as one of its arguments. The _input_fn() function in the trainer.pyof the Week 1 Lab 2 (TFX Tuner and Trainer) takes care of converting these into TF Datasets before starting the training.

def _input_fn(file_pattern,
              tf_transform_output,
              num_epochs=None,
              batch_size=32) -> tf.data.Dataset:
  '''Create batches of features and labels from TF Records

  Args:
    file_pattern - List of files or patterns of file paths containing Example records.
    tf_transform_output - transform output graph
    num_epochs - Integer specifying the number of times to read through the dataset. 
            If None, cycles through the dataset forever.
    batch_size - An int representing the number of records to combine in a single batch.

  Returns:
    A dataset of dict elements, (or a tuple of dict elements and label). 
    Each dict maps feature keys to Tensor or SparseTensor objects.
  '''
  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())
  
  dataset = tf.data.experimental.make_batched_features_dataset(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=transformed_feature_spec,
      reader=_gzip_reader_fn,
      num_epochs=num_epochs,
      label_key=LABEL_KEY)
  
  return dataset

This helper function is called in the run_fn() of the same module to create your train and eval sets.

  # Create batches of data good for 10 epochs
  train_set = _input_fn(fn_args.train_files[0], tf_transform_output, 10)
  val_set = _input_fn(fn_args.eval_files[0], tf_transform_output, 10)

Hope this helps! Will take note of this so we can revise the markdown for clarity. Thanks!

Thanks, I’m going to try it out this week. I’ll post any insights I have!