Skip to main content
Build a production-ready ML model training service that handles resource allocation, monitors training progress, and manages model checkpoints. This cookbook demonstrates how to create a training platform using HopX.

Overview

ML training services provide cloud-based environments for training machine learning models. The service allocates resources, executes long-running training jobs, monitors progress, and saves model checkpoints.

Prerequisites

  • HopX API key (Get one here)
  • Python 3.8+ or Node.js 16+
  • Understanding of ML training workflows
  • Basic knowledge of model checkpointing

Architecture

┌──────────────┐
│  Training    │ Submit training job
│   Request    │
└──────┬───────┘


┌─────────────────┐
│  Training       │ Allocate resources
│   Service       │
└──────┬──────────┘


┌─────────────────┐
│  HopX Sandbox   │ Execute training
│  (High resources)│
└──────┬──────────┘


┌─────────────────┐
│  Model Storage  │ Save checkpoints
└─────────────────┘

Implementation

Step 1: Training Job Execution

Execute ML training jobs:
from hopx_ai import Sandbox
import os
import time
from typing import Dict, Any

class MLTrainingService:
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.sandbox = None
    
    def start_training(self, training_code: str, resources: Dict[str, int] = None) -> Dict[str, Any]:
        """Start ML training job"""
        try:
            # Create sandbox with appropriate resources
            # Note: Resources come from template, but we can request high-resource template
            self.sandbox = Sandbox.create(
                template="code-interpreter",  # Use ML-optimized template if available
                api_key=self.api_key,
                timeout_seconds=7200  # 2 hour timeout for training
            )
            
            # Start training in background
            execution_id = self.sandbox.run_code_background(training_code)
            
            return {
                "success": True,
                "execution_id": execution_id,
                "sandbox_id": self.sandbox.sandbox_id
            }
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def monitor_training(self, execution_id: str) -> Dict[str, Any]:
        """Monitor training progress"""
        try:
            # Check processes
            processes = self.sandbox.list_processes()
            training_processes = [
                p for p in processes
                if 'python' in p.get('name', '').lower() or 'train' in p.get('name', '').lower()
            ]
            
            # Check for checkpoint files
            files = self.sandbox.files.list("/workspace")
            checkpoints = [f for f in files if 'checkpoint' in f.name.lower() or f.name.endswith('.pkl')]
            
            return {
                "training_active": len(training_processes) > 0,
                "checkpoints": len(checkpoints),
                "processes": len(training_processes)
            }
        except Exception as e:
            return {
                "error": str(e)
            }
    
    def get_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]:
        """Retrieve model checkpoint"""
        try:
            checkpoint_data = self.sandbox.files.read(checkpoint_path)
            return {
                "success": True,
                "data": checkpoint_data,
                "size": len(checkpoint_data)
            }
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }
    
    def cleanup(self):
        """Clean up training resources"""
        if self.sandbox:
            self.sandbox.kill()
            self.sandbox = None

# Usage
service = MLTrainingService(api_key=os.getenv("HOPX_API_KEY"))

training_code = """
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import pickle

iris = load_iris()
X, y = iris.data, iris.target

model = RandomForestClassifier(n_estimators=100)
model.fit(X, y)

with open('/workspace/model.pkl', 'wb') as f:
    pickle.dump(model, f)
"""

result = service.start_training(training_code)
print(result)

# Monitor
time.sleep(5)
status = service.monitor_training(result["execution_id"])
print(status)

service.cleanup()

Best Practices

  1. Resource Allocation: Request appropriate resources for training
  2. Progress Monitoring: Monitor training progress regularly
  3. Checkpointing: Save model checkpoints frequently
  4. Error Recovery: Handle training failures gracefully

Next Steps

  1. Implement distributed training support
  2. Add hyperparameter tuning
  3. Create training dashboard
  4. Implement model versioning
  5. Add training job scheduling