Skip to content

keraslearner

keraslearner

Module for default learner

Classes

KerasLearner

KerasLearner(
    model_compiler,
    callback_initializer,
    model_load_custom_objects,
)

Bases: Learner

default learner for keras/keras models

Constructor for DefaultLearner Args: model_compiler: model compiler for keras callback_initializer: callback initializer for keras model_load_custom_objects: custom objects to load the model

Source code in niceml/dlframeworks/keras/learners/keraslearner.py
def __init__(
    self,
    model_compiler: ModelCompiler,
    callback_initializer: CallbackInitializer,
    model_load_custom_objects: ModelCustomLoadObjects,
):
    """
    Constructor for DefaultLearner
    Args:
        model_compiler: model compiler for keras
        callback_initializer: callback initializer for keras
        model_load_custom_objects: custom objects to load the model
    """
    self.model_compiler: ModelCompiler = model_compiler
    self.callback_initializer: CallbackInitializer = callback_initializer
    self.model_load_custom_objects: ModelCustomLoadObjects = (
        model_load_custom_objects
    )
Functions
run_training
run_training(
    exp_context,
    model_factory,
    train_set,
    validation_set,
    train_params,
    data_description,
)

runs the training

Source code in niceml/dlframeworks/keras/learners/keraslearner.py
def run_training(  # noqa: PLR0913
    self,
    exp_context: ExperimentContext,
    model_factory: ModelFactory,
    train_set: Dataset,
    validation_set: Dataset,
    train_params: TrainParams,
    data_description: DataDescription,
):
    """runs the training"""
    mlflow.keras.autolog()
    callbacks = self.callback_initializer(exp_context)
    model_bundle: ModelBundle = self.model_compiler.compile(
        model_factory, data_description
    )
    initialized_model: Model = model_bundle.model
    train_params: TrainParams
    validation_steps = None
    if train_params.validation_steps is not None:
        validation_steps = min(train_params.validation_steps, len(validation_set))
    steps_per_epoch = None
    if train_params.steps_per_epoch is not None:
        steps_per_epoch = min(train_params.steps_per_epoch, len(train_set))
    with tf.keras.utils.custom_object_scope(self.model_load_custom_objects()):
        history = initialized_model.fit(
            train_set,
            epochs=train_params.epochs,
            validation_data=validation_set,
            callbacks=callbacks,
            validation_steps=validation_steps,
            steps_per_epoch=steps_per_epoch,
        )
    return history