Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT model doesn't update params when having changed LoRA config #2295

Open
4 tasks done
d-kleine opened this issue Dec 23, 2024 · 9 comments
Open
4 tasks done

PEFT model doesn't update params when having changed LoRA config #2295

d-kleine opened this issue Dec 23, 2024 · 9 comments

Comments

@d-kleine
Copy link
Contributor

d-kleine commented Dec 23, 2024

System Info

I have noticed that when updated the target_modules settings in the LoRA config, the PEFT model params remain unchanged. Might affect other PEFT settings too.

My assumption is that get_peft_model() does not re-instantiate/update its settings once it has been initialized before.

System: Windows 11
Python: 3.11
peft: 0.14.0

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

For reproduction in a Jupyter Notebook:

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

label_list = ['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForTokenClassification.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    pad_token_id=tokenizer.eos_token_id,
    torch_dtype=torch.bfloat16,
    device_map="auto", 
    num_labels=len(label_list)
)

for name, module in model.named_modules():
    print(name)
lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=16,             
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj"],  
    lora_dropout=0.1
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

This outputs

trainable params: 1,722,377 || all params: 1,237,555,218 || trainable%: 0.1392

But when changing the above code without restarting the kernel to:

lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=16,             
    lora_alpha=32, 
    target_modules=["layers.0.self_attn.q_proj", "layers.0.self_attn.v_proj"], # changed to specific heads
    lora_dropout=0.1
)

and retrieving the trainable params again:

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

it outputs again

trainable params: 1,722,377 || all params: 1,237,555,218 || trainable%: 0.1392

but after the update it should be

trainable params: 124,937 || all params: 1,235,957,778 || trainable%: 0.0101

Expected behavior

When having updated lora_config, get_peft_model() should retrieve the current config.

@BenjaminBossan
Copy link
Member

Your observation is correct. Of course, re-defining a completely new lora_config cannot influence the model, as this just defines a new, unrelated variable that just happens to have the same name. Probably what you mean is that you would like to change the attribute on the existing lora_config:

lora_config = LoraConfig(..., target_modules=["foo"])
model = get_peft_model(base_model, lora_config)
lora_config.target_modules = ["bar"]  # <= you expect this to trigger re-initialization of peft model

Although it is technically possible to turn each parameter into a @property and define a setter that re-initializes the model each time the config is changed, I'd say this is not worth the effort. Intuitively, I also lean towards the current behavior being more intuitive, but that's hard to say.

@d-kleine
Copy link
Contributor Author

d-kleine commented Dec 27, 2024

Your observation is correct. Of course, re-defining a completely new lora_config cannot influence the model, as this just defines a new, unrelated variable that just happens to have the same name.

Sorry, yeah, the model assignment was a bad example. I mean if you save it like that:

model_peft = get_peft_model(model, lora_config)
model_peft.print_trainable_parameters()

In my opinion, if you change the lora_config, these changes should be retrieved by get_peft_model(). Currently, there is not even a warning that the config has changed but the PEFT model doesn't reflect these changes. I don't find this very intuitive.

@githubnemo
Copy link
Collaborator

I think what @d-kleine means is that a model is not re-initialized when being presented with a different config which may be weird when you're in a notebook environment and experimenting with LoRA config values.

To make the initial example a bit more concise:

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("gpt2", num_labels=5)

lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    target_modules=["c_proj", "c_fc"],
)

peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    target_modules=["h.0.mlp.c_fc", "h.0.mlp.c_fc"], # changed to specific heads
)

peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

Expected:

trainable params: 888,581 || all params: 125,332,234 || trainable%: 0.7090
trainable params: 34,565 || all params: 124,478,218 || trainable%: 0.0278

Actual:

trainable params: 888,581 || all params: 125,332,234 || trainable%: 0.7090
trainable params: 888,581 || all params: 125,332,234 || trainable%: 0.7090

In the above example the active config is used

>>> peft_model.active_peft_config
LoraConfig(task_type=<TaskType.TOKEN_CLS: 'TOKEN_CLS'>, peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='gpt2', revision=None, inference_mode=False, r=8, target_modules={'h.0.mlp.c_fc'}, exclude_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=['classifier', 'score'], init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, eva_config=None, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False))

But the parameter count is not updated.

@d-kleine
Copy link
Contributor Author

d-kleine commented Jan 5, 2025

I think what @d-kleine means is that a model is not re-initialized when being presented with a different config which may be weird when you're in a notebook environment and experimenting with LoRA config values.

Exactly, that is what I was trying to say

@BenjaminBossan
Copy link
Member

Ah okay, I was wondering if I misunderstood something.

Here the issue is that PEFT mutates the base model when calling get_peft_model, which results in all ["c_proj", "c_fc"] being transformed into LoRA layers. Therefore, when calling get_peft_model again, this time with target_modules=["h.0.mlp.c_fc", "h.0.mlp.c_fc"] (btw. that's the same name twice), the existing LoRA layers are still there and since the new targets are a subset of the existing targets, nothing changes when it comes to the parameter count.

The source of the surprise here is probably that the base model was mutated by get_peft_model. This has confused a couple of users in the past but it's unfortunately not something we can really avoid (we could create a copy of the model but that's very costly and not what most users would want). I'm not sure if it would have helped in this situation, but we should add a line to the get_peft_model docstring to explain that the base model will be modified in-place.

githubnemo pushed a commit to githubnemo/peft that referenced this issue Jan 6, 2025
When modifying a model with `get_peft_model` that was already modified
in the same way, even specifying a different config may not change
the trainable parameter count, e.g. when specifying target modules that
are only a subset of the previous target modules.

With this patch a warning will be issued with a hint to `.unload()`
when calling `get_peft_model` on an already modified model.
githubnemo pushed a commit to githubnemo/peft that referenced this issue Jan 6, 2025
When modifying a model with `get_peft_model` that was already modified
in the same way, even specifying a different config may not change
the trainable parameter count, e.g. when specifying target modules that
are only a subset of the previous target modules.

With this patch a warning will be issued with a hint to `.unload()`
when calling `get_peft_model` on an already modified model.
@d-kleine
Copy link
Contributor Author

d-kleine commented Jan 7, 2025

The source of the surprise here is probably that the base model was mutated by get_peft_model. This has confused a couple of users in the past but it's unfortunately not something we can really avoid (we could create a copy of the model but that's very costly and not what most users would want). I'm not sure if it would have helped in this situation, but we should add a line to the get_peft_model docstring to explain that the base model will be modified in-place.

As an idea, what do you think about adding an inplace parameter that defaults to True? This would make the behavior explicit and give users control over whether they want to modify the original model or create a copy.

@BenjaminBossan
Copy link
Member

As an idea, what do you think about adding an inplace parameter that defaults to True? This would make the behavior explicit and give users control over whether they want to modify the original model or create a copy.

From my perspective, this would be overkill, as users can create a second instance with a single line of code, so there is not much gained by adding the argument. IMO the core issue would still remain, i.e. that users would need to be aware of the fact that the base model will be modified.

@d-kleine
Copy link
Contributor Author

d-kleine commented Jan 7, 2025

From my perspective, this would be overkill, as users can create a second instance with a single line of code, so there is not much gained by adding the argument. IMO the core issue would still remain, i.e. that users would need to be aware of the fact that the base model will be modified.

I agree with you, sounds reasonable. In addition to adding a warning message, I think it would be great to update the documentation to clearly explain that the base model is modified in-place. This would help users better understand the function's behavior.

Also, at least from my point-of-view, the function name get_peft_model() doesn't quite capture/reflect what's actually happening under the hood. A name like convert_to_peft_model() (or something similar and shorter) would be more intuitive since it explicitly indicates that the function transforms the base model by adding PEFT-specific layers. While renaming would require significant changes across the codebase, it would make the API more transparent and user-friendly. The current name implies it simply returns a new model instance, which could be misleading since it actually modifies the original model's architecture.

@BenjaminBossan
Copy link
Member

I agree with you, sounds reasonable. In addition to adding a warning message, I think it would be great to update the documentation to clearly explain that the base model is modified in-place. This would help users better understand the function's behavior.

👍 Would you be interested in updating the docs? I think an info box could be helpful, also adding how to preserve a copy of the original model.

the function name get_peft_model() doesn't quite capture/reflect what's actually happening under the hood

Yes, fully agree. Unfortunately, at this point, renaming the function would create a huge disruption and is thus not an option.

@githubnemo githubnemo reopened this Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants