ObjDetPredictionHandler(
prediction_filter,
prediction_prefix="pred",
pred_identifier="image_location",
detection_idx_col=DETECTION_INDEX_COLUMN_NAME,
apply_sigmoid=True,
)
Bases: PredictionHandler
Prediction handler for object detection predictions (BoundingBox, class prediction)
Initializes the ObjDetPredictionHandler
Source code in niceml/mlcomponents/predictionhandlers/objdetpredictionhandler.py
| def __init__( # noqa: PLR0913
self,
prediction_filter: PredictionFilter,
prediction_prefix: str = "pred",
pred_identifier: str = "image_location",
detection_idx_col: str = DETECTION_INDEX_COLUMN_NAME,
apply_sigmoid: bool = True,
):
"""Initializes the ObjDetPredictionHandler"""
super().__init__()
self.prediction_filter = prediction_filter
self.prediction_prefix = prediction_prefix
self.apply_sigmoid = apply_sigmoid
self.pred_identifier = pred_identifier
self.detection_idx_col = detection_idx_col
self.data = None
self.data_columns = [pred_identifier, detection_idx_col]
self.data_columns += list(asdict(BoundingBox(0, 0, 0, 0)).keys())
self.anchor_generator = AnchorGenerator()
self.anchors = None
self.anchor_array = None
|
Functions
__enter__
Init self.data
after the context is entered
Source code in niceml/mlcomponents/predictionhandlers/objdetpredictionhandler.py
| def __enter__(self):
"""Init `self.data` after the context is entered"""
self.data = []
if isinstance(self.data_description, OutputObjDetDataDescription):
for class_count in range(self.data_description.get_output_class_count()):
self.data_columns.append(f"{self.prediction_prefix}_{class_count:04d}")
return self
|
__exit__
__exit__(exc_type, exc_value, exc_traceback)
Save the data in self.data
as a parquet file
Source code in niceml/mlcomponents/predictionhandlers/objdetpredictionhandler.py
| def __exit__(self, exc_type, exc_value, exc_traceback):
"""Save the data in `self.data` as a parquet file"""
if self.data is None:
logging.getLogger(__name__).warning(
"PredictionHandler: %s has no data to write!",
self.filename,
)
else:
data_frame: pd.DataFrame = pd.DataFrame(self.data)
self.exp_context.write_parquet(
data_frame,
join(ExperimentFilenames.PREDICTION_FOLDER, self.filename + ".parq"),
)
|
add_prediction
add_prediction(data_info_list, prediction_batch)
Gets the results of an object detection model, de
Source code in niceml/mlcomponents/predictionhandlers/objdetpredictionhandler.py
| def add_prediction(
self, data_info_list: List[ObjDetDataInfo], prediction_batch: np.ndarray
):
"""Gets the results of an object detection model, de"""
output_dd: OutputObjDetDataDescription = check_instance(
self.data_description, OutputObjDetDataDescription
)
for curr_batch, curr_data_info in zip(prediction_batch, data_info_list):
decoded_box_predictions = self._decode_box_predictions(
box_predictions=curr_batch
)
if self.apply_sigmoid:
decoded_box_predictions = apply_sigmoid_on_cls_predictions(
decoded_box_predictions,
output_dd.get_coordinates_count(),
)
filtered_box_predictions = self.prediction_filter.filter(
decoded_box_predictions
)
if len(filtered_box_predictions) > 0:
for curr_index, prediction in enumerate(filtered_box_predictions):
self._add_data(
identifier=curr_data_info.get_identifier(),
prediction=prediction,
detection_index=curr_index,
)
else:
prediction = np.zeros(
(
output_dd.get_coordinates_count()
+ output_dd.get_output_class_count(),
)
)
self._add_data(
curr_data_info.get_identifier(),
prediction=prediction,
detection_index=NO_PREDICTIONS_DETECTION_VALUE,
)
|
initialize
Initializes the prediction handler
Source code in niceml/mlcomponents/predictionhandlers/objdetpredictionhandler.py
| def initialize(self):
"""Initializes the prediction handler"""
self.anchors: List[BoundingBox] = self.anchor_generator.generate_anchors(
data_description=self.data_description
)
self.anchor_array = np.array([box.get_absolute_xywh() for box in self.anchors])
self.prediction_filter.initialize(data_description=self.data_description)
|