-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
I'm trying to add a custom metric to calculate mean AP while training but it's not working the model is not printing it or calling it's update state
import tensorflow as tf
from xcenternet.model.constants import MAX_OBJECTS, IOU_THRESHOD, MAP_SCORE_THRESHOLD, LABELS
from xcenternet.model.evaluation.mean_average_precision import MAP
class MeanAP(tf.keras.metrics.Metric):
def __init__(self, name="custom_mean_ap", **kwargs):
super(MeanAP, self).__init__(name=name, **kwargs)
self.iou_threshold = IOU_THRESHOD
self.score_threshold = MAP_SCORE_THRESHOLD
self.mean_average_precision = MAP(LABELS, iou_threshold=self.iou_threshold, score_threshold=self.score_threshold)
self.map = self.add_weight(name="ctp", initializer="zeros")
self.batch_count = 0
def update_state(self, training_data, predictions):
mask, bboxes, labels = training_data["mask"], training_data["bboxes"], training_data["labels"]
decoded = self.decode(predictions, relative=False, k=MAX_OBJECTS)
self.mean_average_precision.update_state_batch(decoded, bboxes, labels, mask)
result = self.mean_average_precision.result()
self.map.assign_add(result*len(training_data["labels"]))
self.batch_count += len(training_data["labels"])
def result(self):
return self.map/self.batch_count
def reset_state(self):
self.batch_count = 0
self.map.assign(0.0)I added the metric to compile
model.compile(optimizer=optimizer, loss=model.get_loss_funcs(), metrics=[MeanAP()])Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels