KerasLearner(
model_compiler,
callback_initializer,
model_load_custom_objects,
)
Bases: Learner
default learner for keras/keras models
Constructor for DefaultLearner
Args:
model_compiler: model compiler for keras
callback_initializer: callback initializer for keras
model_load_custom_objects: custom objects to load the model
Source code in niceml/dlframeworks/keras/learners/keraslearner.py
| def __init__(
self,
model_compiler: ModelCompiler,
callback_initializer: CallbackInitializer,
model_load_custom_objects: ModelCustomLoadObjects,
):
"""
Constructor for DefaultLearner
Args:
model_compiler: model compiler for keras
callback_initializer: callback initializer for keras
model_load_custom_objects: custom objects to load the model
"""
self.model_compiler: ModelCompiler = model_compiler
self.callback_initializer: CallbackInitializer = callback_initializer
self.model_load_custom_objects: ModelCustomLoadObjects = (
model_load_custom_objects
)
|
Functions
run_training
run_training(
exp_context,
model_factory,
train_set,
validation_set,
train_params,
data_description,
)
runs the training
Source code in niceml/dlframeworks/keras/learners/keraslearner.py
| def run_training( # noqa: PLR0913
self,
exp_context: ExperimentContext,
model_factory: ModelFactory,
train_set: Dataset,
validation_set: Dataset,
train_params: TrainParams,
data_description: DataDescription,
):
"""runs the training"""
mlflow.keras.autolog()
callbacks = self.callback_initializer(exp_context)
model_bundle: ModelBundle = self.model_compiler.compile(
model_factory, data_description
)
initialized_model: Model = model_bundle.model
train_params: TrainParams
validation_steps = None
if train_params.validation_steps is not None:
validation_steps = min(train_params.validation_steps, len(validation_set))
steps_per_epoch = None
if train_params.steps_per_epoch is not None:
steps_per_epoch = min(train_params.steps_per_epoch, len(train_set))
with tf.keras.utils.custom_object_scope(self.model_load_custom_objects()):
history = initialized_model.fit(
train_set,
epochs=train_params.epochs,
validation_data=validation_set,
callbacks=callbacks,
validation_steps=validation_steps,
steps_per_epoch=steps_per_epoch,
)
return history
|