JAX Example

This is a complete example of JAX code that trains a MLP and saves to W&B.

You can find this example on GitHub and see the results on W&B.

import time
import itertools
import numpy.random as npr
import wandb
import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), LogSoftmax)
if __name__ == "__main__":
wandb.init()
rng = random.PRNGKey(0)
wandb.config.step_size = 0.001
wandb.config.num_epochs = 10
wandb.config.batch_size = 128
wandb.config.momentum_mass = 0.9
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, wandb.config.batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(wandb.config.step_size, mass=wandb.config.momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(wandb.config.num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
wandb.log({"Train Accuracy": float(train_acc), "Test Accuracy": float(test_acc)})