Skip to content

genericdataset

genericdataset

module for generic dataset implementation

Classes

GenericDataset

GenericDataset(
    set_name,
    datainfo_listing,
    data_loader,
    target_transformer,
    input_transformer,
    shuffle,
    data_shuffler=None,
    stats_generator=None,
    augmentator=None,
    net_data_logger=None,
)

Bases: Dataset, ABC

Generic dataset implementation. This is a flexible dataset for multiple use cases. It can be used for classification, segmentation, object detection, etc. For specific frameworks, there are subclasses of this class, e.g. KerasGenericDataset

Constructor of the GenericDataset Args: set_name: Name of the subset e.g. train datainfo_listing: How to list the data data_loader: How to load the data target_transformer: How to transform the target of the model (e.g. one-hot encoding) input_transformer: How to transform the input of the model shuffle: bool if the data should be shuffled data_shuffler: A way of shuffling the data (e.g. random, sampled) stats_generator: Write dataset stats augmentator: Augment the data on the fly net_data_logger: Stores the in the way it is presented to the model

Source code in niceml/data/datasets/genericdataset.py
def __init__(  # noqa: PLR0913
    self,
    set_name: str,
    datainfo_listing: DataInfoListing,
    data_loader: DataLoader,
    target_transformer: NetTargetTransformer,
    input_transformer: NetInputTransformer,
    shuffle: bool,
    data_shuffler: Optional[DataShuffler] = None,
    stats_generator: Optional[DataStatsGenerator] = None,
    augmentator: Optional[AugmentationProcessor] = None,
    net_data_logger: Optional[NetDataLogger] = None,
):
    """
    Constructor of the GenericDataset
    Args:
        set_name: Name of the subset e.g. train
        datainfo_listing: How to list the data
        data_loader: How to load the data
        target_transformer: How to transform the
            target of the model (e.g. one-hot encoding)
        input_transformer: How to transform the input of the model
        shuffle: bool if the data should be shuffled
        data_shuffler: A way of shuffling the data (e.g. random, sampled)
        stats_generator: Write dataset stats
        augmentator: Augment the data on the fly
        net_data_logger: Stores the in the way it is presented to the model
    """
    super().__init__()
    self.net_data_logger = net_data_logger
    self.set_name = set_name
    self.datainfo_listing: DataInfoListing = datainfo_listing
    self.data_loader: DataLoader = data_loader
    self.shuffle = shuffle
    self.data_shuffler: DataShuffler = data_shuffler or DefaultDataShuffler()
    self.target_transformer: NetTargetTransformer = target_transformer
    self.input_transformer: NetInputTransformer = input_transformer
    self.augmentator: Optional[AugmentationProcessor] = augmentator

    self.data_stats_generator: DataStatsGenerator = (
        stats_generator or DefaultStatsGenerator()
    )
Functions
__getitem__
__getitem__(item_index)

Returns the data of the item at index

Source code in niceml/data/datasets/genericdataset.py
def __getitem__(self, item_index: int):
    """Returns the data of the item at index"""
    real_index = self.index_list[item_index]
    data_info = self.data_info_list[real_index]
    data_item = self.data_loader.load_data(data_info)
    if self.augmentator is not None:
        data_item = self.augmentator(data_item)
    net_inputs = self.input_transformer.get_net_inputs([data_item])
    net_targets = self.target_transformer.get_net_targets([data_item])
    if self.net_data_logger is not None:
        self.net_data_logger.log_data(
            net_inputs=net_inputs,
            net_targets=net_targets,
            data_info_list=[data_info],
        )
    return net_inputs, net_targets
__len__
__len__()

Returns the number of batches

Source code in niceml/data/datasets/genericdataset.py
def __len__(self):
    """Returns the number of batches"""
    return self.get_items_per_epoch()
get_data_by_key
get_data_by_key(data_key)

Returns the data by the key (identifier of the data)

Source code in niceml/data/datasets/genericdataset.py
def get_data_by_key(self, data_key):
    """Returns the data by the key (identifier of the data)"""
    data_info: DataInfo = self.data_info_dict[data_key]
    return self.data_loader.load_data(data_info)
get_dataset_stats
get_dataset_stats()

Returns the dataset stats

Source code in niceml/data/datasets/genericdataset.py
def get_dataset_stats(self) -> dict:
    """Returns the dataset stats"""
    return self.data_stats_generator.generate_stats(
        self.data_info_list, self.index_list
    )
get_item_count
get_item_count()

Returns the current count of items in the dataset

Source code in niceml/data/datasets/genericdataset.py
def get_item_count(self) -> int:
    """Returns the current count of items in the dataset"""
    return len(self.data_info_list)
get_items_per_epoch
get_items_per_epoch()

Returns the items per epoch

Source code in niceml/data/datasets/genericdataset.py
def get_items_per_epoch(self) -> int:
    """Returns the items per epoch"""
    return len(self.index_list)
get_set_name
get_set_name()

Returns the name of the set e.g. train

Source code in niceml/data/datasets/genericdataset.py
def get_set_name(self) -> str:
    """Returns the name of the set e.g. train"""
    return self.set_name
initialize
initialize(data_description, exp_context)

Initializes the dataset with the data description and context

Source code in niceml/data/datasets/genericdataset.py
def initialize(
    self, data_description: DataDescription, exp_context: ExperimentContext
):
    """Initializes the dataset with the data description and context"""
    self.data_description = data_description

    self.data_loader.initialize(data_description)
    self.data_shuffler.initialize(data_description)
    self.target_transformer.initialize(data_description)
    self.input_transformer.initialize(data_description)
    self.data_info_list: List[DataInfo] = self.datainfo_listing.list(
        data_description
    )
    self.index_list: List[int] = list(range(len(self.data_info_list)))
    self.data_info_dict: Dict[str, DataInfo] = {
        cur_data_info.get_identifier(): cur_data_info
        for cur_data_info in self.data_info_list
    }
    if self.net_data_logger is not None:
        self.net_data_logger.initialize(
            self.data_description, exp_context, self.set_name
        )

    self.on_epoch_end()
on_epoch_end
on_epoch_end()

Shuffles the data if required

Source code in niceml/data/datasets/genericdataset.py
def on_epoch_end(self):
    """Shuffles the data if required"""
    if self.shuffle:
        self.index_list = self.data_shuffler.shuffle(self.data_info_list)