Skip to content

Commit

Permalink
Improve tokenizer validation to make it memory safe
Browse files Browse the repository at this point in the history
We now do careful out of bounds checking and validate the token packing
to avoid out of bounds accesses for malformed files (assuming assertions
are not compiled out of course).
  • Loading branch information
zeux committed Apr 27, 2024
1 parent ef1688a commit 6048634
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/run.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void build_tokenizer(struct Tokenizer* t, struct Tensors* tensors, int vocab_siz
int bos_id = atoi(tensors_metadata(tensors, "bos_token_id"));
int eos_id = atoi(tensors_metadata(tensors, "eos_token_id"));

tokenizer_init(t, tokens, scores, bos_id, eos_id, vocab_size);
tokenizer_init(t, tokens, scores, bos_id, eos_id, vocab_size, tensor->shape[0]);
}

size_t count_bytes(struct Tensors* tensors, const char* prefix, const char* filter, size_t* out_params) {
Expand Down
19 changes: 11 additions & 8 deletions src/tokenizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,30 @@ static int str_lookup(char* str, struct TokenIndex* sorted_vocab, int vocab_size
return res != NULL ? res->id : -1;
}

void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size) {
void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size, int total_length) {
tokenizer->vocab_size = vocab_size;
tokenizer->bos_id = bos_id;
tokenizer->eos_id = eos_id;
tokenizer->eot_id = -1;

// malloc space to hold the scores and the strings
tokenizer->vocab = (char**)malloc(vocab_size * sizeof(char*));
tokenizer->sorted_vocab = (struct TokenIndex*)malloc(vocab_size * sizeof(struct TokenIndex));
tokenizer->vocab_scores = scores;

assert(tokens[total_length - 1] == '\0');
int token_offset = 0;

// TODO: validate tokens are null terminated
for (int i = 0; i < vocab_size; ++i) {
tokenizer->vocab[i] = tokens;
tokenizer->sorted_vocab[i].str = tokens;
tokenizer->vocab[i] = tokens + token_offset;
tokenizer->sorted_vocab[i].str = tokens + token_offset;
tokenizer->sorted_vocab[i].id = i;

assert(strlen(tokens) <= MAX_TOKEN_LENGTH);
tokens += strlen(tokens) + 1;
int token_length = strlen(tokens + token_offset);
assert(token_length <= MAX_TOKEN_LENGTH && token_offset + token_length + 1 <= total_length);
token_offset += token_length + 1;
}

tokenizer->vocab_scores = scores;
assert(token_offset == total_length);

qsort(tokenizer->sorted_vocab, vocab_size, sizeof(struct TokenIndex), compare_tokens);

Expand Down
2 changes: 1 addition & 1 deletion src/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ enum TokenizerFlags {
TF_ENCODE_EOS = 1 << 1,
};

void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size);
void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size, int total_length);
void tokenizer_free(struct Tokenizer* tokenizer);

int tokenizer_bound(int bytes);
Expand Down

0 comments on commit 6048634

Please sign in to comment.