diff --git a/vetiver/meta.py b/vetiver/meta.py index d2a10f93..a09e2e02 100644 --- a/vetiver/meta.py +++ b/vetiver/meta.py @@ -1,3 +1,4 @@ +import sys from dataclasses import dataclass, asdict, field from typing import Mapping @@ -10,6 +11,7 @@ class VetiverMeta: version: "str | None" = None url: "str | None" = None required_pkgs: "list | None" = field(default_factory=list) + python_version: "tuple | None" = None def to_dict(self) -> Mapping: data = asdict(self) @@ -25,9 +27,10 @@ def from_dict(cls, metadata, pip_name=None) -> "VetiverMeta": version = metadata.get("version", None) url = metadata.get("url", None) required_pkgs = metadata.get("required_pkgs", []) + python_version = tuple(metadata.get("python_version", sys.version_info)) if pip_name: if not list(filter(lambda x: pip_name in x, required_pkgs)): required_pkgs = required_pkgs + [f"{pip_name}"] - return cls(user, version, url, required_pkgs) + return cls(user, version, url, required_pkgs, python_version) diff --git a/vetiver/pin_read_write.py b/vetiver/pin_read_write.py index 8da5c872..db2b44ea 100644 --- a/vetiver/pin_read_write.py +++ b/vetiver/pin_read_write.py @@ -70,6 +70,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True): "vetiver_meta": { "required_pkgs": model.metadata.required_pkgs, "prototype": None if not model.prototype else model.prototype().json(), + "python_version": list(model.metadata.python_version), }, }, versioned=versioned, diff --git a/vetiver/tests/test_build_vetiver_model.py b/vetiver/tests/test_build_vetiver_model.py index 2832ea0c..354c5482 100644 --- a/vetiver/tests/test_build_vetiver_model.py +++ b/vetiver/tests/test_build_vetiver_model.py @@ -1,4 +1,5 @@ import sklearn +import sys import vetiver as vt from vetiver.meta import VetiverMeta @@ -113,6 +114,7 @@ def test_vetiver_model_use_ptype(): version=None, url=None, required_pkgs=["scikit-learn"], + python_version=tuple(sys.version_info), ) @@ -137,5 +139,41 @@ def test_vetiver_model_from_pin(): assert v2.metadata.user == {"test": 123} assert v2.metadata.version is not None assert v2.metadata.required_pkgs == ["scikit-learn"] + assert v2.metadata.python_version == tuple(sys.version_info) + + board.pin_delete("model") + + +def test_vetiver_model_from_pin_user_metadata(): + """ + Test if standard keys as part of :dataclass:`VetiverMeta` are picked + """ + custom_meta = { + "test": 123, + "required_pkgs": ["foo", "bar"], + "python_version": [3, 10, 6, "final", 0], + } + loaded_pkgs = custom_meta["required_pkgs"] + ["scikit-learn"] + + v = vt.VetiverModel( + model=model, + prototype_data=X_df, + model_name="model", + versioned=None, + description=None, + metadata=custom_meta, + ) + + board = pins.board_temp(allow_pickle_read=True) + vt.vetiver_pin_write(board=board, model=v) + v2 = vt.VetiverModel.from_pin(board, "model") + + assert isinstance(v2, vt.VetiverModel) + assert isinstance(v2.model, sklearn.base.BaseEstimator) + assert isinstance(v2.prototype.construct(), pydantic.BaseModel) + assert v2.metadata.user == custom_meta + assert v2.metadata.version is not None + assert v2.metadata.required_pkgs == loaded_pkgs + assert v2.metadata.python_version == tuple(custom_meta["python_version"]) board.pin_delete("model") diff --git a/vetiver/vetiver_model.py b/vetiver/vetiver_model.py index 9cef84c5..cfa7868b 100644 --- a/vetiver/vetiver_model.py +++ b/vetiver/vetiver_model.py @@ -102,6 +102,7 @@ def from_pin(cls, board, name: str, version: str = None): if "vetiver_meta" in meta.user: get_prototype = meta.user.get("vetiver_meta").get("prototype", None) required_pkgs = meta.user.get("vetiver_meta").get("required_pkgs", None) + python_version = meta.user.get("vetiver_meta").get("python_version", None) meta.user.pop("vetiver_meta") else: # ptype = meta.user.get("ptype", None) @@ -113,6 +114,7 @@ def from_pin(cls, board, name: str, version: str = None): # get_prototype = None required_pkgs = meta.user.get("required_pkgs") + python_version = meta.user.get("python_version") return cls( model=model, @@ -123,6 +125,7 @@ def from_pin(cls, board, name: str, version: str = None): "version": meta.version.version, "url": meta.local.get("url"), # None all the time, besides Connect, "required_pkgs": required_pkgs, + "python_version": python_version, }, prototype_data=json.loads(get_prototype) if get_prototype else None, versioned=True,