A fellow learner reached out and asked me about my YOLO implementation. Some of the code is a mess right now because I was trying to break it apart in order to be able to train the classification separately from localization (as suggested in one of the YOLO papers). The classification layers are a subset of the network architecture and produce a different output shape, then the localization is trained using transfer learning techniques but again, a different output shape. Pretty significant changes to the data structures and code. The approach I used for defining a custom class for the complicated YOLO loss function remains unchanged, though, and might be interesting, so here it is.
The first step is to define a Python class that extends, or inherits from, tensorflow.keras.losses.loss. It requires an __init__
method and a call
method. I have omitted some computation and housekeeping details of the call
method. Hopefully what remains matches your understanding of what the YOLO loss function is designed to do. [You might need to scroll the code window to see it all]
class Yolo_Loss(losses.Loss):
def __init__(self,true_object_locations_mask,matching_true_boxes, anchors, batch_size, name="yolo_loss"):
super().__init__(name=name)
self.true_object_locations_mask = true_object_locations_mask
self.matching_true_boxes = matching_true_boxes
self.anchors = anchors
def call(self, truth, predicted):
...
#extract ground truth values
truth_txy = truth[...,1:3] #Ground Truth centers - use the sigmoid for classification loss!
truth_twh = truth[...,3:5] #Ground Truth shape
truth_class_probs = K.softmax(truth[...,5:]) #Ground Truth class(es)
#extract predicted values from YOLO output object
predicted_to = predicted[...,0:1] #predicts object is there or not - use sigmoid for confidence loss!
predicted_txy = predicted[...,1:3] #predicts centers - use the sigmoid for classification loss!
predicted_twh = predicted[...,3:5] #predicts shapes - direct prediction
predicted_class_probs = K.cast(K.softmax(predicted[...,5:]),'float64') #predicts class(es)
...
# (0. - predicted_presence) is the error when there is NOT an object in GT
# (1. - predicted_presence) is the error when there IS an object in GT
no_objects_loss = no_object_weights * K.cast(K.square(0. - predicted_presence),'float64')
objects_loss = has_object_weights * K.cast(K.square(1. - predicted_presence),'float64')
confidence_loss = objects_loss + no_objects_loss
...
#classification loss for matching detections
matching_classes = K.cast(matching_true_boxes_batch[...,4:5],'float64') #GT class
classification_weights = CLASS_LAMBDA * true_object_locations_mask_batch
classification_loss = classification_weights * K.cast(K.square(matching_classes - predicted_class_probs),'float64')
...
#coordinates loss is only computed for true object locations
coordinates_weights = COORDINATES_LAMBDA * true_object_locations_mask_batch
coordinates_loss = coordinates_weights * K.cast(K.square(truth_t_boxes - predicted_t_boxes),'float64')
...
total_loss = 0.5 * (confidence_loss_sum + coordinates_loss_sum + classification_loss_sum)
I omitted some of the matrix housekeeping fluff. Once you have the class defined, you can instantiate it in the model definition process
#define custom loss function pointer for model
custom_loss_fn = Yolo_Loss(true_object_locations_mask, matching_true_boxes, use_anchors, TRAINING_BATCH_SIZE)
model = yolov2_full_detection()
#define optimizer per YOLO9000 paper
# We train the network ... for 160 epochs using stochastic gradient descent
# with a starting learning rate of 0.1, polynomial rate decay
# with a power of 4, weight decay of 0.0005 and momentum of 0.9
opt = tfa.optimizers.SGDW(learning_rate=0.1, momentum=0.9, weight_decay=0.0005)
#compile model
model.compile(optimizer=opt, loss=custom_loss_fn, run_eagerly=True)
afterwards you train and run just like any other model
#train model
history = model.fit(x_train,
y_train,
batch_size = TRAINING_BATCH_SIZE,
epochs=160,
callbacks=[CustomTrainingCallbacks(),
tensorboard_callback])
...
def predict(filename, model):
training_images = np.zeros((1, 416, 416, 3), dtype=float)
image = Image.open(filename)
training_images[0] = np.asarray(image) / 255.
return model.predict(training_images)
# run model
predictions = predict(filename, model)