- Changes
- Upgrade Jax from 0.4.33 to 0.4.34.
- Updates the
input_base.Input
API to support configuring input partitioning behavior. - The config fields
batch_axis_names
andseq_axis_names
incausal_lm.Model
are now deprecated. Please useinput_base.Input.input_partitioner
instead. - Updates the
causal_lm.Model
API to support configuring metrics without subclassing. This requires a golden config change.
- Changes
- Upgrade Jax from 0.4.30 to 0.4.33.
- Changes
- Upgrade Python to 3.10
- Fall back to triton backend for qkv in fp32 or with bias on gpu flash attention.
- Changes
- Upgrade Jax from 0.4.28 to 0.4.30.
- Changes
- Add changelog.