Skip to content

train

train

Module for train op

Classes

Functions

train

train(context, exp_context, filelock_dict)

DagsterOp that trains the model

Source code in niceml/dagster/ops/train.py
@op(
    config_schema=train_config,
    out={"expcontext": Out(), "filelock_dict": Out()},
    required_resource_keys={"mlflow"},
)
def train(
    context: OpExecutionContext,
    exp_context: ExperimentContext,
    filelock_dict: Dict[str, FileLock],
) -> Tuple[ExperimentContext, Dict[str, FileLock]]:
    """DagsterOp that trains the model"""
    op_config = json.loads(json.dumps(context.op_config))
    write_op_config(
        op_config, exp_context, OpNames.OP_TRAIN.value, op_config["remove_key_list"]
    )
    instantiated_op_config = instantiate(op_config, _convert_=ConvertMode.ALL)

    data_train = instantiated_op_config["data_train"]
    data_valid = instantiated_op_config["data_validation"]
    data_description = instantiated_op_config["data_description"]

    data_train.initialize(data_description, exp_context)
    data_valid.initialize(data_description, exp_context)

    save_exp_data_stats(data_train, exp_context, ExperimentFilenames.STATS_TRAIN)
    save_exp_data_stats(data_valid, exp_context, ExperimentFilenames.STATS_TRAIN)

    instantiated_op_config["exp_initializer"](exp_context)

    fit_generator(
        exp_context,
        instantiated_op_config["learner"],
        instantiated_op_config["model"],
        data_train,
        data_valid,
        instantiated_op_config["train_params"],
        data_description,
    )
    return exp_context, filelock_dict