Skip to content

Commit

Permalink
Merge pull request #138 from rstudio/prototype
Browse files Browse the repository at this point in the history
update `ptype_data` to `prototype_data`
  • Loading branch information
isabelizimm authored Dec 19, 2022
2 parents fea51a5 + 6403b57 commit 612d957
Show file tree
Hide file tree
Showing 30 changed files with 280 additions and 184 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ jobs:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e ".[dev]"
- name: Run Docker
python -m pip install ".[dev]"
python -m pip install --upgrade git+https://github.com/rstudio/vetiver-python@${{ github.sha }}
- name: run Docker
run: |
python script/setup-docker/docker.py
pip freeze > vetiver_requirements.txt
docker build -t mock .
docker run -d -v $PWD/pinsboard:/vetiver/pinsboard -p 8080:8080 mock
sleep 5
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ from vetiver import mock, VetiverModel
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)

v = VetiverModel(model, save_ptype=True, ptype_data=X, model_name='mock_model')
v = VetiverModel(model, model_name='mock_model', prototype_data=X)
```

You can **version** and **share** your `VetiverModel()` by choosing a [pins](https://rstudio.github.io/pins-python/) "board" for it, including a local folder, RStudio Connect, Amazon S3, and more.
Expand All @@ -63,7 +63,7 @@ You can **deploy** your pinned `VetiverModel()` using `VetiverAPI()`, an extensi

```python
from vetiver import VetiverAPI
app = VetiverAPI(v, check_ptype = True)
app = VetiverAPI(v, check_prototype = True)
```

To start a server using this object, use `app.run(port = 8080)` or your port of choice.
Expand Down
3 changes: 2 additions & 1 deletion vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Change to import.metadata when minimum python>=3.8
from importlib_metadata import version as _version

from .ptype import * # noqa
from .prototype import * # noqa
from .vetiver_model import VetiverModel # noqa
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
from .mock import get_mock_data, get_mock_model # noqa
Expand All @@ -19,6 +19,7 @@
from .rsconnect import deploy_rsconnect # noqa
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
from .model_card import model_card # noqa
from .types import create_prototype, Prototype # noqa

__author__ = "Isabel Zimmerman <[email protected]>"
__all__ = []
Expand Down
36 changes: 18 additions & 18 deletions vetiver/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import singledispatch
from contextlib import suppress

from ..ptype import vetiver_create_ptype
from ..prototype import vetiver_create_prototype
from ..meta import _model_meta


Expand All @@ -20,14 +20,14 @@ def __init__(


@singledispatch
def create_handler(model, ptype_data):
def create_handler(model, prototype_data):
"""check for model type to handle prediction
Parameters
----------
model: object
Description of parameter `x`.
ptype_data : object
prototype_data : object
An object with information (data) whose layout is to be determined.
Returns
Expand Down Expand Up @@ -63,7 +63,7 @@ class BaseHandler:
----------
model :
a trained model
ptype_data :
prototype_data :
An object with information (data) whose layout is to be determined.
"""

Expand All @@ -73,9 +73,9 @@ def __init_subclass__(cls, **kwargs):
with suppress(AttributeError, NameError):
create_handler.register(cls.model_class(), cls)

def __init__(self, model, ptype_data):
def __init__(self, model, prototype_data):
self.model = model
self.ptype_data = ptype_data
self.prototype_data = prototype_data

def describe(self):
"""Create description for model"""
Expand All @@ -93,21 +93,21 @@ def create_meta(

return meta

def construct_ptype(self):
def construct_prototype(self):
"""Create data prototype for a model
Parameters
----------
ptype_data : pd.DataFrame, np.ndarray, or None
Training data to create ptype
prototype_data : pd.DataFrame, np.ndarray, or None
Training data to create prototype
Returns
-------
ptype : pd.DataFrame or None
prototype : pd.DataFrame or None
Zero-row DataFrame for storing data types
"""
ptype = vetiver_create_ptype(self.ptype_data)
return ptype
prototype = vetiver_create_prototype(self.prototype_data)
return prototype

def handler_startup():
"""Include required packages for prediction
Expand All @@ -117,7 +117,7 @@ def handler_startup():
"""
...

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand All @@ -128,8 +128,8 @@ def handler_predict(self, input_data, check_ptype):
----------
input_data:
Data used to generate prediction
check_ptype:
If type should be checked against `ptype` or not
check_prototype:
If type should be checked against `prototype` or not
Returns
-------
Expand All @@ -144,8 +144,8 @@ def handler_predict(self, input_data, check_ptype):


@create_handler.register
def _(model: base.BaseHandler, ptype_data):
if model.ptype_data is None and ptype_data is not None:
model.ptype_data = ptype_data
def _(model: base.BaseHandler, prototype_data):
if model.prototype_data is None and prototype_data is not None:
model.prototype_data = prototype_data

return model
4 changes: 2 additions & 2 deletions vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_meta(

return meta

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand All @@ -51,7 +51,7 @@ def handler_predict(self, input_data, check_ptype):
Prediction from model
"""

if not check_ptype or isinstance(input_data, pd.DataFrame):
if not check_prototype or isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])
Expand Down
5 changes: 1 addition & 4 deletions vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class StatsmodelsHandler(BaseHandler):

model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)

def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)

def describe(self):
"""Create description for statsmodels model"""
desc = f"Statsmodels {self.model.__class__} model."
Expand All @@ -41,7 +38,7 @@ def create_meta(

return meta

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand Down
6 changes: 3 additions & 3 deletions vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_meta(

return meta

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand All @@ -57,8 +57,8 @@ def handler_predict(self, input_data, check_ptype):
"""
if not torch_exists:
raise ImportError("Cannot import `torch`.")
if check_ptype:
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
if check_prototype:
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
prediction = self.model(torch.from_numpy(input_data))

# do not check ptype
Expand Down
2 changes: 1 addition & 1 deletion vetiver/handlers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_meta(

return meta

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand Down
8 changes: 6 additions & 2 deletions vetiver/pin_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
>>> model_board = board_temp(versioned = True, allow_pickle_read = True)
>>> X, y = vetiver.get_mock_data()
>>> model = vetiver.get_mock_model().fit(X, y)
>>> v = vetiver.VetiverModel(model = model, model_name = "my_model", ptype_data = X)
>>> v = vetiver.VetiverModel(model, "my_model", prototype_data = X)
>>> vetiver.vetiver_pin_write(model_board, v)
"""
if not board.allow_pickle_read:
Expand All @@ -51,14 +51,18 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
"with vetiver.model_card()",
)

# convert older model's ptype to prototype
if hasattr(model, "ptype"):
model.prototype = model.ptype

board.pin_write(
model.model,
name=model.model_name,
type="joblib",
description=model.description,
metadata={
"required_pkgs": model.metadata.get("required_pkgs"),
"ptype": None if model.ptype is None else model.ptype().json(),
"prototype": None if model.prototype is None else model.prototype().json(),
},
versioned=versioned,
)
Expand Down
Loading

0 comments on commit 612d957

Please sign in to comment.