-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PEFT model doesn't update params when having changed LoRA config #2295
Comments
Your observation is correct. Of course, re-defining a completely new lora_config = LoraConfig(..., target_modules=["foo"])
model = get_peft_model(base_model, lora_config)
lora_config.target_modules = ["bar"] # <= you expect this to trigger re-initialization of peft model Although it is technically possible to turn each parameter into a |
Sorry, yeah, the model_peft = get_peft_model(model, lora_config)
model_peft.print_trainable_parameters() In my opinion, if you change the |
I think what @d-kleine means is that a model is not re-initialized when being presented with a different config which may be weird when you're in a notebook environment and experimenting with LoRA config values. To make the initial example a bit more concise: from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForTokenClassification
model = AutoModelForTokenClassification.from_pretrained("gpt2", num_labels=5)
lora_config = LoraConfig(
task_type=TaskType.TOKEN_CLS,
target_modules=["c_proj", "c_fc"],
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
lora_config = LoraConfig(
task_type=TaskType.TOKEN_CLS,
target_modules=["h.0.mlp.c_fc", "h.0.mlp.c_fc"], # changed to specific heads
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters() Expected:
Actual:
In the above example the active config is used >>> peft_model.active_peft_config
LoraConfig(task_type=<TaskType.TOKEN_CLS: 'TOKEN_CLS'>, peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='gpt2', revision=None, inference_mode=False, r=8, target_modules={'h.0.mlp.c_fc'}, exclude_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=['classifier', 'score'], init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, eva_config=None, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False)) But the parameter count is not updated. |
Exactly, that is what I was trying to say |
Ah okay, I was wondering if I misunderstood something. Here the issue is that PEFT mutates the base model when calling The source of the surprise here is probably that the base model was mutated by |
When modifying a model with `get_peft_model` that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules. With this patch a warning will be issued with a hint to `.unload()` when calling `get_peft_model` on an already modified model.
When modifying a model with `get_peft_model` that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules. With this patch a warning will be issued with a hint to `.unload()` when calling `get_peft_model` on an already modified model.
As an idea, what do you think about adding an |
From my perspective, this would be overkill, as users can create a second instance with a single line of code, so there is not much gained by adding the argument. IMO the core issue would still remain, i.e. that users would need to be aware of the fact that the base model will be modified. |
I agree with you, sounds reasonable. In addition to adding a warning message, I think it would be great to update the documentation to clearly explain that the base model is modified in-place. This would help users better understand the function's behavior. Also, at least from my point-of-view, the function name |
👍 Would you be interested in updating the docs? I think an info box could be helpful, also adding how to preserve a copy of the original model.
Yes, fully agree. Unfortunately, at this point, renaming the function would create a huge disruption and is thus not an option. |
System Info
I have noticed that when updated the
target_modules
settings in the LoRA config, the PEFT model params remain unchanged. Might affect other PEFT settings too.My assumption is that
get_peft_model()
does not re-instantiate/update its settings once it has been initialized before.System: Windows 11
Python: 3.11
peft: 0.14.0
Who can help?
@BenjaminBossan @sayakpaul
Information
Tasks
examples
folderReproduction
For reproduction in a Jupyter Notebook:
This outputs
But when changing the above code without restarting the kernel to:
and retrieving the trainable params again:
it outputs again
but after the update it should be
Expected behavior
When having updated
lora_config
,get_peft_model()
should retrieve the current config.The text was updated successfully, but these errors were encountered: