Skip to content

Commit

Permalink
Fixed some stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
AG committed Mar 4, 2024
1 parent 945d4db commit eb5fbd5
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
###########
# All with chunk size of 1024
# ?.?MB: Target dictionary count 100,000, Prune 10,000,000
# 8.5MB: Target dictionary count 25,000, Prune 40,000
# 3.6MB: Target dictionary count 10,000, Prune 40,000
# 8.5MB: Target dictionary count 25,000, Prune 10,000,000
# 3.6MB: Target dictionary count 10,000, Prune 10,000,000

PRUNE_FREQUENCY = 10 * 1000 * 1000 # Every this many document positions
CHUNK_SIZE = 1024 # 1KB per chunk
TARGET_DICTIONARY_COUNT = 100 * 1000
TARGET_DICTIONARY_COUNT = 10 * 1000

# Define a flag to indicate when an interrupt has been caught
interrupted = False
Expand Down Expand Up @@ -97,6 +97,15 @@ def update_scores(trie_store, path, context_slug):
trie_store['scores'][path][context_slug] = 0
trie_store['scores'][path][context_slug] += 1

def save_position(progress_file, current_position, trie_store):
# Every now and then save our progress.
print(f"Saving the current position of %s" % current_position)
# Save the current progress (file position)
with open(progress_file, 'w') as f:
f.write(str(current_position))
print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY)
flatten_to_dictionary(trie_store, TARGET_DICTIONARY_COUNT)

# Define a main function to orchestrate the training process
def main():
global prune_position_marker
Expand Down Expand Up @@ -162,22 +171,15 @@ def main():

words = row.split()

# Every now and then save our progress.
print(f"Saving the current position of %s" % current_position)
# Save the current progress (file position)
with open(progress_file, 'w') as f:
f.write(str(current_position))

if interrupted:
print("Saving data. Script will terminate when done.")
save_trie_store(trie_store)
sys.exit(0)

# Every now and then, prune unpopular entries.
if (current_position - prune_position_marker > PRUNE_FREQUENCY):
save_position(progress_file, current_position, trie_store)
prune_position_marker = current_position
print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY)
flatten_to_dictionary(trie_store, TARGET_DICTIONARY_COUNT)

# Process words three at a time with shifting window
for i in range(len(words) - 2):
Expand Down

0 comments on commit eb5fbd5

Please sign in to comment.