Skip to content

Commit

Permalink
Merge branch 'main' into metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelizimm authored Dec 19, 2022
2 parents 3edaa19 + 612d957 commit 4e7e259
Show file tree
Hide file tree
Showing 29 changed files with 276 additions and 179 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ from vetiver import mock, VetiverModel
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)

v = VetiverModel(model, save_ptype=True, ptype_data=X, model_name='mock_model')
v = VetiverModel(model, model_name='mock_model', prototype_data=X)
```

You can **version** and **share** your `VetiverModel()` by choosing a [pins](https://rstudio.github.io/pins-python/) "board" for it, including a local folder, RStudio Connect, Amazon S3, and more.
Expand All @@ -63,7 +63,7 @@ You can **deploy** your pinned `VetiverModel()` using `VetiverAPI()`, an extensi

```python
from vetiver import VetiverAPI
app = VetiverAPI(v, check_ptype = True)
app = VetiverAPI(v, check_prototype = True)
```

To start a server using this object, use `app.run(port = 8080)` or your port of choice.
Expand Down
3 changes: 2 additions & 1 deletion vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Change to import.metadata when minimum python>=3.8
from importlib_metadata import version as _version

from .ptype import * # noqa
from .prototype import * # noqa
from .vetiver_model import VetiverModel # noqa
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
from .mock import get_mock_data, get_mock_model # noqa
Expand All @@ -19,6 +19,7 @@
from .rsconnect import deploy_rsconnect # noqa
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
from .model_card import model_card # noqa
from .types import create_prototype, Prototype # noqa

__author__ = "Isabel Zimmerman <[email protected]>"
__all__ = []
Expand Down
36 changes: 18 additions & 18 deletions vetiver/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from functools import singledispatch
from contextlib import suppress

from ..ptype import vetiver_create_ptype
from ..meta import VetiverMeta
from ..prototype import vetiver_create_prototype


class InvalidModelError(Exception):
Expand All @@ -20,14 +20,14 @@ def __init__(


@singledispatch
def create_handler(model, ptype_data):
def create_handler(model, prototype_data):
"""check for model type to handle prediction
Parameters
----------
model: object
Description of parameter `x`.
ptype_data : object
prototype_data : object
An object with information (data) whose layout is to be determined.
Returns
Expand Down Expand Up @@ -63,7 +63,7 @@ class BaseHandler:
----------
model :
a trained model
ptype_data :
prototype_data :
An object with information (data) whose layout is to be determined.
"""

Expand All @@ -73,9 +73,9 @@ def __init_subclass__(cls, **kwargs):
with suppress(AttributeError, NameError):
create_handler.register(cls.model_class(), cls)

def __init__(self, model, ptype_data):
def __init__(self, model, prototype_data):
self.model = model
self.ptype_data = ptype_data
self.prototype_data = prototype_data

def describe(self):
"""Create description for model"""
Expand All @@ -94,21 +94,21 @@ def create_meta(self, metadata):

return VetiverMeta.from_dict(metadata, pip_name)

def construct_ptype(self):
def construct_prototype(self):
"""Create data prototype for a model
Parameters
----------
ptype_data : pd.DataFrame, np.ndarray, or None
Training data to create ptype
prototype_data : pd.DataFrame, np.ndarray, or None
Training data to create prototype
Returns
-------
ptype : pd.DataFrame or None
prototype : pd.DataFrame or None
Zero-row DataFrame for storing data types
"""
ptype = vetiver_create_ptype(self.ptype_data)
return ptype
prototype = vetiver_create_prototype(self.prototype_data)
return prototype

def handler_startup():
"""Include required packages for prediction
Expand All @@ -118,7 +118,7 @@ def handler_startup():
"""
...

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand All @@ -129,8 +129,8 @@ def handler_predict(self, input_data, check_ptype):
----------
input_data:
Data used to generate prediction
check_ptype:
If type should be checked against `ptype` or not
check_prototype:
If type should be checked against `prototype` or not
Returns
-------
Expand All @@ -145,8 +145,8 @@ def handler_predict(self, input_data, check_ptype):


@create_handler.register
def _(model: base.BaseHandler, ptype_data):
if model.ptype_data is None and ptype_data is not None:
model.ptype_data = ptype_data
def _(model: base.BaseHandler, prototype_data):
if model.prototype_data is None and prototype_data is not None:
model.prototype_data = prototype_data

return model
4 changes: 2 additions & 2 deletions vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class SKLearnHandler(BaseHandler):
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
pip_name = "scikit-learn"

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand All @@ -34,7 +34,7 @@ def handler_predict(self, input_data, check_ptype):
Prediction from model
"""

if not check_ptype or isinstance(input_data, pd.DataFrame):
if not check_prototype or isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])
Expand Down
2 changes: 1 addition & 1 deletion vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class StatsmodelsHandler(BaseHandler):
if sm_exists:
pip_name = "statsmodels"

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand Down
6 changes: 3 additions & 3 deletions vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TorchHandler(BaseHandler):
if torch_exists:
pip_name = "torch"

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand All @@ -41,8 +41,8 @@ def handler_predict(self, input_data, check_ptype):
"""
if not torch_exists:
raise ImportError("Cannot import `torch`.")
if check_ptype:
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
if check_prototype:
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
prediction = self.model(torch.from_numpy(input_data))

# do not check ptype
Expand Down
2 changes: 1 addition & 1 deletion vetiver/handlers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class XGBoostHandler(BaseHandler):
if xgb_exists:
pip_name = "xgboost"

def handler_predict(self, input_data, check_ptype):
def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
Expand Down
6 changes: 5 additions & 1 deletion vetiver/pin_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
>>> model_board = board_temp(versioned = True, allow_pickle_read = True)
>>> X, y = vetiver.get_mock_data()
>>> model = vetiver.get_mock_model().fit(X, y)
>>> v = vetiver.VetiverModel(model = model, model_name = "my_model", ptype_data = X)
>>> v = vetiver.VetiverModel(model, "my_model", prototype_data = X)
>>> vetiver.vetiver_pin_write(model_board, v)
"""
if not board.allow_pickle_read:
Expand All @@ -55,6 +55,10 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
if isinstance(model.metadata, dict):
model.metadata = VetiverMeta.from_dict(model.metadata)

# convert older model's ptype to prototype
if hasattr(model, "ptype"):
model.prototype = model.ptype

board.pin_write(
model.model,
name=model.model_name,
Expand Down
Loading

0 comments on commit 4e7e259

Please sign in to comment.