PyTorch Ray Distributed + Wandb Integration
Code For Integrating Ray Train, Pytorch Lightning, Wandb
For a complete Code for Integration Ray Train, Pytorch Lightning & Wandb, check this [here] out. The Following is the code Snippet for the same. Here we look at the generalized way of writing the code so you can use it both for Ray Train Distributed as well as DeepSpeed Distributed Training.
'@Author:NavinKumarMNK'
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import torch.nn as nn
import pytorch_lightning as pl
import ray_lightning as rl
pl.seed_everything(42) # Set the seed for reproducibility
class PyTorchModel(pl.LightningModule):
def __init__(self, **parmas) -> None:
super(VariationalAutoEncoder, self).__init__()
## Setup the needed paramters for the model
self.example_input_array = ...
self.model = nn.Sequential(
...
)
def forward(self, x):
## Forward pass of the model, returns the output of the model
x = self.model(x)
return x
def loss_function(self, y, y_hat):
# Define the custom loss function for your problem
loss = nn.CrossEntropyLoss(y_hat, y)
self.log("Total loss", loss)
return loss
def training_step(self, batch, batch_idx):
# this function defines the procedure for each training step
x, y = batch
y_hat = self(x)
loss = self.loss_function(y, y_hat)
# Default Logger Provided by Pytorch Lightning, when Wandb Plugins is used this logs into wandb
self.log('train_loss', loss)
return {"loss" : loss}
def training_epoch_end(self, outputs)-> None:
# Provides the prcedure for the end of each train epoch
loss = outputs[0]['loss']
avg_loss = torch.stack([x['loss'] for x in loss]).mean()
self.log('train/loss_epoch', avg_loss)
def validation_step(self, batch, batch_idx):
# this function defines the procedure for each validation step
x, y = batch
y_hat = self(x)
loss = self.loss_function(x_hat, y, mu, log_var)
self.log('val_loss', loss)
return {"val_loss": loss, "y_hat": x_hat, "y": y}
def validation_epoch_end(self, outputs)-> None:
# Provides the prcedure for the end of each validation epoch
loss, y_hat, y = outputs[0]['val_loss'], outputs[0]['y_hat'], outputs[0]['y']
avg_loss = torch.stack([x['loss'] for x in loss]).mean()
self.log('val/loss_epoch', avg_loss)
# validation loss is less than previous epoch then save the model
elif (avg_loss < self.best_val_loss):
self.best_val_loss = avg_loss
self.save_model()
def test_step(self, batch, batch_idx):
# this function defines the procedure for each test step
x, y = batch
x_hat, mu, log_var = self(x)
loss = self.loss_function(x_hat, y, mu, log_var)
self.log('test_loss', loss)
return {"test_loss": loss, "y_hat": x_hat, "y": y}
def test_epoch_end(self, outputs)-> None:
# Provides the prcedure for the end of each test epoch
loss, y_hat, y = outputs[0]['test_loss'], outputs[0]['y_hat'], outputs[0]['y']
avg_loss = torch.stack([x['loss'] for x in loss]).mean()
self.log('test/loss_epoch', avg_loss)
def save_model(self):
# Save the model
torch.save(self.state_dict(), "auto_encoder_model.cpkt")
artifact = wandb.Artifact('auto_encoder_model.cpkt', type='model')
wandb.run.log_artifact(artifact)
def print_params(self):
print("Model Parameters:")
for name, param in self.named_parameters():
if param.requires_grad:
print(name, param.data.shape)
def configure_optimizers(self):
# Define the optimizer for the model
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
return [optimizer]
def prediction_step(self, batch, batch_idx, dataloader_idx=None):
# this function defines the procedure for each prediction step
x, y = batch
y_hat = self(x)
return y_hat
class CustomCallback(Callback):
...
def train():
# Pytorch Wandb Logger
logger = WandbLogger(project='CrimeDetection3', name='VariationalAutoEncoder')
# Intialize the Wandb, Ray and Pytorch Lightning
ray.init(runtime_env={"working_dir": utils.ROOT_PATH})
wandb.init()
# Setup the Dataset
dataset = Dataset()
dataset.setup()
# Pytorch Lightning Callbacks
...
callbacks = [
... # Declare the callbacks
CustomCallback()
]
model = PyTorchModel()
# Setup the Distributed Training Method
dist_env_params = {'horovod': 0, 'deep_speed': 0, 'model_parallel': 0, 'data_parallel': 0, 'use_gpu': 0, 'num_workers': 0, 'num_cpus_per_worker': 0}
strategy = None
if int(dist_env_params['horovod']) == 1:
strategy = rl.HorovodRayStrategy(use_gpu=dist_env_params['use_gpu'],
num_workers=dist_env_params['num_workers'],
num_cpus_per_worker=dist_env_params['num_cpus_per_worker'])
elif int(dist_env_params['deep_speed']) == 1:
strategy = 'deepspeed_stage_1'
elif int(dist_env_params['model_parallel']) == 1:
strategy = rl.RayShardedStrategy(use_gpu=dist_env_params['use_gpu'],
num_workers=dist_env_params['num_workers'],
num_cpus_per_worker=dist_env_params['num_cpus_per_worker'])
elif int(dist_env_params['data_parallel']) == 1:
strategy = rl.RayStrategy(use_gpu=dist_env_params['use_gpu'],
num_workers=dist_env_params['num_workers'],
num_cpus_per_worker=dist_env_params['num_cpus_per_worker'])
trainer = Trainer(**autoencoder_params,
callbacks=callbacks,
strategy=strategy,
accelerator='gpu',
logger=logger,
log_every_n_steps=5
)
trainer.fit(model, dataset)
model.save_model()
wandb.finish()
if __name__ == '__main__':
train()