From 60486347f72ebf990c7410e66b24f2e180d212d8 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Sat, 27 Apr 2024 12:27:26 -0700 Subject: [PATCH] Improve tokenizer validation to make it memory safe We now do careful out of bounds checking and validate the token packing to avoid out of bounds accesses for malformed files (assuming assertions are not compiled out of course). --- src/run.c | 2 +- src/tokenizer.c | 19 +++++++++++-------- src/tokenizer.h | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/run.c b/src/run.c index 70dabca..53286f5 100644 --- a/src/run.c +++ b/src/run.c @@ -131,7 +131,7 @@ void build_tokenizer(struct Tokenizer* t, struct Tensors* tensors, int vocab_siz int bos_id = atoi(tensors_metadata(tensors, "bos_token_id")); int eos_id = atoi(tensors_metadata(tensors, "eos_token_id")); - tokenizer_init(t, tokens, scores, bos_id, eos_id, vocab_size); + tokenizer_init(t, tokens, scores, bos_id, eos_id, vocab_size, tensor->shape[0]); } size_t count_bytes(struct Tensors* tensors, const char* prefix, const char* filter, size_t* out_params) { diff --git a/src/tokenizer.c b/src/tokenizer.c index 8fa58a2..94eacb7 100644 --- a/src/tokenizer.c +++ b/src/tokenizer.c @@ -18,27 +18,30 @@ static int str_lookup(char* str, struct TokenIndex* sorted_vocab, int vocab_size return res != NULL ? res->id : -1; } -void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size) { +void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size, int total_length) { tokenizer->vocab_size = vocab_size; tokenizer->bos_id = bos_id; tokenizer->eos_id = eos_id; tokenizer->eot_id = -1; - // malloc space to hold the scores and the strings tokenizer->vocab = (char**)malloc(vocab_size * sizeof(char*)); tokenizer->sorted_vocab = (struct TokenIndex*)malloc(vocab_size * sizeof(struct TokenIndex)); + tokenizer->vocab_scores = scores; + + assert(tokens[total_length - 1] == '\0'); + int token_offset = 0; - // TODO: validate tokens are null terminated for (int i = 0; i < vocab_size; ++i) { - tokenizer->vocab[i] = tokens; - tokenizer->sorted_vocab[i].str = tokens; + tokenizer->vocab[i] = tokens + token_offset; + tokenizer->sorted_vocab[i].str = tokens + token_offset; tokenizer->sorted_vocab[i].id = i; - assert(strlen(tokens) <= MAX_TOKEN_LENGTH); - tokens += strlen(tokens) + 1; + int token_length = strlen(tokens + token_offset); + assert(token_length <= MAX_TOKEN_LENGTH && token_offset + token_length + 1 <= total_length); + token_offset += token_length + 1; } - tokenizer->vocab_scores = scores; + assert(token_offset == total_length); qsort(tokenizer->sorted_vocab, vocab_size, sizeof(struct TokenIndex), compare_tokens); diff --git a/src/tokenizer.h b/src/tokenizer.h index e6fbb4d..86ad065 100644 --- a/src/tokenizer.h +++ b/src/tokenizer.h @@ -24,7 +24,7 @@ enum TokenizerFlags { TF_ENCODE_EOS = 1 << 1, }; -void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size); +void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size, int total_length); void tokenizer_free(struct Tokenizer* tokenizer); int tokenizer_bound(int bytes);