Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add arctic model support by adding w2 to all_reduce (#6856)
As title says. Default behavior of arctic model produces shape issues with AutoTP due to the MLP layer performing `w2 * act(w1*w3)`. However, method provided to fix Mixtral-7x8b in #5257 does not work since the MLP for Arctic is also used within a ModuleList for the MoE. This results in MLP weights hiding behind individual experts as layers `#.w#`, which is not caught by the fix in #5257. This adds the check directly within replace, where it can check for actual layer names for the `w2` key in the model to patch with `all_reduce`. --------- Signed-off-by: Daniel Huang <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]>
- Loading branch information