Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Move input transforms to GPyTorch #2114

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

saitcakmak
Copy link
Collaborator

@saitcakmak saitcakmak commented Aug 31, 2022

This diff presents a minimal implementation of input transforms in GPyTorch, as requested in #1652. This should be viewed together with pytorch/botorch#1372. The input transforms themselves are currently implemented in https://github.com/pytorch/botorch/blob/cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067/botorch/models/transforms/input.py

What this does:

  • Moves the transform_inputs from BoTorch Model to GPyTorch GP class, with some modifications to explicitly identify whether given inputs are train or test inputs.
  • Modifies the InputTransform.forward call to use is_training_input argument instead of self.training check to apply the transforms that have transform_on_train=True.
  • Removes preprocess_transform method since this is no-longer needed.
  • For ExactGP models, it transforms both train and test inputs in __call__. For train_inputs it always uses is_training_input=True. For generic inputs, it uses is_training_input=self.training which signals that these are training inputs when the model is in train mode, and that these are test inputs when the model is in eval mode.
  • For ApproximateGP models, it applies the transform to inputs in __call__ using is_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms inducing_points, thus fixes the previous bug with inducing_points getting transformed in train but not getting transformed in eval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).
  • For BoTorch SingleTaskVariationalGP, it moves the input_transform attribute down to _SingleTaskVariationalGP, which is the actual ApproximateGP instance. This makes the transform accessible from GPyTorch.

What this doesn't do:

  • It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.
  • It does not clean up the call sites for self.transform_inputs. This is just made into a no-op and the clean-up is left for later.
  • It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.
  • It does not touch PairwiseGP. PairwiseGP has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.
  • I didn't look into ApproximateGP.fantasize. This may need some changes similar to ExactGP.get_fantasy_model.
  • It does not support PyroGP and DeepGP.

@saitcakmak
Copy link
Collaborator Author

cc @wjmaddox, @gpleiss, @Balandat

Copy link
Member

@gpleiss gpleiss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me, though we should add some unit tests and docs.

from ..module import Module


class GP(Module):
pass
def apply_input_transforms(self, X: Tensor, is_training_input: bool) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not name this arg is_training rather than is_training_input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just being verbose and differentiating the model being in training and the inputs being the train_inputs or otherwise being treated as such.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants