Skip to content
This repository has been archived by the owner on Aug 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #346 from wellcometrust/fix/tf-idf-fit-transform
Browse files Browse the repository at this point in the history
Fix/tf idf fit transform
  • Loading branch information
aCampello authored Jul 30, 2021
2 parents 4684a42 + bd6336b commit ba377a4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
20 changes: 20 additions & 0 deletions tests/test_frequency_vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# encoding: utf-8
import pytest

from wellcomeml.ml.frequency_vectorizer import WellcomeTfidf


Expand All @@ -25,3 +27,21 @@ def test_save_and_load(tmpdir):
X_loaded = vec.load_transformed(str(tmpfile))

assert (X_loaded != X_embed).sum() == 0


def test_fit_transform_and_transform():
X = [
"This is a sentence",
"This is another one",
"This is a third sentence",
"Wellcome is a global charitable foundation",
"We want everyone to benefit from science's potential to improve health and save lives."
]

text_vectorizer = WellcomeTfidf()
X_embed = text_vectorizer.fit_transform(X)

X_embed_2 = text_vectorizer.transform(X)

# Asserts that the result of transform is almost the same as fit transform
assert (X_embed-X_embed_2).sum() == pytest.approx(0, abs=1e-6)
44 changes: 23 additions & 21 deletions wellcomeml/ml/frequency_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,27 @@ class WellcomeTfidf(TfidfVectorizer):
vectorisation/embedding
"""

def __init__(self, **kwargs):
def __init__(self, use_regex=True, use_spacy_lemmatizer=True, **kwargs):
"""
Args:
Any sklearn "tfidfvectorizer" arguments (min_df, etc.)
"""
self.embedding = "tf-idf"
self.use_regex = use_regex
self.use_spacy_lemmatizer = use_spacy_lemmatizer

logger.info("Initialising frequency vectorizer.")

kwargs["stop_words"] = kwargs.get("stop_words", "english")

super().__init__(**kwargs)

self.nlp = spacy.blank("en")
self.nlp.add_pipe("lemmatizer", config={"mode": "lookup"})
self.nlp.initialize()

@classmethod
def save_transformed(cls, path, X_transformed):
"""Saves transformed embedded vectors"""
Expand Down Expand Up @@ -84,10 +90,6 @@ def spacy_lemmatizer(self, X, remove_stopwords_and_punct=True):
"""

nlp = spacy.blank("en")
nlp.add_pipe("lemmatizer", config={"mode": "lookup"})
nlp.initialize()

logger.info("Using spacy pre-trained lemmatiser.")
if remove_stopwords_and_punct:
return [
Expand All @@ -98,33 +100,33 @@ def spacy_lemmatizer(self, X, remove_stopwords_and_punct=True):
and not token.is_punct
and token.lemma_ != "-PRON-"
]
for doc in nlp.pipe(X)
for doc in self.nlp.pipe(X)
]
else:
return [
[token.lemma_.lower() for token in doc] for doc in nlp.pipe(X)
[token.lemma_.lower() for token in doc] for doc in self.nlp.pipe(X)
]

def transform(self, X, regex=True, spacy_lemmatizer=True, *_):
if regex:
def _pre_transform(self, X):
if self.use_regex:
X = self.regex_transform(X)
if spacy_lemmatizer:
if self.use_spacy_lemmatizer:
X = self.spacy_lemmatizer(X)

X = [" ".join(text) for text in X]
return [" ".join(text) for text in X]

return super().transform(X)

def fit(self, X, regex=True, spacy_lemmatizer=True, *_):
if regex:
X = self.regex_transform(X)
if spacy_lemmatizer:
X = self.spacy_lemmatizer(X)
def transform(self, X):
X = self._pre_transform(X)

logger.info("Fitting vectorizer.")
return super().transform(X)

X = [" ".join(text) for text in X]
def fit(self, X, y=None):
X = self._pre_transform(X)

super().fit(X)

return self

def fit_transform(self, X, y=None):
X = self._pre_transform(X)

return super().fit_transform(X, y=y)

0 comments on commit ba377a4

Please sign in to comment.