Skip to content

Commit

Permalink
Merge pull request #68 from x-tabdeveloping/embedding_dim
Browse files Browse the repository at this point in the history
Informative error message
  • Loading branch information
KennethEnevoldsen authored Oct 30, 2024
2 parents 0bc8a0e + 636f667 commit 11e3c6e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ line-length=79

[tool.poetry]
name = "turftopic"
version = "0.7.0"
version = "0.7.1"
description = "Topic modeling with contextual representations from sentence transformers."
authors = ["Márton Kardos <[email protected]>"]
license = "MIT"
Expand Down
13 changes: 13 additions & 0 deletions turftopic/models/_keynmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

from turftopic.base import Encoder

NOT_MATCHING_ERROR = (
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
+ "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
+ "Try to initialize the model with the encoder you used for computing the embeddings."
)


def batched(iterable, n: int) -> Iterable[list[str]]:
"Batch data into tuples of length n. The last batch may be shorter."
Expand Down Expand Up @@ -143,6 +149,13 @@ def batch_extract_keywords(
self.term_embeddings[self.key_to_index[term]]
for term in batch_vocab[important_terms]
]
if self.term_embeddings.shape[1] != embeddings.shape[1]:
raise ValueError(
NOT_MATCHING_ERROR.format(
n_dims=embeddings.shape[1],
n_word_dims=self.term_embeddings.shape[1],
)
)
sim = cosine_similarity(embedding, word_embeddings).astype(
np.float64
)
Expand Down
16 changes: 16 additions & 0 deletions turftopic/models/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroid'
"""

NOT_MATCHING_ERROR = (
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
+ "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
+ "Try to initialize the model with the encoder you used for computing the embeddings."
)


def smallest_hierarchical_join(
topic_vectors: np.ndarray,
Expand Down Expand Up @@ -370,6 +376,16 @@ def estimate_components(
self.vocab_embeddings = self.encoder_.encode(
self.vectorizer.get_feature_names_out()
) # type: ignore
if (
self.vocab_embeddings.shape[1]
!= self.topic_vectors_.shape[1]
):
raise ValueError(
NOT_MATCHING_ERROR.format(
n_dims=self.topic_vectors_.shape[1],
n_word_dims=self.vocab_embeddings.shape[1],
)
)
self.components_ = cluster_centroid_distance(
self.topic_vectors_,
self.vocab_embeddings,
Expand Down
13 changes: 13 additions & 0 deletions turftopic/models/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from turftopic.base import ContextualModel, Encoder
from turftopic.vectorizer import default_vectorizer

NOT_MATCHING_ERROR = (
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
+ "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
+ "Try to initialize the model with the encoder you used for computing the embeddings."
)


class SemanticSignalSeparation(ContextualModel):
"""Separates the embedding matrix into 'semantic signals' with
Expand Down Expand Up @@ -115,6 +121,13 @@ def fit_transform(
console.log("Term extraction done.")
status.update("Encoding vocabulary")
self.vocab_embeddings = self.encoder_.encode(vocab)
if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
raise ValueError(
NOT_MATCHING_ERROR.format(
n_dims=self.embeddings.shape[1],
n_word_dims=self.vocab_embeddings.shape[1],
)
)
console.log("Vocabulary encoded.")
status.update("Estimating term importances")
vocab_topic = self.decomposition.transform(self.vocab_embeddings)
Expand Down

0 comments on commit 11e3c6e

Please sign in to comment.