diff --git a/turftopic/models/keynmf.py b/turftopic/models/keynmf.py index 2aa47fd..513c2e9 100644 --- a/turftopic/models/keynmf.py +++ b/turftopic/models/keynmf.py @@ -125,11 +125,12 @@ def extract_keywords( embedding = embeddings[i].reshape(1, -1) if self.keyword_scope == 'document': mask = terms > 0 - if not np.any(mask): - keywords.append(dict()) - continue else: - mask = np.ones(shape=terms.shape, dtype=bool) + tot_freq = document_term_matrix.sum(axis=0) + mask = tot_freq != 0 + if not np.any(mask): + keywords.append(dict()) + continue important_terms = np.squeeze(np.asarray(mask)) word_embeddings = self.vocab_embeddings[important_terms] sim = cosine_similarity(embedding, word_embeddings) @@ -284,7 +285,7 @@ def prepare_topic_data( except (NotFittedError, AttributeError): doc_topic_matrix = self.nmf_.fit_transform(dtm) self.components_ = self.nmf_.components_ - console.log("Model fiting done.") + console.log("Model fitting done.") res: TopicData = { "corpus": corpus, "document_term_matrix": dtm,