diff --git a/model/app/app.py b/model/app/app.py index 35cb9b8e..f3655046 100644 --- a/model/app/app.py +++ b/model/app/app.py @@ -168,7 +168,7 @@ async def generate_quiz_async(generate_request, id, summary_split_docs, vector_s async for doc in vector_split_docs: vector_docs.append(doc) # Generate quiz - quiz_generator = QuizGenerator(vector_docs) + quiz_generator = QuizGenerator(vector_docs, summary) async def generate_question(idx, i): max_attempts = generate_request.num_of_quiz @@ -203,7 +203,7 @@ async def generate_question(idx, i): # 5문제가 생성되었거나 마지막 문제인 경우 DynamoDB에 업데이트 if idx % 5 == 0 or idx == generate_request.num_of_quiz: - log('info', f'[app.py > quiz] cur idx: {idx} - quiz batch is ready to push. {questions}') + log('info', f'[app.py > quiz] cur idx: {idx} - quiz batch is ready to push. idx: {idx}') retries = 0 while retries < QUIZ_UPDATE_RETRY: @@ -268,7 +268,7 @@ async def generate(generate_request: GenerateRequest): # keyword 추출 keywords = await extract_keywords(keyword_docs, top_n=min(generate_request.num_of_quiz * 2, len(sentences) - 1)) # 키워드는 개수를 여유롭게 생성합니다. - log('info', f'[app.py > quiz] Extracted Keywords: {keywords}') + log('info', f'[app.py > quiz] Extracted Keywords for creating {generate_request.num_of_quiz} quizs.: {keywords}') # quiz 생성 (비동기) asyncio.create_task(generate_quiz_async(generate_request, res["id"], summary_docs, vector_docs, keywords)) diff --git a/model/app/data/generator.py b/model/app/data/generator.py index 1e09c0be..a39513b0 100644 --- a/model/app/data/generator.py +++ b/model/app/data/generator.py @@ -12,13 +12,8 @@ setup_logging() class QuizGenerator(): - def __init__(self, split_docs): - self.embedding = OpenAIEmbeddings(model="text-embedding-3-large") - try: - self.indices = FAISS.from_documents(split_docs, self.embedding) - except Exception as e: - raise e - self.chain = QuizPipeline(self.indices) + def __init__(self, split_docs, summary): + self.chain = QuizPipeline(split_docs, summary) def get_type(self, text): # self.types = ['1. Multiple choice', '2. Short answer type that can be easily answered in one word', '3. yes/no quiz'] @@ -30,7 +25,6 @@ def get_type(self, text): return "OX퀴즈" def adjust_result(self, type, option_list, correct): - log("warning", f"[generator.py > quiz] set type({type}), options({option_list}), correct({correct}) ") if type == "단답형" or type == None: if (correct.strip().lower() in NO_LIST) or (correct.strip().lower() in YES_LIST): type = "OX퀴즈" diff --git a/model/app/data/settings.py b/model/app/data/settings.py index d01edeb9..86399fef 100644 --- a/model/app/data/settings.py +++ b/model/app/data/settings.py @@ -20,7 +20,7 @@ SUMMARY_SENTENCE_OVERLAP = 2 # vector chunk -VECTOR_CHUNK_SIZE = 320 +VECTOR_CHUNK_SIZE = 400 VECTOR_SENTENCE_OVERLAP = 0 K = 2 diff --git a/model/app/model/chain.py b/model/app/model/chain.py index 6a1576f3..7ea99d2b 100644 --- a/model/app/model/chain.py +++ b/model/app/model/chain.py @@ -7,6 +7,9 @@ from langchain.chains.base import Chain from langchain.schema import BaseRetriever from typing import List, Dict, Any +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import FAISS +from sklearn.metrics.pairwise import cosine_similarity from model.prompt import CHOICE_PROB_TEMPLATE, SHORT_PROB_TEMPLATE, GENERATE_QUIZ_TEMPLATE, JSON_FORMAT_TEMPLATE from data.settings import * @@ -18,6 +21,8 @@ # Retrieval 체인 class RetrievalChain(Chain): retriever: BaseRetriever + embeddings: OpenAIEmbeddings + summary: str @property def input_keys(self) -> List[str]: @@ -30,7 +35,23 @@ def output_keys(self) -> List[str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: message = inputs["message"] relevant_docs = self.retriever.invoke(message) - context = " ".join([doc.page_content for doc in relevant_docs]) + log('info', f'[chain.py > RetrievalChain] Retrieved {len(relevant_docs)} Docs for {message}') + + # 검색된 문서들과 summary 임베딩 + doc_embeddings = self.embeddings.embed_documents([doc.page_content for doc in relevant_docs]) + summary_embedding = self.embeddings.embed_query(self.summary) + summary_embedding = [summary_embedding] # 1D 배열을 2D 배열로 변환 + + # 유사도 계산 + similarities = cosine_similarity(summary_embedding, doc_embeddings) + + # 유사도가 높은 문서 K개 선택 + top_k_indices = similarities.argsort()[0][-K:][::-1] + top_k_docs = [relevant_docs[i] for i in top_k_indices] + log('info', f'[chain.py > RetrievalChain] Selected {len(top_k_docs)} Docs for {message} with summary.') + + context = "\n".join([doc.page_content for doc in top_k_docs]) + if len(context) < VECTOR_CHUNK_SIZE * 0.3: log('error', f'[chain.py > RetrievalChain] Retrieved Context is not sufficient. - "{message}", len: {len(context)}') raise ValueError(f'[chain.py > RetrievalChain] Retrieved Context is not sufficient. - "{message}", len: {len(context)}') @@ -90,9 +111,14 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: # 문제 생성 클래스 class QuizPipeline: - def __init__(self, indices: object): - self.indices = indices - self.retriever = self.indices.as_retriever(search_kwargs=dict(k=K)) # 인덱스 객체로부터 retriever 초기화 + def __init__(self, split_docs, summary): + self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") + try: + self.indices = FAISS.from_documents(split_docs, self.embeddings) + self.retriever = self.indices.as_retriever(search_kwargs=dict(k=K+1)) # 인덱스 객체로부터 retriever 초기화 + except Exception as e: + raise e + self.summary = summary self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125") self.question_templates = [CHOICE_PROB_TEMPLATE, SHORT_PROB_TEMPLATE] # self.output_schemas = [ChoiceOutput, ShortOutput] @@ -114,7 +140,7 @@ def generate_quiz(self, message: str) -> dict: partial_variables={"format_instructions": json_output_parser.get_format_instructions()} ) - retrieval_chain = RetrievalChain(retriever=self.retriever) + retrieval_chain = RetrievalChain(retriever=self.retriever, embeddings=self.embeddings, summary=self.summary) quiz_generation_chain = QuizGenerationChain(prompt= quiz_prompt, llm=self.llm) json_formatter_chain = JSONFormatterChain(prompt=json_prompt, llm=self.llm, output_parser=json_output_parser)