You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to, roughly, replicate behaviour found in this repo, where they pmap a scan transform on the train step to combine multiple train steps into one function call (see run_lib.py line 124). Since this is a pre flax.nnx implementation they replicate the model and pmap over the model replicates and data (structured as [combined steps, jax.device_count(), batchsize// jax.device_count(), *data dim]).
Ergo they pmap across the second dimension and scan across the first to distribute the forward pass across GPUs, jax.lax.pmean the gradient, update the model, and iterate to the nexted step in the scan.
Since pmap has no flax.nnx equivalent my approach was to shard the data across the batch dimension (data for me is in the shape [combined steps, batch_size, *data dim]) and replicate the model on each GPU to distribute the forward pass. Although I'm not certain if I'm going about it properly. See below for a minimum example with a simple model and random data/labels.
from flax import nnx
from jax.sharding import NamedSharding, PartitionSpec
import jax
import optax
#Data is of shape [steps, batch, data dim]
data = jax.random.normal(jax.random.PRNGKey(1), (5,100,20))
label = jax.numpy.ones((5,100,1))
model = nnx.Sequential(nnx.Linear(20, 30, rngs= nnx.Rngs(0)),
nnx.Linear(30, 1, rngs= nnx.Rngs(1)))
#Unsharded data/model
jax.debug.visualize_array_sharding(data[0])
#shard data
mesh = jax.make_mesh((jax.device_count(),), ("batch", ))
data_sharding = NamedSharding(mesh, PartitionSpec(None,"batch"))
sharded_data = jax.device_put(data, data_sharding)
sharded_label = jax.device_put(label, data_sharding)
#shard model
def create_sharded_model(model):
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
with mesh:
sharded_model = create_sharded_model(model)
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)
#Sharded data/label at this point data should be sharded across GPUs
#with the model replicated (???)
jax.debug.visualize_array_sharding(sharded_data[0])
loss_fn = lambda model, x, y: optax.l2_loss(model(x),y).mean()
def step_fn(batch_data, batch_label, state):
grads = nnx.grad(loss_fn)(state.model,batch_data,batch_label)
state.update(grads=grads)
return state
#combine multiple train steps into a single scan carrying over the state
scanned_train = nnx.jit(nnx.scan(step_fn, in_axes=(0,0,nnx.Carry), out_axes=(nnx.Carry),
transform_metadata={nnx.PARTITION_NAME:"batch"}))
#returns states after 5 scanned+jitted train steps
new_state = scanned_train(sharded_data, sharded_label, state)
Specifically I'd like to know:
Is my approach to sharding a model with no annotations the best practice for replicating a model across devices?
Am I correct in thinking this sharding formulation will have the replicated models run the forward pass on the subset of batch observations located on their respective GPUs?
Bit hard to exactly determine whats happening under the hood here.
with data sharded across GPUs but a replicated model how exactly are gradients calculated/combined?
the repo I'm trying to mirror has the jax.lax.pmean explicitly stated in the losses.py generated loss function (line 229) but it seems like some flax magic is happening behind the scenes that I'm kinda confused about because everything appears to work without a jax.lax.pmean equivalent in my example
The text was updated successfully, but these errors were encountered:
Is my approach to sharding a model with no annotations the best practice for replicating a model across devices?
Because you are just replicating I think you could define pspecs as:
pspecs=jax.sharding.PartitionSpec(None)
Am I correct in thinking this sharding formulation will have the replicated models run the forward pass on the subset of batch observations located on their respective GPUs?
It does feel a bit magical but you sharding visualization indeed shows they are replicated and jit respect and propagate the shardings to things "just work".
with data sharded across GPUs but a replicated model how exactly are gradients calculated/combined?
I'm trying to, roughly, replicate behaviour found in this repo, where they pmap a scan transform on the train step to combine multiple train steps into one function call (see run_lib.py line 124). Since this is a pre flax.nnx implementation they replicate the model and pmap over the model replicates and data (structured as [combined steps, jax.device_count(), batchsize// jax.device_count(), *data dim]).
Ergo they pmap across the second dimension and scan across the first to distribute the forward pass across GPUs, jax.lax.pmean the gradient, update the model, and iterate to the nexted step in the scan.
Since pmap has no flax.nnx equivalent my approach was to shard the data across the batch dimension (data for me is in the shape [combined steps, batch_size, *data dim]) and replicate the model on each GPU to distribute the forward pass. Although I'm not certain if I'm going about it properly. See below for a minimum example with a simple model and random data/labels.
Specifically I'd like to know:
jax.lax.pmean
explicitly stated in thelosses.py
generated loss function (line 229) but it seems like some flax magic is happening behind the scenes that I'm kinda confused about because everything appears to work without ajax.lax.pmean
equivalent in my exampleThe text was updated successfully, but these errors were encountered: