Skip to main content
In this tutorial you’ll learn how to train a simple PyTorch model in a sandbox environment. To do this, you will start a sandbox with the appropriate environment variables, install the necessary dependencies, and run a Python script that trains a simple neural network on the UCI Zoo dataset.

Prerequisites

Install W&B Python SDK

Install the W&B Python SDK. You can do this using pip:
pip install wandb

Store API keys in environment variables

If you have not done so already, store your W&B API key. Run the wandb login CLI command and follow the prompts to log in to your W&B account:
wandb login
See wandb login reference documenation for more information on how W&B searches for credentials.

Copy the training script and dependencies

Copy and paste the following code into a file named requirements.txt. This file contains the dependencies for the training script.
requirements.txt
torch
pandas
ucimlrepo
scikit-learn
pyyaml
Copy and paste the following code into a YAML file named hyperparameters.yaml. This file contains the hyperparameters for the training script.
hyperparameters.yaml
learning_rate: 0.1
epochs: 1000
model_type: Multivariate_neural_network_classifier
Copy and paste the following code into a file named train.py. This script trains a simple PyTorch model on the UCI Zoo dataset and saves the trained model to a file named zoo_wandb.pth.
train.py
import argparse
import torch 
from torch import nn
import yaml
import pandas as pd
from ucimlrepo import fetch_ucirepo

from sklearn.model_selection import train_test_split

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(in_features=16 , out_features=16),
            nn.Sigmoid(),
            nn.Linear(in_features=16, out_features=7)
        )

    def forward(self, x):
        logits = self.linear_stack(x)
        return logits

def main(args):
    # Load hyperparameters from the provided config file
    with open(args.config, 'r') as f:
        hyperparameter_config = yaml.safe_load(f)

    # fetch dataset 
    zoo = fetch_ucirepo(id=111) 
    
    # data (as pandas dataframes) 
    X = zoo.data.features 
    y = zoo.data.targets

    print("features: ", X.shape, "type: ", type(X))
    print("labels: ", y.shape, "type: ", type(y))

    ## Process data
    # Data type of the data must match the data type of the model, the default dtype for nn.Linear is torch.float32
    dataset = torch.tensor(X.values).type(torch.float32) 

    # Convert to tensor and format labels from 0 - 6 for indexing
    labels = torch.tensor(y.values)  - 1

    print("dataset: ", dataset.shape, "dtype: ",dataset.dtype)
    print("labels: ", labels.shape, "dtype: ",labels.dtype)

    torch.save(dataset, "zoo_dataset.pt")
    torch.save(labels, "zoo_labels.pt")

    # Describe how we split the training dataset for future reference, reproducibility.
    config = {
        "random_state" : 42,
        "test_size" : 0.25,
        "shuffle" : True
    }

    # Split dataset into training and test set
    X_train, X_test, y_train, y_test = train_test_split(
        dataset,labels, 
        random_state=config["random_state"],
        test_size=config["test_size"], 
        shuffle=config["shuffle"]
    )

    # Save the files locally
    torch.save(X_train, "zoo_dataset_X_train.pt")
    torch.save(y_train, "zoo_labels_y_train.pt")

    torch.save(X_test, "zoo_dataset_X_test.pt")
    torch.save(y_test, "zoo_labels_y_test.pt")


    ## Define model
    model = NeuralNetwork()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=hyperparameter_config["learning_rate"])
    print(model)

    # Set initial dummy loss value to compare to in training loop
    prev_best_loss = 1e10 

    # Training loop
    for e in range(hyperparameter_config["epochs"] + 1):
        pred = model(X_train)
        loss = loss_fn(pred, y_train.squeeze(1))
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Checkpoint/save model if loss improves
        if (e % 100 == 0) and (loss <= prev_best_loss):
            print("epoch: ", e, "loss:", loss.item())
        
            # Store new best loss
            prev_best_loss = loss

    print("Saving model...")
    PATH = 'zoo_wandb.pth' 
    torch.save(model.state_dict(), PATH)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a simple neural network on the zoo dataset.")
    parser.add_argument("--config", type=str, required=True, help="Path to the hyperparameter configuration file.")
    args = parser.parse_args()
    main(args)

Create the sandbox and run the training script

The following code snippet shows how to create a sandbox, copy the training script and dependencies into it, run the training script, and download the generated model file. The next section provides a line-by-line explanation of the code. Copy and paste the following code into a Python file and run it. Save it in the same directory as the train.py, requirements.txt, and hyperparameters.yaml files you created in the previous step.
train_in_sandbox.py
from pathlib import Path
from wandb.sandbox import Sandbox, NetworkOptions

hyperparameters = Path("hyperparameters.yaml").read_bytes()

mounted_files = [
    {"mount_path": "train.py", "file_content": Path("train.py").read_bytes()},
    {"mount_path": "requirements.txt", "file_content": Path("requirements.txt").read_bytes()},
        ] 

print("Starting sandbox...")
with Sandbox.run(
    mounted_files=mounted_files,
    container_image="python:3.13",
    network=NetworkOptions(egress_mode="internet"),
    max_lifetime_seconds=3600
) as sandbox:
    sandbox.write_file("hyperparameters.yaml", hyperparameters).result()

    # Install dependencies
    print("Installing dependencies...")
    sandbox.exec(["pip", "install", "-r", "requirements.txt"], check=True).result()

    # Run the script
    print("Running script...")
    result = sandbox.exec(["python", "train.py", "--config", "hyperparameters.yaml"]).result()
    print(result.stdout)
    print(result.stderr)
    print(f"Exit code: {result.returncode}")

    # Save the generated model file locally
    print("Downloading zoo_wandb.pth...")
    model_data = sandbox.read_file("zoo_wandb.pth").result()
    Path("zoo_wandb.pth").write_bytes(model_data)
    print("Saved zoo_wandb.pth")
The previous code snippet does the following:
  1. (Lines 6 - 9) Mount files to the sandbox at startup. The mounted_files parameter in Sandbox.run() allows you to specify a list of files to mount into the sandbox at startup. Each file is represented as a dictionary with two keys: mount_path, which specifies the path where the file will be mounted inside the sandbox, and file_content, which contains the content of the file as bytes. In this example, you mount train.py and requirements.txt into the sandbox.
  2. (Line 12) Start the sandbox with the specified configuration. You use a context manager (with statement) to ensure that the sandbox is properly stopped after use. The sandbox is configured to use the python:3.13 container image, have internet access, and a maximum lifetime of 3600 seconds (1 hour).
  3. (Line 18) Write the hyperparameters.yaml file to the sandbox. You read the content of the hyperparameters.yaml file and write it to the sandbox using the write_file() method. This allows the training script to access the hyperparameters when it runs.
  4. (Line 22) Install dependencies. You run the command pip install -r requirements.txt inside the sandbox to install the necessary dependencies for the training script.
  5. (Line 26) Run the training script. You execute the command python train.py --config hyperparameters.yaml inside the sandbox to start the training process. The script trains a simple PyTorch model on the UCI Zoo dataset and saves the trained model to a file named zoo_wandb.pth.
  6. (Lines 29-31) Print the output and exit code. After the training script finishes executing, you print the standard output, standard error, and exit code to the console for debugging and verification purposes.
  7. (Lines 35-36) Download the generated model file. You read the zoo_wandb.pth file from the sandbox using the read_file() method and save it locally.