-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Move input transforms to GPyTorch #2114
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me, though we should add some unit tests and docs.
from ..module import Module | ||
|
||
|
||
class GP(Module): | ||
pass | ||
def apply_input_transforms(self, X: Tensor, is_training_input: bool) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to not name this arg is_training
rather than is_training_input
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just being verbose and differentiating the model being in training and the inputs being the train_inputs
or otherwise being treated as such.
This diff presents a minimal implementation of input transforms in GPyTorch, as requested in #1652. This should be viewed together with pytorch/botorch#1372. The input transforms themselves are currently implemented in https://github.com/pytorch/botorch/blob/cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067/botorch/models/transforms/input.py
What this does:
transform_inputs
from BoTorchModel
to GPyTorchGP
class, with some modifications to explicitly identify whether given inputs are train or test inputs.InputTransform.forward
call to useis_training_input
argument instead ofself.training
check to apply the transforms that havetransform_on_train=True
.preprocess_transform
method since this is no-longer needed.ExactGP
models, it transforms both train and test inputs in__call__
. Fortrain_inputs
it always usesis_training_input=True
. For genericinputs
, it usesis_training_input=self.training
which signals that these are training inputs when the model is intrain
mode, and that these are test inputs when the model is ineval
mode.ApproximateGP
models, it applies the transform toinputs
in__call__
usingis_training_input=self.training
. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points
, thus fixes the previous bug withinducing_points
getting transformed intrain
but not getting transformed ineval
. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).SingleTaskVariationalGP
, it moves theinput_transform
attribute down to_SingleTaskVariationalGP
, which is the actualApproximateGP
instance. This makes the transform accessible from GPyTorch.What this doesn't do:
DeterministicModel
s. Those will still need to deal with their own transforms, which is not implemented here. If we makeModel
inherit fromGP
, we can keep the existing setup with very minimal changes.self.transform_inputs
. This is just made into a no-op and the clean-up is left for later.InputTransform
classes to GPyTorch. That'll be done if we decide to go forward with this design.PairwiseGP
.PairwiseGP
has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.ApproximateGP.fantasize
. This may need some changes similar toExactGP.get_fantasy_model
.PyroGP
andDeepGP
.