-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
160 lines (126 loc) · 4.56 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import streamlit as st
import torch
from base.predictor import Predictor
from third_party.google_trans import GoogleTranslator
from transformer import Transformer
from utils import *
import dill
import argparse
import os
import pathlib
posix_backup = pathlib.PosixPath
parser = argparse.ArgumentParser()
parser.add_argument("runs_path", type=str, help="Path to training result folder (e.g. runs/...)")
args = parser.parse_args()
model_output_text = None
gg_output_text = None
max_len = 200
beam_size = 1
@st.cache_resource
def load_model():
'''
Load Transformer model from checkpoint
'''
global max_len, beam_size
config_path = os.path.join(args.runs_path, 'config.yaml')
ckpt_path = os.path.join(args.runs_path, 'best.pt')
src_field_path = os.path.join(args.runs_path, 'src_field.pt')
trg_field_path = os.path.join(args.runs_path, 'trg_field.pt')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_dict = load_config(config_path)
max_len = config_dict['DATA']['MAX_LEN']
beam_size = config_dict['PREDICTOR']['BEAM_SIZE']
# Debug on windows
try:
src_field = torch.load(src_field_path, pickle_module=dill)
trg_field = torch.load(trg_field_path, pickle_module=dill)
print("Load fields successfully!")
except Exception as e:
raise e
src_vocab_size = len(src_field.vocab)
trg_vocab_size = len(trg_field.vocab)
model = Transformer(config_path=config_path,
src_vocab_size=src_vocab_size,
trg_vocab_size=trg_vocab_size)
model.load_state_dict(torch.load(ckpt_path))
predictor = Predictor(model, src_field, trg_field, device=device)
return predictor
@st.cache_resource
def load_gg_trans():
translator = GoogleTranslator()
return translator
def setup_page():
st.set_page_config(
page_title="Transformer for Machine Translation - CNTN20 - Vu & Thien",
layout="wide",
initial_sidebar_state="expanded",
)
def show_header():
# Header
st.markdown('''
<h2 align="center">
Statistical Learning - Final Project - CNTN20
</br>
Transformer for Machine Translation (English to Vietnamese)
</h2>
<h5 align="center">
Hoàng Trọng Vũ - 20120025
</br>
Trần Hữu Thiên - 20120584
</h5>
''', unsafe_allow_html=True)
st.markdown('<br>', unsafe_allow_html=True)
col1, col2, col3 = st.columns([2, 1, 2])
with col2:
st.image('./images/app/banner.png', use_column_width=True)
def run_ui():
global model_output_text, gg_output_text, max_len, beam_size
# Input and output columns
tmp, left_column, right_column, tmpp = st.columns([1, 2, 2, 1])
with left_column:
st.write("**Input sentence (in English):**")
input_text = st.text_area("",
height=150,
max_chars=max_len * 20,
label_visibility="collapsed")
# Button
if st.button('Translate'):
if input_text.strip() == '':
st.warning('Please input a sentence!')
model_output_text = None
gg_output_text = None
return
else:
input_text = input_text.strip()
splitted_sens = split_text_by_sens(input_text, max_len=max_len)
model_outputs = []
for sen in splitted_sens:
model_outputs.append(load_model()(preprocess_text(sen), max_len=max_len, beam_size=beam_size))
model_output_text = postprocess_text(' '.join(model_outputs))
while True:
try:
gg_output_text = load_gg_trans().translate(input_text, lang_src='en', lang_tgt='vi').lower()
break
except:
pass
with right_column:
st.write("**Output from model:**")
if model_output_text is not None:
st.write(model_output_text)
st.write("**Output from Google Translate:**")
if gg_output_text is not None:
st.write(gg_output_text)
if __name__ == '__main__':
setup_page()
show_header()
# Try to debug on Windows
# NotImplementedError: cannot instantiate 'PosixPath' on your system
try:
# pathlib.PosixPath = pathlib.WindowsPath
load_model()
except Exception as e:
print(e)
finally:
pathlib.PosixPath = posix_backup
load_gg_trans()
run_ui()