Quantization aware training in tf2 object detection api


I would like to know if there is any way to perform quantization aware training (as well as pruning) for object detection models trained with the Tensorflow2 Object Detection API. Also, how do you guys optimize your models for production? Do you use TFLite? ONNX?
Thank you and I wish you good fortune in the courses to come :slight_smile:

Hello @Vilabella ,
TensorFlow has a Model Optimization Toolkit that you can use to optimize your computer vision models, you need to define a custom Keras model and then apply the tfmot.quantization.keras.quantize_model function to the model. This will create a new quantized model that can be trained with a quantization-aware training strategy. The strategy can be defined using the tfmot.quantization.keras.quantize_annotate_layer and tfmot.quantization.keras.quantize_apply functions.

Similarly, to perform pruning, you can use the tfmot.sparsity.keras.prune_low_magnitude function to apply pruning to the model. This will create a new pruned model that can be trained with a pruning training strategy. The strategy can be defined using the tfmot.sparsity.keras.prune_annotate_layer and tfmot.sparsity.keras.prune_low_magnitude functions.

Once you have defined your quantized or pruned model and the training strategy, you can use the standard Tensorflow2 Object Detection API training pipeline to train your model. The only difference is that you will need to use the tfmot.sparsity.keras.UpdatePruningStep or tfmot.quantization.keras.TFLiteConverter callbacks to update the pruning or quantization parameters during the training.

Here are some links:

  1. TensorFlow Model Optimization Toolkit: TensorFlow Model Optimization
  2. Quantization Aware Training with TensorFlow Model Optimization Toolkit: Quantization aware training  |  TensorFlow Model Optimization
  3. Pruning with TensorFlow Model Optimization Toolkit: Trim insignificant weights  |  TensorFlow Model Optimization
  4. Example of Quantization Aware Training with Object Detection API: https://github.com/tensorflow/models/tree/master/research/object_detection/g3doc/quantization
  5. Example of Pruning with Object Detection API: https://github.com/tensorflow/models/tree/master/research/object_detection/g3doc/model_pruning