We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
当拥有main grad的时候,下面的代码正常结果应该是全为1!
import paddle from paddle.distributed.fleet.recompute import recompute class PyLayerMatmul(paddle.autograd.PyLayer): @staticmethod def forward(ctx, a, b): ctx.save_for_backward(a, b) return a @ b @staticmethod def backward(ctx, dy): a, b = ctx.saved_tensor() if hasattr(a, "main_grad"): a.main_grad.add_(paddle.ones_like(a.main_grad)) if hasattr(b, "main_grad"): b.main_grad.add_(paddle.ones_like(b.main_grad)) grad_a = paddle.matmul(dy, b, transpose_y=True) grad_b = paddle.matmul(a, dy, transpose_x=True) return grad_a, grad_b def pylayer_matmul(x, y): return PyLayerMatmul.apply(x, y) paddle.seed(42) x = paddle.create_parameter(shape=[3,4], dtype='float32') # 设置main grad进行模拟 x.main_grad = paddle.zeros([3, 4]) y = paddle.create_parameter(shape=[4, 5], dtype='float32') # 设置main grad进行模拟 y.main_grad = paddle.zeros([4, 5]) o = recompute(pylayer_matmul, x, y, use_reentrant=True) o.mean().backward() print(x.main_grad) print(y.main_grad) # Tensor(shape=[3, 4], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [[0., 0., 0., 0.], # [0., 0., 0., 0.], # [0., 0., 0., 0.]]) # Tensor(shape=[4, 5], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [[0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0.]]) # 不符合预期。
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/recompute/recompute.py#L49-L59
上面的代码应该修改为,修改后就正常了
def _varbase_help(param): # 防止deepcopy复制多了一份main grad的显存。 state = copy.deepcopy({k: v for k, v in inner_x.__dict__.items() if k != "main_grad"}) new_param = EagerParamBase( shape=param.shape, dtype=param.dtype, trainable=param.trainable, name=param.name, **state, ) # EagerParamBase不会接收 main_grad setattr(new_param, "main_grad", param.main_grad) param._share_buffer_to(new_param) return new_param
No response
The text was updated successfully, but these errors were encountered:
winter-wang
No branches or pull requests
bug描述 Describe the Bug
当拥有main grad的时候,下面的代码正常结果应该是全为1!
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/recompute/recompute.py#L49-L59
上面的代码应该修改为,修改后就正常了
其他补充信息 Additional Supplementary Information
No response
The text was updated successfully, but these errors were encountered: