-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sasrec tests #193
base: experimental/sasrec
Are you sure you want to change the base?
sasrec tests #193
Conversation
rectools/models/sasrec.py
Outdated
raise NotImplementedError() | ||
|
||
|
||
class IdEmbeddingsItemNet(ItemNetBase): | ||
""" | ||
Base class for item embeddings. To use more complicated logic then just id embeddings inherit | ||
from this class and pass your custom ItemNet to your model params | ||
Network for item embeddings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Network for item embeddings. | |
Network for item embeddings based only on item ids. |
rectools/models/sasrec.py
Outdated
output = self.ff_relu(self.ff_dropout1(self.ff_linear1(seqs))) | ||
fin = self.ff_dropout2(self.ff_linear2(output)) | ||
return fin | ||
|
||
|
||
class SASRecTransformerLayers(TransformerLayersBase): | ||
"""Exactly SASRec authors architecture but with torch MHA realisation""" | ||
""" | ||
Exactly SASRec author's transformer blocks architecture but with torch MHA realisation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly SASRec author's transformer blocks architecture but with torch MHA realisation. | |
Exactly SASRec author's transformer blocks architecture but with pytorch Multi-Head Attention realisation. |
rectools/models/sasrec.py
Outdated
Parameters | ||
---------- | ||
n_blocks: int | ||
Number of self-attention blocks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Number of self-attention blocks. | |
Number of transformer blocks. |
rectools/models/sasrec.py
Outdated
Parameters | ||
---------- | ||
n_blocks: int | ||
Number of self-attention blocks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Number of self-attention blocks. | |
Number of transformer blocks. |
rectools/models/sasrec.py
Outdated
use_causal_attn: bool, default True | ||
If ``True``, causal mask is used in multi-head self-attention. | ||
transformer_layers_type: Type(TransformerLayersBase), default `SasRecTransformerLayers` | ||
Type of transformer layers used for training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type of transformer layers used for training. | |
Type of transformer layers architecture. |
rectools/models/sasrec.py
Outdated
Parameters | ||
---------- | ||
sessions: torch.Tensor | ||
User sessions consisting of items. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User sessions consisting of items. | |
User sessions in the form of sequences of items ids. |
rectools/models/sasrec.py
Outdated
Parameters | ||
---------- | ||
sessions: torch.Tensor | ||
User sessions consisting of items. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User sessions consisting of items. | |
User sessions in the form of sequences of items ids. |
rectools/models/sasrec.py
Outdated
Returns | ||
------- | ||
torch.Tensor | ||
User sessions with positional encoding if use_pos_emb is ``True``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User sessions with positional encoding if use_pos_emb is ``True``. | |
Encoded user sessions with added positional encoding if `use_pos_emb` is ``True``. |
rectools/models/sasrec.py
Outdated
Parameters | ||
---------- | ||
sessions: torch.Tensor | ||
User sessions consisting of items. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User sessions consisting of items. | |
User sessions in the form of sequences of items ids. |
Parameters | ||
---------- | ||
sessions: List[List[int]] | ||
User interaction sequences. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User interaction sequences. | |
User sessions in the form of sequences of items ids. |
tests/models/test_sasrec.py
Outdated
model.fit(dataset=dataset) | ||
users = np.array([10, 30, 40]) | ||
actual = model.recommend(users=users, dataset=dataset, k=3, filter_viewed=filter_viewed) | ||
actual[Columns.Item] = actual[Columns.Item].apply(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is bad. If items in the dataset were int
we need to receive int
ids from recommend method. If this doesn't work, we need to find where exactly we are breaking them and fix it. Is it in IdMap when creating new id map with string ["PAD"] and then adding other ids as int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not to apply(int)
or astype(int)
because they can modify float values. Pass check_dtype=False
to assert_frames_equal
or assert_series_equal
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we still expect incorrect id types in reco? Could you explain the reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We receive int item ids but pd.Series has dtype object. So the values are correct and they are not float. But dtype is not equal to int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's convert them back to int in the code (I mean in the main code, not test)
If I'm a user and I give ints, I expect ints in the reco. Besides, object is much slower for future iterations. And also there may be problems with metric calculation since pandas will fail trying to merge int and object.
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal | ||
|
||
|
||
@pytest.mark.filterwarnings("ignore::pytorch_lightning.utilities.warnings.PossibleUserWarning") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add comments for this and the next line? (What the warnings are and why do we ignore them)
tests/models/test_sasrec.py
Outdated
model.fit(dataset=dataset) | ||
users = np.array([10, 30, 40]) | ||
actual = model.recommend(users=users, dataset=dataset, k=3, filter_viewed=filter_viewed) | ||
actual[Columns.Item] = actual[Columns.Item].apply(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we still expect incorrect id types in reco? Could you explain the reason?
Merge with updated branch
added sasrec test