-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VLM Support via GPTQ Hooks and Data Pipelines #914
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
…s. Requires patching modeling_llava
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
input_names = state.data.calib.dataset.column_names | ||
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) | ||
try: | ||
run_sequential( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we do "Layer Sequential" and "Subgraph Sequential" ? Sequential being indicative of the data/error propagation while using "layer" and "subgraph" to differentiate between data structures?
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
…allbacks Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
input_names = state.data.calib.dataset.column_names | ||
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) | ||
try: | ||
run_sequential( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm let me think of other descriptors
I think we just want each of the pipelines beyond the basic pipeline to be a little more verbose in its name
) -> HFTracer: | ||
""" | ||
Get a tracer specialized for the given model. The resulting tracer will not trace | ||
inside of sequential targets, ignored targets, or offloaded modules. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand this comment. If the resulting tracer does not trace offloaded modules, how does this work for cases when we have parts of the model offloaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the tracer actually trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracing within sequential targets and ignored targets is unnecessary, and tracing within offloaded modules may result in meta tensors being added to the model graph
When a module is "not traced", this means that the internals of module are not traced, but the module still appears in the graph as a call_module
node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a side note, even if a module contains untraceable code internally, if the internals of the module skip tracing via ignore
or sequential_targets
or has_offloaded_params
, then the model graph as a whole will still be traceable, but just with less granularity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin The tracer traces all of the objects and operations that are need perform a forward pass of the model.
skip_trace_modules = sequential_targets | offloaded_modules | ignore | ||
|
||
class SequentialTracer(HFTracer): | ||
def create_arg(self, a: Any) -> Argument: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain why create_arg
is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
special extension allows models which depend on config values to be traced
I overload this function to insert my own definition for creating an argument which is of type PretrainedConfig
. Many models use values from the config during execution, but torch.fx
is only capable of "baking" a limited set of class types into the model graph.
This code says, whenever the graph would try to reference an instance of a PretrainedConfig
(for example, to get an attribute like config.max_sequence_length
), instead just create a PretrainedConfig
on the fly and initialize it with all of the args from the original config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few things in this file that we should consider upstreaming HF, this might be one of them
|
||
:param subgraphs: list of subgraphs with empty `consumed_names` attributes | ||
""" | ||
# populate consumed_names according to when inputs are last used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this mean by "last used"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some input names are used by multiple subgraphs in the model (for example, the cross attention output is used by every text decoder layer in mllama) while other input names are only used once (for example, the output text decoder layer is only used as the input to the next text decoder layer).
All subgraph outputs are stored in the IntermediatesCache. However, we only want to keep outputs which will be later used as inputs to later subgraphs and vacate outputs which are never used again (this is really only to reduce cpu memory usage). Therefore, for each name, we need to find the index of the subgraph which is the last user of that name. After that, we can vacate that output from the cache.
Side note, outputs which do not lead to inputs are automatically pruned by the instantiation of GraphModule
, and this is validated by check_assumption
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these Traceable model definitions have very opaque changes compared to the reference model definitions. This architecture seems like an intensive blocker to add support for a new model, as it requires a lot of knowledge of tracing limitations. However I understand the need - I'll look in more detail tomorrow
# bug in trace throws an error for variadic | ||
# args and kwargs in function signature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just explaining that you couldn't pass *args, **kwargs
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is explaining why I need to write my own populate_concrete_args
function rather than rely on the one provided by transformers
) -> HFTracer: | ||
""" | ||
Get a tracer specialized for the given model. The resulting tracer will not trace | ||
inside of sequential targets, ignored targets, or offloaded modules. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the tracer actually trace?
# TRACING: Must use MistralModel | ||
class MistralForCausalLM(MistralForCausalLM): | ||
def __init__(self, config): | ||
super(MistralPreTrainedModel, self).__init__(config) | ||
# TRACING: Must use MistralModel | ||
self.model = MistralModel(config) | ||
self.vocab_size = config.vocab_size | ||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks to be the same as the modeling definition https://github.com/huggingface/transformers/blob/12ba96aa3cb3e4ed2a3ffb77b59f53f8ce9ac1fa/src/transformers/models/mistral/modeling_mistral.py#L752-L763
What is the purpose of this comment and code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That comment should say
# TRACING: Must use MistralModel with wrapped _prepare_4d_causal_attention_mask_with_cache_position function
Note that we define a version of the MistralModel
which wraps the problematic function, and it is this definition that is used by MistralForCausalLM
.
@mgoin I think the Tracing Guide will clarify how and why to make changes to your model to make it traceable and why tracing is the best and least invasive solution currently available. Also note that
|
Signed-off-by: Kyle Sayers <[email protected]>
Purpose
Llama_3 2-Vision Graphviz
Related Issues
Prerequisites
Changes
VLM Support
examples/multimodal_vision
custom_offload_device_map
to support models which are notXForCausalLM
src/llmcompressor/transformers/utils/data_collator.py
GPTQModifier
GPTQModifier
gptq_wrapper.py
are now implemented ingptq_quantize.py
offload_hessians
parameter inGPTQModifier
GPTQModifier
sequential
pipelinelayer_sequential
pipelineofflo ad_hessians
ValueError
to a_LinAlgError
so it can be ignored by the gptq pipeline fallback mechanismData Pipelines
LayerCompressor
as a straight-forward data pipelineIntermediatesCache
to handle activation offloadingtorch.fx
to trace the graph in order to determine where sequential targets (layers) exist in the graph and what their inputs and outputs areSubgraph
) is compiled as an executable python function with the proper inputs and outputsIntermediatesCache
to handle activation offloadingIntermediatesCache
which automagically handles the offloading and onloading of activations from batchesTuple
s and dataclasses such asBaseModelOutputWithPast
tests/llmcompressor/pipelines/test_cache.py
Tracing
# vllm-project: no copyright
was added in similar style to text_generation.pyFuture Work/ Follow ups
Winogrande Evaluations
lm_eval --model vllm --model_args pretrained="path/to/model",dtype=auto,max_model_len=4096,tensor_parallel_size=1,gpu_memory_utilization=0.8,enforce_eager=True,add_bos_token=True --tasks winogrande --num_fewshot 5 --batch_size 32
lm_eval --model vllm --model_args pretrained="path/to/model",dtype=bfloat16,max_model_len=4096,tensor_parallel_size=1,gpu_memory_utilization=0.8,enforce_eager=True,add_bos_token=True,max_num_seqs=1 --tasks winogrande --num_fewshot 5 --batch_size 1
MMMU Evaluations
Credit to @shubhra
Testing