Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jessieziyun committed Mar 22, 2021
0 parents commit a48a2aa
Show file tree
Hide file tree
Showing 14 changed files with 670 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .flaskenv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FLASK_APP=poetry_generator
FLASK_ENV=development
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
.DS_Store
Screenshots/
6 changes: 6 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from poetry_generator import create_app, socketio

app = create_app()

if __name__ == '__main__':
socketio.run(app)
14 changes: 14 additions & 0 deletions poetry_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from flask import Flask
from flask_socketio import SocketIO

socketio = SocketIO()

def create_app(config_file='settings.py'):
app = Flask(__name__, static_url_path="/static", static_folder="static")
app.config.from_pyfile(config_file)

from .routes import generator
app.register_blueprint(generator)

socketio.init_app(app)
return app
68 changes: 68 additions & 0 deletions poetry_generator/conditional_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import fire
import json
import os
import numpy as np
import tensorflow as tf

from poetry_generator import model
from poetry_generator import sample
from poetry_generator import encoder

class AI:
def generate_text(self, input_text):
model_name='poet'
seed=None
nsamples=1
batch_size=1
length=50
temperature=1
top_k=40
top_p=0.0

self.response = ""

if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0

enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
cur_path = os.path.dirname(__file__) + "/models" + "/" + model_name
with open(cur_path + "/hparams.json") as f:
hparams.override_from_dict(json.load(f))

if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k, top_p=top_p
)

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(cur_path)
saver.restore(sess, ckpt)

context_tokens = enc.encode(input_text)
print("Title: " + input_text)
generated = 0
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
self.response = text

return self.response

ai = AI()
118 changes: 118 additions & 0 deletions poetry_generator/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Byte pair encoding utilities"""

import os
import json
import regex as re
from functools import lru_cache

@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs

class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace'):
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}

# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)

if not pairs:
return token

while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens

def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text

def get_encoder(model_name):
cur_path = os.path.dirname(__file__) + "/models" + "/" + model_name
with open(cur_path + '/encoder.json', 'r') as f:
encoder = json.load(f)
with open(cur_path + '/vocab.bpe', 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
60 changes: 60 additions & 0 deletions poetry_generator/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import fire
import json
import os
import numpy as np
import tensorflow as tf

from poetry_generator import model
from poetry_generator import sample
from poetry_generator import encoder

class AI:
def generate_poetry(self):
model_name='poet'
seed=None
nsamples=1
batch_size=1
length=50
temperature=0.75
top_k=40
top_p=0.0

self.response = ""

enc = encoder.get_encoder(model_name)
cur_path = os.path.dirname(__file__) + "/models" + "/" + model_name
hparams = model.default_hparams()
with open(cur_path + '/hparams.json') as f:
hparams.override_from_dict(json.load(f))

if length is None:
length = hparams.n_ctx
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

with tf.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
tf.set_random_seed(seed)

output = sample.sample_sequence(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=batch_size,
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(cur_path)
saver.restore(sess, ckpt)

generated = 0
while nsamples == 0 or generated < nsamples:
out = sess.run(output)
for i in range(batch_size):
generated += batch_size
text = enc.decode(out[i])
self.response = text

return self.response

ai = AI()
Loading

0 comments on commit a48a2aa

Please sign in to comment.