diff --git a/src/sk_transformers/generic_transformer.py b/src/sk_transformers/generic_transformer.py index 4fdcc6e..53a550c 100644 --- a/src/sk_transformers/generic_transformer.py +++ b/src/sk_transformers/generic_transformer.py @@ -544,7 +544,9 @@ def transform(self, Xy: pd.DataFrame) -> pd.DataFrame: pd.DataFrame: Dataframe with the queries applied. """ - Xy = check_ready_to_transform(self, Xy, force_all_finite="allow-nan") + Xy = check_ready_to_transform( + self, Xy, Xy.columns, force_all_finite="allow-nan" + ) for query in self.queries: Xy = Xy.query(query, inplace=False) return Xy diff --git a/src/sk_transformers/utils.py b/src/sk_transformers/utils.py index 248b0f0..5501b15 100644 --- a/src/sk_transformers/utils.py +++ b/src/sk_transformers/utils.py @@ -8,7 +8,7 @@ def check_ready_to_transform( transformer: Any, X: pd.DataFrame, - features: Optional[Union[str, List[str]]] = None, + features: Union[str, List[str]], force_all_finite: Union[bool, str] = True, dtype: Optional[Union[str, List[str]]] = None, ) -> pd.DataFrame: @@ -38,25 +38,25 @@ def check_ready_to_transform( pandas.DataFrame: A checked copy of original dataframe. """ + if isinstance(features, str): + features = [features] + if not isinstance(X, pd.DataFrame): raise ValueError("X must be a Pandas dataframe!") if X.empty: raise ValueError("X must not be empty!") - if features: - if isinstance(features, str): - if not features in X.columns: - raise ValueError(f"Column `{features}` not in dataframe!") - elif isinstance(features, list): - if not all(c in X.columns for c in features): - not_in_df = ( - str([c for c in features if c not in X.columns]) - .replace("[", "") - .replace("]", "") - .replace("'", "`") - ) - raise ValueError( - f"Not all provided `features` could be found in `X`! Following columns were not found in the dataframe: {not_in_df}." - ) + + if isinstance(features, list): + if not all(c in X.columns for c in features): + not_in_df = ( + str([c for c in features if c not in X.columns]) + .replace("[", "") + .replace("]", "") + .replace("'", "`") + ) + raise ValueError( + f"Not all provided `features` could be found in `X`! Following columns were not found in the dataframe: {not_in_df}." + ) if issubclass(transformer.__class__, BaseEstimator) is False: raise TypeError( @@ -67,18 +67,26 @@ def check_ready_to_transform( ) check_is_fitted(transformer, "fitted_") - X_tmp = check_array( - X.to_numpy(), + X_tmp = X[ + dict.fromkeys(X[features]).keys() + ].copy() # `dict.fromkeys` was chosen instead of `set` to maintain the order of the entries. + + X_tmp_array = check_array( + X_tmp.to_numpy(), dtype=dtype, accept_large_sparse=False, force_all_finite=force_all_finite, ) - X_tmp = pd.DataFrame(X_tmp, columns=X.columns, index=X.index) + X_tmp = pd.DataFrame(X_tmp_array, columns=X_tmp.columns, index=X_tmp.index) - for column in X.columns: + for column in X_tmp.columns: X_tmp[column] = X_tmp[column].astype(X[column].dtype) - return X_tmp.copy() + non_included_features = [c for c in X.columns if c not in features] + if non_included_features: + X_tmp = pd.concat([X_tmp, X[non_included_features]], axis=1) + + return X_tmp def check_data(X: pd.DataFrame, y: pd.Series, check_nans: bool = True) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 585c616..07cd2ba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,14 +13,14 @@ def test_check_ready_to_transform_for_empty_df() -> None: with pytest.raises(ValueError) as error: - check_ready_to_transform(None, pd.DataFrame()) + check_ready_to_transform(None, pd.DataFrame(), ["a"]) assert "X must not be empty!" == str(error.value) def test_check_ready_to_transform_for_not_dataframe() -> None: with pytest.raises(ValueError) as error: - check_ready_to_transform(None, np.ndarray([1, 2, 3])) + check_ready_to_transform(None, np.ndarray([1, 2, 3]), ["a"]) assert "X must be a Pandas dataframe!" == str(error.value) @@ -29,7 +29,10 @@ def test_check_ready_to_transform_for_wrong_column() -> None: with pytest.raises(ValueError) as error: check_ready_to_transform(None, pd.DataFrame({"a": [1, 2, 3]}), "b") - assert "Column `b` not in dataframe!" == str(error.value) + assert ( + "Not all provided `features` could be found in `X`! Following columns were not found in the dataframe: `b`." + == str(error.value) + ) def test_check_ready_to_transform_for_wrong_columns() -> None: