Skip to content

Commit

Permalink
Use torch.func API in torch example
Browse files Browse the repository at this point in the history
  • Loading branch information
elcorto committed Jul 12, 2024
1 parent f8261a5 commit 541f988
Showing 1 changed file with 61 additions and 32 deletions.
93 changes: 61 additions & 32 deletions examples/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
jax.hessian(func) -> hess_func
jax.hessian(func)(x) -> hess_func(x) -> DeviceArray
As of torch 2.0, there is torch.func (formerly functorch) which implements a
subset of the jax API (e.g. torch.func.grad() behaves like jax.grad()). There
is also support for using the torch.func API with custom derivatives, using
torch.autograd.Function [3] (not part of this example so far).
resources
https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#autograd
Expand All @@ -37,9 +41,11 @@
[1] https://github.com/pytorch/pytorch/commit/1f4a4aaf643b70ebcb40f388ae5226a41ca57d9b
[2] https://pytorch.org/docs/stable/autograd.html#functional-higher-level-api
[3] https://pytorch.org/docs/stable/notes/extending.func.html
"""

import torch
from torch.func import grad, vmap
import numpy as np

rand = torch.rand
Expand Down Expand Up @@ -70,21 +76,11 @@
func_plain_torch = lambda x: torch.sin(x).pow(2.0).sum()


def copy(x, requires_grad=False):
_x = x.clone().detach()
if not requires_grad:
assert not _x.requires_grad
else:
_x.requires_grad = requires_grad
return _x


# -----------------------------------------------------------------------------
# poor man's jax-like API
# -----------------------------------------------------------------------------


def _wrap_input(func):
"""
Helper for mygrad
"""

def wrapper(_x):
if isinstance(_x, torch.Tensor):
x = _x
Expand All @@ -98,18 +94,11 @@ def wrapper(_x):
return wrapper


# only to make scalar args work
@_wrap_input
def cos(x):
return torch.cos(x)


@_wrap_input
def func(x):
return func_plain_torch(x)

def mygrad(func):
"""
poor man's jax-like API, for torch < 2.0
"""

def grad(func):
@_wrap_input
def _gradfunc(x):
out = func(x)
Expand All @@ -121,17 +110,51 @@ def _gradfunc(x):
return _gradfunc


elementwise_grad = grad
elementwise_mygrad = mygrad


def elementwise_grad(func):
return vmap(grad(func))


# Use _wrap_input here only to make scalar args work
@_wrap_input
def cos(x):
return torch.cos(x)


# Use _wrap_input here only to make scalar args work
@_wrap_input
def func(x):
return func_plain_torch(x)


def copy(x, requires_grad=False):
_x = x.clone().detach()
if not requires_grad:
assert not _x.requires_grad
else:
_x.requires_grad = requires_grad
return _x


def test():
# Check that grad() works
assert torch.allclose(grad(torch.sin)(1.234), cos(1.234))
# Check that mygrad() works
assert torch.allclose(mygrad(torch.sin)(1.234), cos(1.234))
x = rand(10) * 5 - 5
assert torch.allclose(elementwise_mygrad(torch.sin)(x), torch.cos(x))
assert mygrad(func)(x).shape == x.shape

# torch.func.grad()
#
# Float input needs to be a tensor.
##assert torch.allclose(grad(torch.sin)(1.234), cos(1.234))
assert torch.allclose(grad(torch.sin)(torch.tensor(1.234)), cos(1.234))

assert torch.allclose(elementwise_grad(torch.sin)(x), torch.cos(x))
assert grad(func)(x).shape == x.shape
assert grad(func_plain_torch)(x).shape == x.shape

# Show 4 different pytorch grad APIs
# Show 5 different pytorch grad APIs
x1 = rand(3, requires_grad=True)

# 1
Expand All @@ -157,9 +180,15 @@ def test():
g2 = torch.autograd.functional.vjp(func_plain_torch, x2)[1]
assert (g1 == g2).all()

# jax-like functional API defined here
# 5
# Limited lax-like support in torch 2.x
x2 = copy(x1)
g2 = torch.func.grad(func_plain_torch)(x2)
assert (g1 == g2).all()

# jax-like functional API defined here for torch < 2.0
x2 = copy(x1)
g2 = grad(func)(x2)
g2 = mygrad(func)(x2)
assert (g1 == g2).all()


Expand Down

0 comments on commit 541f988

Please sign in to comment.