Skip to content

objdetnetdatalogger

objdetnetdatalogger

Module of the ObjDetNetDataLogger

Classes

ObjDetNetDataLogger

ObjDetNetDataLogger(max_log=5)

Bases: NetDataLogger

NetDataLogger for object detection

Source code in niceml/data/netdataloggers/objdetnetdatalogger.py
def __init__(self, max_log: int = 5):
    super().__init__()
    self.max_log: int = max_log
    self.anchor_generator: AnchorGenerator = AnchorGenerator()
    self.log_count: int = 0
Functions
log_data
log_data(net_inputs, net_targets, data_info_list)

Saves as many images with corresponding anchor boxes as defined in self.max_log. The images are saved into self.output_path. For each input image, the associated positively marked anchor boxes are added to the image.

Parameters:

  • net_inputs (ndarray) –

    Input images as np.ndarray

  • net_targets (ndarray) –

    Targets as np.ndarray with the coded coordinates of the anchor boxes, the mask value and the one-hot-encoded class vector

  • data_info_list (List[ObjDetDataInfo]) –

    Associated data information of input and destination with extended information

Returns:

  • None

Source code in niceml/data/netdataloggers/objdetnetdatalogger.py
def log_data(
    self,
    net_inputs: np.ndarray,
    net_targets: np.ndarray,
    data_info_list: List[ObjDetDataInfo],
):
    """
    Saves as many images with corresponding anchor boxes as defined in `self.max_log`.
    The images are saved into `self.output_path`. For each input image,
    the associated positively marked anchor boxes are added to the image.

    Args:
        net_inputs: Input images as `np.ndarray`
        net_targets: Targets as `np.ndarray` with the coded coordinates
            of the anchor boxes, the mask value and the one-hot-encoded class vector
        data_info_list: Associated data information of input and destination
            with extended information

    Returns:
        None
    """
    if self.log_count >= self.max_log:
        return

    output_data_description = check_instance(
        self.data_description, OutputObjDetDataDescription
    )

    anchors = self.anchor_generator.generate_anchors(
        data_description=output_data_description
    )

    if len(net_inputs) != len(net_targets):
        raise ValueError(
            f"Mismatching lengths of net_inputs "
            f"and net_targets ({len(net_inputs)}, {len(net_targets)}"
        )

    for net_input, net_target, data_info in zip(
        net_inputs, net_targets, data_info_list
    ):
        decoded_bboxes = [
            anchor.decode(
                predicted_values=target[:4],
                box_variance=output_data_description.get_box_variance(),
            ).get_absolute_ullr()
            for anchor, target in zip(anchors, net_target)
        ]
        net_target[:, :4] = decoded_bboxes
        positive_targets = net_target[net_target[:, 4] == POSITIVE_MASK_VALUE]
        labels = [
            self._target_to_label(target=target) for target in positive_targets
        ]
        self._draw_image(
            image=net_input, instance_labels=labels, data_info=data_info
        )

        self.log_count += 1

        if self.log_count >= self.max_log:
            break

Functions