Skip to content

ML Extensions

The ML Extensions specification defines pipeline-level attributes for AI/ML workloads. It complements the ML Resources extension (which handles compute infrastructure) with semantic metadata about the ML task itself.

ExtensionPrefixFocus
ML Resourcesext_ml_GPU/TPU/CPU requirements, scheduling, preemption
ML Extensionsext_mlx_Model metadata, training, inference, evaluation

Describe the model a job operates on, including framework and I/O schemas.

AttributeTypeDescription
ext_mlx_model_namestringHuman-readable model name
ext_mlx_model_versionstringSemantic version of the model artifact
ext_mlx_model_frameworkstringpytorch, tensorflow, onnx, jax, sklearn, xgboost, custom
ext_mlx_model_input_schemaobjectJSON Schema for expected model input
ext_mlx_model_output_schemaobjectJSON Schema for expected model output
ext_mlx_experiment_idstringExperiment tracker ID
ext_mlx_run_idstringRun ID within the experiment

Training attributes describe datasets, hyperparameters, and output locations.

{
"type": "ml.train",
"ext_ml_gpu_type": "nvidia-a100",
"ext_ml_gpu_count": 4,
"ext_ml_precision": "bf16",
"ext_mlx_model_name": "llama-3.1-8b-custom",
"ext_mlx_model_framework": "pytorch",
"ext_mlx_training_dataset_uri": "s3://datasets/train.parquet",
"ext_mlx_training_hyperparameters": {
"learning_rate": 2e-5,
"batch_size": 16,
"weight_decay": 0.01,
"seed": 42
},
"ext_mlx_training_epochs": 3,
"ext_mlx_training_output_model_uri": "s3://models/custom/v1.0.0/"
}
AttributeTypeDescription
ext_mlx_training_dataset_uristringTraining dataset URI
ext_mlx_training_validation_uristringValidation dataset URI
ext_mlx_training_hyperparametersobjectKey-value map of hyperparameters
ext_mlx_training_epochsintegerNumber of training epochs
ext_mlx_training_checkpoint_uristringURI for saving checkpoints
ext_mlx_training_resume_fromstringCheckpoint URI to resume from
ext_mlx_training_output_model_uristringOutput model artifact URI

Configure how a trained model processes input data.

{
"type": "ml.inference",
"ext_ml_gpu_type": "nvidia-l4",
"ext_ml_gpu_count": 1,
"ext_mlx_model_name": "intent-classifier",
"ext_mlx_model_framework": "onnx",
"ext_mlx_inference_model_uri": "s3://models/classifier/v3.1.0/model.onnx",
"ext_mlx_inference_batch_size": 256,
"ext_mlx_inference_input_uri": "s3://data/input.jsonl",
"ext_mlx_inference_output_uri": "s3://data/output.jsonl",
"ext_mlx_inference_mode": "batch"
}
AttributeTypeDescription
ext_mlx_inference_model_uristringModel artifact URI
ext_mlx_inference_batch_sizeintegerInputs per batch
ext_mlx_inference_timeout_msintegerPer-request timeout (ms)
ext_mlx_inference_input_uristringInput data URI (batch mode)
ext_mlx_inference_output_uristringOutput data URI
ext_mlx_inference_modestringbatch, streaming, or realtime

Describe data transformations that prepare raw data for training or inference.

AttributeTypeDescription
ext_mlx_preprocess_input_uristringRaw input data URI
ext_mlx_preprocess_output_uristringProcessed output URI
ext_mlx_preprocess_input_formatstringcsv, parquet, json, jsonl, tfrecord, arrow, custom
ext_mlx_preprocess_output_formatstringOutput format (same enum)
ext_mlx_preprocess_transformationsarrayOrdered list of transformations
ext_mlx_preprocess_split_ratiosobjectTrain/validation/test split ratios

Transformation types: tokenize, normalize, augment, filter, sample, deduplicate, encode.

Assess model quality with metrics and thresholds that can gate deployment.

{
"type": "ml.evaluate",
"ext_mlx_eval_dataset_uri": "s3://data/test.parquet",
"ext_mlx_eval_model_uri": "s3://models/v1.0.0/",
"ext_mlx_eval_metrics": ["accuracy", "f1", "latency_p99"],
"ext_mlx_eval_thresholds": {
"accuracy": { "min": 0.92 },
"f1": { "min": 0.88 },
"latency_p99": { "max": 100 }
}
}
AttributeTypeDescription
ext_mlx_eval_dataset_uristringEvaluation dataset URI
ext_mlx_eval_model_uristringModel to evaluate
ext_mlx_eval_metricsarrayMetrics to compute
ext_mlx_eval_thresholdsobjectMin/max thresholds per metric
ext_mlx_eval_output_uristringResults output URI
ext_mlx_eval_baseline_run_idstringBaseline run for comparison

Well-known metrics: accuracy, f1, precision, recall, auc_roc, loss, perplexity, bleu, rouge_l, mae, rmse, latency_p50, latency_p99, throughput.

Unified artifact references for models, datasets, and checkpoints.

AttributeTypeDescription
ext_mlx_artifact_model_uristringPrimary model artifact URI
ext_mlx_artifact_dataset_uristringPrimary dataset URI
ext_mlx_artifact_checkpoint_uristringCheckpoint artifact URI
ext_mlx_artifact_output_uristringPrimary output artifact URI
ext_mlx_artifact_registrystrings3, gcs, azure_blob, mlflow, wandb, huggingface, local

Supported URI schemes: s3://, gs://, az://, hdfs://, mlflow://, wandb://, hf://, file://.

Advisory attributes for optimizing job placement.

AttributeTypeDescription
ext_mlx_scheduling_gpu_affinitystringspread, pack, dedicated
ext_mlx_scheduling_preemption_policystringnever, save_and_retry, immediate
ext_mlx_scheduling_spot_preferencestringspot_preferred, on_demand_only, any
ext_mlx_scheduling_data_localitystringURI hint for data-local scheduling

A complete ML pipeline expressed as an OJS workflow chain:

{
"workflow_type": "chain",
"id": "019539a4-ml-pipeline-001",
"name": "classifier-training-pipeline",
"steps": [
{
"type": "ml.preprocess",
"queue": "cpu-workers",
"ext_ml_accelerator": "cpu",
"ext_mlx_preprocess_input_uri": "s3://raw/tickets.jsonl",
"ext_mlx_preprocess_output_uri": "s3://processed/tickets-v3/"
},
{
"type": "ml.train",
"queue": "gpu-training",
"ext_ml_gpu_type": "nvidia-a100",
"ext_ml_gpu_count": 4,
"ext_mlx_training_dataset_uri": "s3://processed/tickets-v3/train.parquet",
"ext_mlx_training_output_model_uri": "s3://models/v1.0.0/"
},
{
"type": "ml.evaluate",
"queue": "gpu-inference",
"ext_ml_gpu_count": 1,
"ext_mlx_eval_model_uri": "s3://models/v1.0.0/",
"ext_mlx_eval_thresholds": { "accuracy": { "min": 0.92 } }
},
{
"type": "ml.deploy",
"queue": "deployment",
"ext_ml_accelerator": "cpu"
}
]
}

Each step has independent resource requirements — the backend schedules each on appropriate hardware as the chain progresses. If evaluation thresholds are not met, the pipeline stops before deployment.