@op(
config_schema=train_config,
out={"expcontext": Out(), "filelock_dict": Out()},
required_resource_keys={"mlflow"},
)
def train(
context: OpExecutionContext,
exp_context: ExperimentContext,
filelock_dict: Dict[str, FileLock],
) -> Tuple[ExperimentContext, Dict[str, FileLock]]:
"""DagsterOp that trains the model"""
op_config = json.loads(json.dumps(context.op_config))
write_op_config(
op_config, exp_context, OpNames.OP_TRAIN.value, op_config["remove_key_list"]
)
instantiated_op_config = instantiate(op_config, _convert_=ConvertMode.ALL)
data_train = instantiated_op_config["data_train"]
data_valid = instantiated_op_config["data_validation"]
data_description = instantiated_op_config["data_description"]
data_train.initialize(data_description, exp_context)
data_valid.initialize(data_description, exp_context)
save_exp_data_stats(data_train, exp_context, ExperimentFilenames.STATS_TRAIN)
save_exp_data_stats(data_valid, exp_context, ExperimentFilenames.STATS_TRAIN)
instantiated_op_config["exp_initializer"](exp_context)
fit_generator(
exp_context,
instantiated_op_config["learner"],
instantiated_op_config["model"],
data_train,
data_valid,
instantiated_op_config["train_params"],
data_description,
)
return exp_context, filelock_dict