Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification on sharding strategy to combine multiple training steps together via nnx.scan #4417

Open
Teculos opened this issue Dec 4, 2024 · 2 comments

Comments

@Teculos
Copy link

Teculos commented Dec 4, 2024

I'm trying to, roughly, replicate behaviour found in this repo, where they pmap a scan transform on the train step to combine multiple train steps into one function call (see run_lib.py line 124). Since this is a pre flax.nnx implementation they replicate the model and pmap over the model replicates and data (structured as [combined steps, jax.device_count(), batchsize// jax.device_count(), *data dim]).

Ergo they pmap across the second dimension and scan across the first to distribute the forward pass across GPUs, jax.lax.pmean the gradient, update the model, and iterate to the nexted step in the scan.

Since pmap has no flax.nnx equivalent my approach was to shard the data across the batch dimension (data for me is in the shape [combined steps, batch_size, *data dim]) and replicate the model on each GPU to distribute the forward pass. Although I'm not certain if I'm going about it properly. See below for a minimum example with a simple model and random data/labels.

from flax import nnx
from jax.sharding import NamedSharding, PartitionSpec

import jax
import optax

#Data is of shape [steps, batch, data dim]
data = jax.random.normal(jax.random.PRNGKey(1), (5,100,20))
label = jax.numpy.ones((5,100,1))
model = nnx.Sequential(nnx.Linear(20, 30, rngs= nnx.Rngs(0)),
                        nnx.Linear(30, 1, rngs= nnx.Rngs(1)))


#Unsharded data/model
jax.debug.visualize_array_sharding(data[0])

Image

jax.debug.visualize_array_sharding(model.layers[0].kernel.value)

Image


#shard data
mesh = jax.make_mesh((jax.device_count(),), ("batch", ))
data_sharding = NamedSharding(mesh, PartitionSpec(None,"batch"))

sharded_data = jax.device_put(data, data_sharding)
sharded_label = jax.device_put(label, data_sharding)

#shard model
def create_sharded_model(model):
  state = nnx.state(model) # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
    sharded_model = create_sharded_model(model)

tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)

#Sharded data/label at this point data should be sharded across GPUs
#with the model replicated (???)
jax.debug.visualize_array_sharding(sharded_data[0])

Image

jax.debug.visualize_array_sharding(sharded_model.layers[0].kernel.value)

Image

loss_fn = lambda model, x, y: optax.l2_loss(model(x),y).mean()

def step_fn(batch_data, batch_label, state):
    grads = nnx.grad(loss_fn)(state.model,batch_data,batch_label)
    state.update(grads=grads)
    return state

#combine multiple train steps into a single scan carrying over the state
scanned_train = nnx.jit(nnx.scan(step_fn, in_axes=(0,0,nnx.Carry), out_axes=(nnx.Carry),
                       transform_metadata={nnx.PARTITION_NAME:"batch"}))

#returns states after 5 scanned+jitted train steps
new_state = scanned_train(sharded_data, sharded_label, state)

Specifically I'd like to know:

  1. Is my approach to sharding a model with no annotations the best practice for replicating a model across devices?
  2. Am I correct in thinking this sharding formulation will have the replicated models run the forward pass on the subset of batch observations located on their respective GPUs?
    • Bit hard to exactly determine whats happening under the hood here.
  3. with data sharded across GPUs but a replicated model how exactly are gradients calculated/combined?
    • the repo I'm trying to mirror has the jax.lax.pmean explicitly stated in the losses.py generated loss function (line 229) but it seems like some flax magic is happening behind the scenes that I'm kinda confused about because everything appears to work without a jax.lax.pmean equivalent in my example
@Teculos
Copy link
Author

Teculos commented Dec 4, 2024

just saw that I was mistaken in thinking that there wasn't a nnx.pmap... was confused since it isn't included in the transforms documentation .

Regardless I'd be love to know if my approach roughly approximates the pmap/pmean strategy used in the mentioned repo.

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 5, 2024

Hey @Teculos!

Is my approach to sharding a model with no annotations the best practice for replicating a model across devices?

Because you are just replicating I think you could define pspecs as:

pspecs = jax.sharding.PartitionSpec(None)

Am I correct in thinking this sharding formulation will have the replicated models run the forward pass on the subset of batch observations located on their respective GPUs?

It does feel a bit magical but you sharding visualization indeed shows they are replicated and jit respect and propagate the shardings to things "just work".

with data sharded across GPUs but a replicated model how exactly are gradients calculated/combined?

The compiler will insert all the communication and do the equivalent of pmean for you. Check out JAX's Distributed arrays and automatic parallelization for more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants