You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
Thanks for a great book. I've been working through the examples and its really helpful. One issue I noticed in the agnews classifier, is when I was running through the predict_category function I got an error when trying to predict one of the sports category:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-18-1b3d5180f20c> in <module>
3 print("="*30)
4 for sample in sample_group:
----> 5 pred = predict_category(sample, classifier, dc.vectorizer, dc._train_ds.max_seq_length+1)
6 print(f"Prediction: {pred['category']} (p={pred['prob']:0.2f})")
7 print(f"\t + Sample: {sample}")
<ipython-input-16-75dc8b0470ad> in predict_category(title, classifer, vectorizer, max_length)
12 """
13 title = preprocess_text(title)
---> 14 vectorized_title = torch.tensor(vectorizer.vectorize(title, max_length))
15
16 # add batch dim so you have a batch of size 1
~/nlpbook/ag/ag/vectorizer.py in vectorize(self, title, vector_len)
33
34 out_vector = np.zeros(vector_len, dtype=np.int64)
---> 35 out_vector[:len(vector)] = vector
36 out_vector[len(vector):] = self.title_vocab.mask_idx
37
ValueError: cannot copy sequence with size 22 to array axis with dimension 21
The max_seq_length in the training set was 20, and in this line in the text, we pass in max_seq_length+1 effectively making the sequence 21 tokens long:
However, the required length is 22, so when I changed the function call to pass max_seq_length+2, it worked. This begs the more general question:
When the test data's max_seq_length could potentially be larger than that of the training data, what do we do? How do we handle that? Do we just pass in a larger value for the max_seq_length? Even if we do that, how do we foresee how big of a value we might need?
Thanks.
The text was updated successfully, but these errors were encountered:
Hello,
Thanks for a great book. I've been working through the examples and its really helpful. One issue I noticed in the agnews classifier, is when I was running through the
predict_category
function I got an error when trying to predict one of the sports category:The
max_seq_length
in the training set was 20, and in this line in the text, we pass inmax_seq_length+1
effectively making the sequence 21 tokens long:However, the required length is 22, so when I changed the function call to pass
max_seq_length+2
, it worked. This begs the more general question:When the test data's
max_seq_length
could potentially be larger than that of the training data, what do we do? How do we handle that? Do we just pass in a larger value for themax_seq_length
? Even if we do that, how do we foresee how big of a value we might need?Thanks.
The text was updated successfully, but these errors were encountered: