diff --git a/train.py b/train.py index f692ad8..ab50d6b 100644 --- a/train.py +++ b/train.py @@ -14,12 +14,12 @@ ########### # All with chunk size of 1024 # ?.?MB: Target dictionary count 100,000, Prune 10,000,000 -# 8.5MB: Target dictionary count 25,000, Prune 40,000 -# 3.6MB: Target dictionary count 10,000, Prune 40,000 +# 8.5MB: Target dictionary count 25,000, Prune 10,000,000 +# 3.6MB: Target dictionary count 10,000, Prune 10,000,000 PRUNE_FREQUENCY = 10 * 1000 * 1000 # Every this many document positions CHUNK_SIZE = 1024 # 1KB per chunk -TARGET_DICTIONARY_COUNT = 100 * 1000 +TARGET_DICTIONARY_COUNT = 10 * 1000 # Define a flag to indicate when an interrupt has been caught interrupted = False @@ -97,6 +97,15 @@ def update_scores(trie_store, path, context_slug): trie_store['scores'][path][context_slug] = 0 trie_store['scores'][path][context_slug] += 1 +def save_position(progress_file, current_position, trie_store): + # Every now and then save our progress. + print(f"Saving the current position of %s" % current_position) + # Save the current progress (file position) + with open(progress_file, 'w') as f: + f.write(str(current_position)) + print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY) + flatten_to_dictionary(trie_store, TARGET_DICTIONARY_COUNT) + # Define a main function to orchestrate the training process def main(): global prune_position_marker @@ -162,12 +171,6 @@ def main(): words = row.split() - # Every now and then save our progress. - print(f"Saving the current position of %s" % current_position) - # Save the current progress (file position) - with open(progress_file, 'w') as f: - f.write(str(current_position)) - if interrupted: print("Saving data. Script will terminate when done.") save_trie_store(trie_store) @@ -175,9 +178,8 @@ def main(): # Every now and then, prune unpopular entries. if (current_position - prune_position_marker > PRUNE_FREQUENCY): + save_position(progress_file, current_position, trie_store) prune_position_marker = current_position - print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY) - flatten_to_dictionary(trie_store, TARGET_DICTIONARY_COUNT) # Process words three at a time with shifting window for i in range(len(words) - 2):