@op(
config_schema=dict(
prediction_handler=HydraInitField(PredictionHandler),
datasets=HydraMapField(Dataset),
prediction_steps=Field(
Noneable(int),
default_value=None,
description="If None the whole datasets are processed. "
"Otherwise only `prediction_steps` are evaluated.",
),
model_loader=HydraInitField(ModelLoader),
prediction_function=HydraInitField(PredictionFunction),
remove_key_list=Field(
list,
default_value=DEFAULT_REMOVE_CONFIG_KEYS,
description="These key are removed from any config recursively before it is saved.",
),
),
out={"expcontext": Out(), "datasets": Out(), "filelock_dict": Out()},
required_resource_keys={"mlflow"},
)
def prediction(
context: OpExecutionContext,
exp_context: ExperimentContext,
filelock_dict: Dict[str, FileLock],
) -> Tuple[ExperimentContext, Dict[str, Dataset], Dict[str, FileLock]]:
"""Dagster op to predict the stored model with the given datasets"""
op_config = json.loads(json.dumps(context.op_config))
write_op_config(
op_config,
exp_context,
OpNames.OP_PREDICTION.value,
op_config["remove_key_list"],
)
instantiated_op_config = instantiate(op_config, _convert_=ConvertMode.ALL)
data_description: DataDescription = (
exp_context.instantiate_datadescription_from_yaml()
)
exp_data: ExperimentData = create_expdata_from_expcontext(exp_context)
model_path: str = exp_data.get_model_path(relative_path=True)
model_loader: ModelLoader = instantiated_op_config["model_loader"]
with open_location(exp_context.fs_config) as (exp_fs, exp_root):
model = model_loader(
join_fs_path(exp_fs, exp_root, model_path),
file_system=exp_fs,
)
datasets_dict: Dict[str, Dataset] = instantiated_op_config["datasets"]
for dataset_key, cur_pred_set in datasets_dict.items():
context.log.info(f"Predict dataset: {dataset_key}")
cur_pred_set.initialize(data_description, exp_context)
save_exp_data_stats(cur_pred_set, exp_context, ExperimentFilenames.STATS_PRED)
predict_dataset(
data_description=data_description,
prediction_steps=instantiated_op_config["prediction_steps"],
model=model,
prediction_set=cur_pred_set,
prediction_handler=instantiated_op_config["prediction_handler"],
exp_context=exp_context,
filename=dataset_key,
prediction_function=instantiated_op_config["prediction_function"],
)
return exp_context, datasets_dict, filelock_dict