from hopx_ai import Sandbox
import os
import time
from typing import Dict, Any
class MLTrainingService:
def __init__(self, api_key: str):
self.api_key = api_key
self.sandbox = None
def start_training(self, training_code: str, resources: Dict[str, int] = None) -> Dict[str, Any]:
"""Start ML training job"""
try:
# Create sandbox with appropriate resources
# Note: Resources come from template, but we can request high-resource template
self.sandbox = Sandbox.create(
template="code-interpreter", # Use ML-optimized template if available
api_key=self.api_key,
timeout_seconds=7200 # 2 hour timeout for training
)
# Start training in background
execution_id = self.sandbox.run_code_background(training_code)
return {
"success": True,
"execution_id": execution_id,
"sandbox_id": self.sandbox.sandbox_id
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
def monitor_training(self, execution_id: str) -> Dict[str, Any]:
"""Monitor training progress"""
try:
# Check processes
processes = self.sandbox.list_processes()
training_processes = [
p for p in processes
if 'python' in p.get('name', '').lower() or 'train' in p.get('name', '').lower()
]
# Check for checkpoint files
files = self.sandbox.files.list("/workspace")
checkpoints = [f for f in files if 'checkpoint' in f.name.lower() or f.name.endswith('.pkl')]
return {
"training_active": len(training_processes) > 0,
"checkpoints": len(checkpoints),
"processes": len(training_processes)
}
except Exception as e:
return {
"error": str(e)
}
def get_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]:
"""Retrieve model checkpoint"""
try:
checkpoint_data = self.sandbox.files.read(checkpoint_path)
return {
"success": True,
"data": checkpoint_data,
"size": len(checkpoint_data)
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
def cleanup(self):
"""Clean up training resources"""
if self.sandbox:
self.sandbox.kill()
self.sandbox = None
# Usage
service = MLTrainingService(api_key=os.getenv("HOPX_API_KEY"))
training_code = """
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import pickle
iris = load_iris()
X, y = iris.data, iris.target
model = RandomForestClassifier(n_estimators=100)
model.fit(X, y)
with open('/workspace/model.pkl', 'wb') as f:
pickle.dump(model, f)
"""
result = service.start_training(training_code)
print(result)
# Monitor
time.sleep(5)
status = service.monitor_training(result["execution_id"])
print(status)
service.cleanup()