Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add XLM-RoBERTa in paddlenlp #9720

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

jie-z-0607
Copy link
Contributor

PR types

New features

PR changes

Models

Description

在PaddleNLP中增加对于XLM-RoBERTa模型的支持

Examples:

```python
>>> from ppdiffusers.transformers import XLMRobertaConfig, XLMRobertaModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改一下文档

classifier_dropout=None,
**kwargs,
):
kwargs["return_dict"] = kwargs.pop("return_dict", True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我当时是跟transformers逻辑一样,默认值return_dict为True,而paddlenlp基本上所有模型都是False,需要决策一下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为False吧

Comment on lines 484 to 485
if self.gradient_checkpointing and not hidden_states.stop_gradient:
layer_outputs = self._gradient_checkpointing_func(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient_checkpointing -> recompute,参照paddlenlp的改一下吧

all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

if self.gradient_checkpointing and self.training:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也是

super().__init__()
self.config = config
self.layer = nn.LayerList([XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也改了吧

Copy link
Collaborator

@DrownFish19 DrownFish19 Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成self.enable_recompute=False

Comment on lines 563 to 580
_deprecated_dict = {
"key": ".self_attn.q_proj.",
"name_mapping": {
# common
"encoder.layers.": "encoder.layer.",
# embeddings
"embeddings.layer_norm.": "embeddings.LayerNorm.",
# transformer
".self_attn.q_proj.": ".attention.self.query.",
".self_attn.k_proj.": ".attention.self.key.",
".self_attn.v_proj.": ".attention.self.value.",
".self_attn.out_proj.": ".attention.output.dense.",
".norm1.": ".attention.output.LayerNorm.",
".linear1.": ".intermediate.dense.",
".linear2.": ".output.dense.",
".norm2.": ".output.LayerNorm.",
},
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里删了,没有用


from paddlenlp.transformers.tokenizer_utils import AddedToken
from paddlenlp.transformers.tokenizer_utils import (
PretrainedTokenizer as PPNLPPretrainedTokenizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不用as直接PretrainedTokenizer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为相对路径

__all__ = ["XLMRobertaTokenizer"]


class XLMRobertaTokenizer(PPNLPPretrainedTokenizer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也修改

@JunnYu
Copy link
Member

JunnYu commented Dec 31, 2024

auto部分也要加
image

Comment on lines 961 to 978
class ModuleUtilsMixin:
"""
A few utilities for `nn.Layer`, to be used as a mixin.
"""

# @property
# def device(self):
# """
# `paddle.place`: The device on which the module is (assuming that all the module parameters are on the same
# device).
# """
# try:
# return next(self.named_parameters())[1].place
# except StopIteration:
# try:
# return next(self.named_buffers())[1].place
# except StopIteration:
# return paddle.get_device()
Copy link
Member

@JunnYu JunnYu Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的代码加入可能会影响已有的很多模型,得仔细看一下

@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里少一个paddle的copyright

classifier_dropout=None,
**kwargs,
):
kwargs["return_dict"] = kwargs.pop("return_dict", True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为False吧

@@ -0,0 +1,1517 @@
# coding=utf-8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加paddle的copyright

from paddle import nn
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from paddlenlp.transformers.activations import ACT2FN
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from paddlenlp 这些都改成相对路径吧

super().__init__()
self.config = config
self.layer = nn.LayerList([XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
Copy link
Collaborator

@DrownFish19 DrownFish19 Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成self.enable_recompute=False

Example:

```python
>>> from ppdiffusers.transformers import AutoTokenizer, XLMRobertaForCausalLM, AutoConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上修改文档


from paddlenlp.transformers.tokenizer_utils import AddedToken
from paddlenlp.transformers.tokenizer_utils import (
PretrainedTokenizer as PPNLPPretrainedTokenizer,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为相对路径

@DrownFish19
Copy link
Collaborator

在PaddleNLP/paddlenlp/transformers/auto文件里增加对应的模型、tokenizer映射

Copy link

codecov bot commented Dec 31, 2024

Codecov Report

Attention: Patch coverage is 17.89474% with 624 lines in your changes missing coverage. Please review.

Project coverage is 52.55%. Comparing base (dff62a1) to head (e4c1f12).
Report is 3 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/xlm_roberta/modeling.py 15.03% 548 Missing ⚠️
paddlenlp/transformers/xlm_roberta/tokenizer.py 32.18% 59 Missing ⚠️
...addlenlp/transformers/xlm_roberta/configuration.py 22.72% 17 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9720      +/-   ##
===========================================
- Coverage    53.20%   52.55%   -0.66%     
===========================================
  Files          719      722       +3     
  Lines       115583   113254    -2329     
===========================================
- Hits         61493    59515    -1978     
+ Misses       54090    53739     -351     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ZHUI
Copy link
Collaborator

ZHUI commented Jan 2, 2025

加两个单测,测试一下,模型初始化,tokenier 加载。

@JunnYu
Copy link
Member

JunnYu commented Jan 2, 2025

新增对应的单测脚本

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants