Skip to content

Commit

Permalink
axis mismatch bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
victor7246 committed Jun 12, 2020
1 parent 6aaff1a commit c70a299
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion jointtsmodel/JST.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,11 @@ def conditionalDistribution(self, d, v):

secondFactor[:,s] = ((self.n_dst[d, s, :] + self.gammaVec) / \
(self.n_ds[d, s] + np.sum(self.gammaVec)))

thirdFactor = (self.n_vts[v,:, :] + self.beta) / \
(self.n_ts + self.n_vts.shape[0] * self.beta)
probabilities_ts *= firstFactor[:, np.newaxis]

probabilities_ts *= firstFactor[np.newaxis,:]
probabilities_ts *= secondFactor * thirdFactor
probabilities_ts /= np.sum(probabilities_ts)

Expand Down
4 changes: 2 additions & 2 deletions jointtsmodel/TSWE.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def conditionalDistribution(self, d, v):
# forthFactor[k,:] = np.exp(np.dot(self.topic_embeddings[k,:],self.word_embeddings[v,:]))/np.sum(np.exp(np.dot(self.topic_embeddings[k,:],self.word_embeddings.T)))

forthFactor = np.exp(np.dot(self.topic_embeddings,self.word_embeddings[v,:]))/np.sum(np.exp(np.dot(self.topic_embeddings,self.word_embeddings.T)),-1)
probabilities_ts *= firstFactor[:, np.newaxis]
probabilities_ts *= firstFactor[np.newaxis,:]
#probabilities_ts *= secondFactor * thirdFactor
probabilities_ts *= secondFactor * ((1-self.lambda_)*thirdFactor + self.lambda_*forthFactor)
probabilities_ts *= secondFactor * ((1-self.lambda_)*thirdFactor + self.lambda_*forthFactor[:,np.newaxis])
probabilities_ts /= np.sum(probabilities_ts)

return probabilities_ts
Expand Down
2 changes: 1 addition & 1 deletion jointtsmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.5"
__version__ = "1.6"

from .JST import JST
from .RJST import RJST
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if __name__ == "__main__":
setup(
name="jointtsmodel",
version="1.5",
version="1.6",
description="jointtsmodel - library of joint topic-sentiment models",
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit c70a299

Please sign in to comment.