Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use_reentrant=Ture的recompute存在BUG #70508

Open
JunnYu opened this issue Dec 27, 2024 · 0 comments
Open

use_reentrant=Ture的recompute存在BUG #70508

JunnYu opened this issue Dec 27, 2024 · 0 comments
Assignees

Comments

@JunnYu
Copy link
Member

JunnYu commented Dec 27, 2024

bug描述 Describe the Bug

当拥有main grad的时候,下面的代码正常结果应该是全为1!

import paddle
from paddle.distributed.fleet.recompute import recompute

class PyLayerMatmul(paddle.autograd.PyLayer):
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        return a @ b

    @staticmethod
    def backward(ctx, dy):
        a, b = ctx.saved_tensor()
        if hasattr(a, "main_grad"):
            a.main_grad.add_(paddle.ones_like(a.main_grad))
        if hasattr(b, "main_grad"):
            b.main_grad.add_(paddle.ones_like(b.main_grad))
        grad_a = paddle.matmul(dy, b, transpose_y=True)
        grad_b = paddle.matmul(a, dy, transpose_x=True)
        return grad_a, grad_b

def pylayer_matmul(x, y):
    return PyLayerMatmul.apply(x, y)

paddle.seed(42)
x = paddle.create_parameter(shape=[3,4], dtype='float32')
# 设置main grad进行模拟
x.main_grad = paddle.zeros([3, 4])
y = paddle.create_parameter(shape=[4, 5], dtype='float32')
# 设置main grad进行模拟
y.main_grad = paddle.zeros([4, 5])

o = recompute(pylayer_matmul, x, y, use_reentrant=True)
o.mean().backward()

print(x.main_grad)
print(y.main_grad)
# Tensor(shape=[3, 4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [[0., 0., 0., 0.],
#         [0., 0., 0., 0.],
#         [0., 0., 0., 0.]])
# Tensor(shape=[4, 5], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [[0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])
# 不符合预期。

https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/recompute/recompute.py#L49-L59

上面的代码应该修改为,修改后就正常了

def _varbase_help(param):
    # 防止deepcopy复制多了一份main grad的显存。
    state = copy.deepcopy({k: v for k, v in inner_x.__dict__.items() if k != "main_grad"})
    new_param = EagerParamBase(
        shape=param.shape,
        dtype=param.dtype,
        trainable=param.trainable,
        name=param.name,
        **state,
    )
    # EagerParamBase不会接收 main_grad
    setattr(new_param, "main_grad", param.main_grad)
    param._share_buffer_to(new_param)
    return new_param

其他补充信息 Additional Supplementary Information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants