-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathbase_class.py
146 lines (118 loc) · 4.23 KB
/
base_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import time
import os
import abc
import pandas as pd
import pickle
class Pool():
def __init__(self, ):
self.dict = {}
self.update_time = {}
# 1 hour
self.time_interval = 60 * 60
def __contains__(self, key):
self._del_outdated()
return key in self.dict
def __getitem__(self, key):
self._del_outdated()
return self.dict[key]
def __setitem__(self, key, value):
self.dict[key] = value
self.update_time[key] = time.time()
self._del_outdated()
def _del_outdated(self,):
for key in self.update_time:
if time.time() - self.update_time[key] > self.time_interval:
del self.dict[key]
del self.update_time[key]
def __delitem__(self, key):
del self.dict[key]
class SimilarityAlg(metaclass=abc.ABCMeta):
"""Similarity Algorithm to compute similarity between query_embedding and embeddings"""
def __init__(self) -> None:
pass
@abc.abstractmethod
def __call__(self, query_embedding, embeddings) -> None:
pass
class Embedding_Model(metaclass=abc.ABCMeta):
"""Embedding Model to compute embedding of a text"""
def __init__(self, model_name) -> None:
"""Initialize the embedding model"""
user_path = os.path.expanduser('~')
ckpt_path = os.path.join(user_path, "ckpt")
# creat path if not exist
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
embedding_cache_path = os.path.join(ckpt_path, f"embedding_cache_{model_name}.pkl")
self.embedding_cache_path = embedding_cache_path
# load the cache if it exists, and save a copy to disk
try:
embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)
self.embedding_cache = embedding_cache
self.model_name = model_name
@abc.abstractmethod
def __call__(self, text) -> None:
"""Compute the embedding of the text"""
pass
class AbstractPDFParser(metaclass=abc.ABCMeta):
""" PDF parser to parse a PDF file"""
def __init__(self, db_name) -> None:
"""Initialize the pdf database"""
user_path = os.path.expanduser('~')
ckpt_path = os.path.join(user_path, "ckpt")
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
db_cache_path = os.path.join(ckpt_path, f"pdf_parser_{db_name}.pkl")
self.db_cache_path = db_cache_path
# load the cache if it exists, and save a copy to disk
try:
db_cache = pd.read_pickle(db_cache_path)
except FileNotFoundError:
db_cache = {}
with open(db_cache_path, "wb") as cache_file:
pickle.dump(db_cache, cache_file)
self.db_cache = db_cache
self.db_name = db_name
@abc.abstractmethod
def parse_pdf(self,) -> None:
"""Parse the PDF file"""
pass
@abc.abstractmethod
def _get_metadata(self, ) -> None:
"""Get the metadata of the PDF file"""
pass
def get_paragraphs(self, ) -> None:
"""Get the paragraphs of the PDF file"""
pass
@abc.abstractmethod
def get_split_paragraphs(self, ) -> None:
"""
Get the split paragraphs of the PDF file
Return:
split_paragraphs: dict of metadata and corresponding list of split paragraphs
"""
pass
def _determine_metadata_of_paragraph(self, paragraph) -> None:
"""
Determine the metadata of a paragraph
Return:
metadata: metadata of the paragraph
"""
pass
# @abc.abstractmethod
# def _determine_optimal_split_of_pargraphs(self, ) -> None:
# """
# Determine the optimal split of paragraphs
# Return:
# split_paragraphs: dict of metadata and corresponding list of split paragraphs
# """
# pass
class ChatbotEngine(metaclass=abc.ABCMeta):
def __init__(self,) -> None:
pass
@abc.abstractmethod
def query(self, user_query):
pass