You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
model is AutoModelForCausalLM.from_pretrained(), torch.nn.Transformer
Make it generalizable
APIs
model = AutoModel.from_pretrained(...)
ParallelMapping(model).is_mlp(module)
Write a function that:
ParallelMapping(model).is_column_parallel(name, module) returns True if the module is the first linear layer in an MLP layer, or if the module is a query, key, value linear, or a fused qkv linear of an attention layer, or an input embedding; otherwise, return False.
ParallelMapping(model).is_row_parallel(name, module) returns True if the module is the second linear layer in an MLP layer, or the output projection of an attention layer.
ParallelMapping(model).is_lm_head(name, module) returns True if the module is language model head.
ParallelMapping(model).is_text_embedding(name, module) returns True if the module is text embedding module.
ParallelMapping(model).is_mlp(name, module) returns True if the module is an MLP layer
The text was updated successfully, but these errors were encountered:
I'm taking a stab at this issue. I'll first come up with a solution that works for at least bloom-560m, which is already mapped in ParallelMapping, such that we can see if the automatic mapping works. Then we can work on making it more general/generalizable.
Notes
pipegoose.nn.parallel_mapping.ParallelMapping
module
is an instance inmodel.named_modules()
AutoModelForCausalLM.from_pretrained()
,torch.nn.Transformer
APIs
Write a function that:
ParallelMapping(model).is_column_parallel(name, module)
returns True if the module is the first linear layer in an MLP layer, or if the module is a query, key, value linear, or a fused qkv linear of an attention layer, or an input embedding; otherwise, return False.ParallelMapping(model).is_row_parallel(name, module)
returns True if the module is the second linear layer in an MLP layer, or the output projection of an attention layer.ParallelMapping(model).is_lm_head(name, module)
returns True if the module is language model head.ParallelMapping(model).is_text_embedding(name, module)
returns True if the module is text embedding module.ParallelMapping(model).is_mlp(name, module)
returns True if the module is an MLP layerThe text was updated successfully, but these errors were encountered: