Ray Distributed Model Training
- Run the following command to start the cluster.
ray up [OPTIONS] CLUSTER_CONFIG_FILE
Head Node - Run the following program in the head node to start the ray cluster.
ray start --head --dashboard-port 8000
Worker Node - Run the following command in all the worker nodes to connect to the head node.
ray start --address="ipaddress:port"
- Run the following command to check the status of the cluster.
ray status
Note
Also install prometheus and grafana to monitor the cluster. Prometheus Grafana
Model Training
Example
import torch
import torch.nn as nn
import ray
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air.config import RunConfig
from ray.air.config import CheckpointConfig
# If using CPU, set this to False.
use_gpu = True
# Define NN layers archicture, epochs, and number of workers
input_size = 1
layer_size = 32
output_size = 1
num_epochs = 200
num_workers = 3
# Define your network structure
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(input_size, layer_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(layer_size, output_size)
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
# Define your train worker loop
def train_loop_per_worker():
# Fetch training set from the session
dataset_shard = session.get_dataset_shard("train")
model = NeuralNetwork()
# Loss function, optimizer, prepare model for training.
# This moves the data and prepares model for distributed
# execution
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=0.01,
weight_decay=0.01)
model = train.torch.prepare_model(model)
# Iterate over epochs and batches
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(batch_size=32,
dtypes=torch.float, device=train.torch.get_device()):
# Add batch or unsqueeze as an additional dimension [32, x]
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
# Make output shape same as the as labels
loss = loss_fn(output.squeeze(), labels)
# Zero out grads, do backward, and update optimizer
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print what's happening with loss per 30 epochs
if epoch % 20 == 0:
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")
# Report and record metrics, checkpoint model at end of each
# epoch
session.report({"loss": loss.item(), "epoch": epoch},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict()))
)
torch.manual_seed(42)
train_dataset = ray.data.from_items(
[{"x": x, "y": 2 * x + 1} for x in range(200)]
)
# Define scaling and run configs
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=scaling_config,
run_config=run_config,
datasets={"train": train_dataset})
result = trainer.fit()
best_checkpoint_loss = result.metrics['loss']
# Assert loss is less 0.09
assert best_checkpoint_loss <= 0.09