Skip to content

Commit

Permalink
Small fixes for expected context transformer model (#211)
Browse files Browse the repository at this point in the history
KC and I tested locally with notebook, worked with older version of scipy. Will file an issue later about newer version scipy related issue. Thank you Vivian! Will merge now.
  • Loading branch information
vianxnguyen authored May 18, 2024
1 parent 3303ddf commit 7410b8f
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions convokit/expected_context_framework/expected_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def transform(self, utt_vects):
return self._snip(utt_vects * self.term_reprs_full / self.context_s, self.snip_first_dim)

def compute_utt_ranges(self, utt_vects):
return np.dot(normalize(utt_vects, norm="l1"), self.term_ranges)
return np.dot(normalize(np.array(utt_vects), norm="l1"), self.term_ranges)

def transform_context_utts(self, context_utt_vects):
return self._snip(context_utt_vects * self.context_V / self.context_s, self.snip_first_dim)
Expand Down Expand Up @@ -698,17 +698,21 @@ def load(self, dirname):
self.snip_first_dim = meta_dict["snip_first_dim"]
self.cluster_on = meta_dict["cluster_on"]

self.context_U = np.load(os.path.join(dirname, "context_U.npy"))
self.context_U = np.load(os.path.join(dirname, "context_U.npy"), allow_pickle=True)
self.train_context_reprs = self._snip(self.context_U, self.snip_first_dim)
self.context_V = np.load(os.path.join(dirname, "context_V.npy"))
self.context_V = np.load(os.path.join(dirname, "context_V.npy"), allow_pickle=True)
self.context_term_reprs = self._snip(self.context_V, self.snip_first_dim)
self.context_s = np.load(os.path.join(dirname, "context_s.npy"))
self.context_terms = np.load(os.path.join(dirname, "context_terms.npy"))
self.terms = np.load(os.path.join(dirname, "terms.npy"))
self.term_reprs_full = np.matrix(np.load(os.path.join(dirname, "term_reprs.npy")))
self.context_s = np.load(os.path.join(dirname, "context_s.npy"), allow_pickle=True)
self.context_terms = np.load(os.path.join(dirname, "context_terms.npy"), allow_pickle=True)
self.terms = np.load(os.path.join(dirname, "terms.npy"), allow_pickle=True)
self.term_reprs_full = np.matrix(
np.load(os.path.join(dirname, "term_reprs.npy"), allow_pickle=True)
)
self.term_reprs = self._snip(self.term_reprs_full, self.snip_first_dim)
self.term_ranges = np.load(os.path.join(dirname, "term_ranges.npy"))
self.train_utt_reprs = np.load(os.path.join(dirname, "train_utt_reprs.npy"))
self.term_ranges = np.load(os.path.join(dirname, "term_ranges.npy"), allow_pickle=True)
self.train_utt_reprs = np.load(
os.path.join(dirname, "train_utt_reprs.npy"), allow_pickle=True
)

try:
km_obj = ClusterWrapper(self.n_clusters)
Expand Down Expand Up @@ -761,7 +765,7 @@ def _get_default_ids(self, ids, n):
def _snip(self, vects, snip_first_dim=True, dim=None):
if dim is None:
dim = vects.shape[1]
return normalize(vects[:, int(snip_first_dim) : dim])
return normalize(np.array(vects[:, int(snip_first_dim) : dim]))


class ClusterWrapper:
Expand Down Expand Up @@ -818,7 +822,7 @@ def load(self, dirname):
self.random_state = meta_dict["random_state"]

self.km_df = pd.read_csv(os.path.join(dirname, "cluster_km_df.tsv"), sep="\t", index_col=0)
self.cluster_names = np.load(os.path.join(dirname, "cluster_names.npy"))
self.cluster_names = np.load(os.path.join(dirname, "cluster_names.npy"), allow_pickle=True)
self.km_model = joblib.load(os.path.join(dirname, "km_model.joblib"))

def dump(self, dirname):
Expand Down

0 comments on commit 7410b8f

Please sign in to comment.