Skip to content

Commit

Permalink
Merge pull request #520 from Team-WeQuiz/feat/#505
Browse files Browse the repository at this point in the history
[ML] ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋“ค๊ณผ ์š”์•ฝ ๊ฐ„์˜ ์œ ์‚ฌ๋„ ๊ฒ€์‚ฌ
  • Loading branch information
noooey authored May 29, 2024
2 parents 71fb200 + 3b1609b commit 718d22b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
6 changes: 3 additions & 3 deletions model/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def generate_quiz_async(generate_request, id, summary_split_docs, vector_s
async for doc in vector_split_docs:
vector_docs.append(doc)
# Generate quiz
quiz_generator = QuizGenerator(vector_docs)
quiz_generator = QuizGenerator(vector_docs, summary)

async def generate_question(idx, i):
max_attempts = generate_request.num_of_quiz
Expand Down Expand Up @@ -203,7 +203,7 @@ async def generate_question(idx, i):

# 5๋ฌธ์ œ๊ฐ€ ์ƒ์„ฑ๋˜์—ˆ๊ฑฐ๋‚˜ ๋งˆ์ง€๋ง‰ ๋ฌธ์ œ์ธ ๊ฒฝ์šฐ DynamoDB์— ์—…๋ฐ์ดํŠธ
if idx % 5 == 0 or idx == generate_request.num_of_quiz:
log('info', f'[app.py > quiz] cur idx: {idx} - quiz batch is ready to push. {questions}')
log('info', f'[app.py > quiz] cur idx: {idx} - quiz batch is ready to push. idx: {idx}')

retries = 0
while retries < QUIZ_UPDATE_RETRY:
Expand Down Expand Up @@ -268,7 +268,7 @@ async def generate(generate_request: GenerateRequest):

# keyword ์ถ”์ถœ
keywords = await extract_keywords(keyword_docs, top_n=min(generate_request.num_of_quiz * 2, len(sentences) - 1)) # ํ‚ค์›Œ๋“œ๋Š” ๊ฐœ์ˆ˜๋ฅผ ์—ฌ์œ ๋กญ๊ฒŒ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
log('info', f'[app.py > quiz] Extracted Keywords: {keywords}')
log('info', f'[app.py > quiz] Extracted Keywords for creating {generate_request.num_of_quiz} quizs.: {keywords}')

# quiz ์ƒ์„ฑ (๋น„๋™๊ธฐ)
asyncio.create_task(generate_quiz_async(generate_request, res["id"], summary_docs, vector_docs, keywords))
Expand Down
10 changes: 2 additions & 8 deletions model/app/data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
setup_logging()

class QuizGenerator():
def __init__(self, split_docs):
self.embedding = OpenAIEmbeddings(model="text-embedding-3-large")
try:
self.indices = FAISS.from_documents(split_docs, self.embedding)
except Exception as e:
raise e
self.chain = QuizPipeline(self.indices)
def __init__(self, split_docs, summary):
self.chain = QuizPipeline(split_docs, summary)

def get_type(self, text):
# self.types = ['1. Multiple choice', '2. Short answer type that can be easily answered in one word', '3. yes/no quiz']
Expand All @@ -30,7 +25,6 @@ def get_type(self, text):
return "OXํ€ด์ฆˆ"

def adjust_result(self, type, option_list, correct):
log("warning", f"[generator.py > quiz] set type({type}), options({option_list}), correct({correct}) ")
if type == "๋‹จ๋‹ตํ˜•" or type == None:
if (correct.strip().lower() in NO_LIST) or (correct.strip().lower() in YES_LIST):
type = "OXํ€ด์ฆˆ"
Expand Down
2 changes: 1 addition & 1 deletion model/app/data/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SUMMARY_SENTENCE_OVERLAP = 2

# vector chunk
VECTOR_CHUNK_SIZE = 320
VECTOR_CHUNK_SIZE = 400
VECTOR_SENTENCE_OVERLAP = 0

K = 2
Expand Down
36 changes: 31 additions & 5 deletions model/app/model/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from langchain.chains.base import Chain
from langchain.schema import BaseRetriever
from typing import List, Dict, Any
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from sklearn.metrics.pairwise import cosine_similarity

from model.prompt import CHOICE_PROB_TEMPLATE, SHORT_PROB_TEMPLATE, GENERATE_QUIZ_TEMPLATE, JSON_FORMAT_TEMPLATE
from data.settings import *
Expand All @@ -18,6 +21,8 @@
# Retrieval ์ฒด์ธ
class RetrievalChain(Chain):
retriever: BaseRetriever
embeddings: OpenAIEmbeddings
summary: str

@property
def input_keys(self) -> List[str]:
Expand All @@ -30,7 +35,23 @@ def output_keys(self) -> List[str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
message = inputs["message"]
relevant_docs = self.retriever.invoke(message)
context = " ".join([doc.page_content for doc in relevant_docs])
log('info', f'[chain.py > RetrievalChain] Retrieved {len(relevant_docs)} Docs for {message}')

# ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋“ค๊ณผ summary ์ž„๋ฒ ๋”ฉ
doc_embeddings = self.embeddings.embed_documents([doc.page_content for doc in relevant_docs])
summary_embedding = self.embeddings.embed_query(self.summary)
summary_embedding = [summary_embedding] # 1D ๋ฐฐ์—ด์„ 2D ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜

# ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarities = cosine_similarity(summary_embedding, doc_embeddings)

# ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ ๋ฌธ์„œ K๊ฐœ ์„ ํƒ
top_k_indices = similarities.argsort()[0][-K:][::-1]
top_k_docs = [relevant_docs[i] for i in top_k_indices]
log('info', f'[chain.py > RetrievalChain] Selected {len(top_k_docs)} Docs for {message} with summary.')

context = "\n".join([doc.page_content for doc in top_k_docs])

if len(context) < VECTOR_CHUNK_SIZE * 0.3:
log('error', f'[chain.py > RetrievalChain] Retrieved Context is not sufficient. - "{message}", len: {len(context)}')
raise ValueError(f'[chain.py > RetrievalChain] Retrieved Context is not sufficient. - "{message}", len: {len(context)}')
Expand Down Expand Up @@ -90,9 +111,14 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

# ๋ฌธ์ œ ์ƒ์„ฑ ํด๋ž˜์Šค
class QuizPipeline:
def __init__(self, indices: object):
self.indices = indices
self.retriever = self.indices.as_retriever(search_kwargs=dict(k=K)) # ์ธ๋ฑ์Šค ๊ฐ์ฒด๋กœ๋ถ€ํ„ฐ retriever ์ดˆ๊ธฐํ™”
def __init__(self, split_docs, summary):
self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
try:
self.indices = FAISS.from_documents(split_docs, self.embeddings)
self.retriever = self.indices.as_retriever(search_kwargs=dict(k=K+1)) # ์ธ๋ฑ์Šค ๊ฐ์ฒด๋กœ๋ถ€ํ„ฐ retriever ์ดˆ๊ธฐํ™”
except Exception as e:
raise e
self.summary = summary
self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
self.question_templates = [CHOICE_PROB_TEMPLATE, SHORT_PROB_TEMPLATE]
# self.output_schemas = [ChoiceOutput, ShortOutput]
Expand All @@ -114,7 +140,7 @@ def generate_quiz(self, message: str) -> dict:
partial_variables={"format_instructions": json_output_parser.get_format_instructions()}
)

retrieval_chain = RetrievalChain(retriever=self.retriever)
retrieval_chain = RetrievalChain(retriever=self.retriever, embeddings=self.embeddings, summary=self.summary)
quiz_generation_chain = QuizGenerationChain(prompt= quiz_prompt, llm=self.llm)
json_formatter_chain = JSONFormatterChain(prompt=json_prompt, llm=self.llm, output_parser=json_output_parser)

Expand Down

0 comments on commit 718d22b

Please sign in to comment.