-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
260b9c0
commit d416f9d
Showing
11 changed files
with
15,996 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
from docx import Document | ||
import re | ||
import torch.nn as nn | ||
from hflayers import Hopfield, HopfieldPooling, HopfieldLayer | ||
import torch.nn.functional as F | ||
import torch | ||
from sentence_transformers import SentenceTransformer | ||
import pickle | ||
|
||
class HopfieldRetrievalModel(nn.Module): | ||
def __init__(self, beta=0.125, update_steps_max=3): | ||
# def __init__(self, beta=0.125): | ||
super(HopfieldRetrievalModel, self).__init__() | ||
self.hopfield = Hopfield( | ||
scaling=beta, | ||
update_steps_max=update_steps_max, | ||
update_steps_eps=1e-5, | ||
|
||
# do not project layer input | ||
state_pattern_as_static=True, | ||
stored_pattern_as_static=True, | ||
pattern_projection_as_static=True, | ||
|
||
# do not pre-process layer input | ||
normalize_stored_pattern=False, | ||
normalize_stored_pattern_affine=False, | ||
normalize_state_pattern=False, | ||
normalize_state_pattern_affine=False, | ||
normalize_pattern_projection=False, | ||
normalize_pattern_projection_affine=False, | ||
|
||
# do not post-process layer output | ||
disable_out_projection=True) | ||
|
||
def forward(self, memory, trg): | ||
memory = torch.unsqueeze(memory, 0) | ||
trg = torch.unsqueeze(trg, 0) | ||
output = self.hopfield((memory, trg, memory)) | ||
output = output.squeeze(0) | ||
memories = memory.squeeze(0) | ||
# temp = torch.bmm(F.softmax(attn_output_weights_init, dim=-1), memory).squeeze(0) | ||
pair_list = F.normalize(output) @ F.normalize(memories).t() # step1 | ||
return pair_list | ||
|
||
|
||
|
||
def read_external_knowledge(path): | ||
path = '/Users/jmy/Desktop/ai_for_health_final/exsit_knowledge/my_dict.pkl' | ||
with open(path, 'rb') as file: | ||
loaded_data = pickle.load(file) | ||
paragraph = [] | ||
for i in loaded_data: | ||
paragraph.append(loaded_data[i]) | ||
return paragraph | ||
|
||
|
||
def read_reports(path): | ||
reports = [] | ||
for filename in os.listdir(path): | ||
if filename.endswith(".txt"): | ||
filepath = os.path.join(path, filename) | ||
|
||
# Read the .docx file | ||
with open(filepath, 'r') as f: | ||
txt = f.read() | ||
reports.extend(txt.split('\n')) | ||
|
||
return reports | ||
|
||
|
||
|
||
def retrieval_info(reports, path, k): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
paragraphs = read_external_knowledge(path + '/exsit_knowledge') | ||
|
||
# sentence_embedding with paragraphs | ||
model = SentenceTransformer('all-mpnet-base-v2') | ||
# p_embeddings = [] | ||
# for i in paragraphs: | ||
# p_embeddings.append(model.encode(i)) | ||
p_embeddings = model.encode(paragraphs) | ||
# sentence_embedding with reports | ||
report_embeddings = model.encode(reports) | ||
retrievaler = HopfieldRetrievalModel().to(device) | ||
result = retrievaler(torch.tensor(p_embeddings).to(device) * 100, torch.tensor(report_embeddings).to(device) * 100) | ||
input_ids = torch.topk(result, k, dim=1).indices | ||
|
||
# mask = ~(input_ids == input_ids[0]).any(dim=1) | ||
# input_ids = input_ids[mask] | ||
indices = input_ids[0] | ||
# indices = set() | ||
# for input_id in input_ids: | ||
# for id in input_id: | ||
# indices.add(id.item()) | ||
knowledge = [] | ||
for indice in indices: | ||
knowledge.append(paragraphs[indice]) | ||
knowledge = [x for x in knowledge if x != ''] | ||
return knowledge | ||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
reports = read_reports('/Users/jmy/Desktop/ai_for_health_final/dataset_folder/health_report_{2343}') # 13452 | ||
know = retrieval_info(reports,'/Users/jmy/Desktop/ai_for_health_final/',3) | ||
for i in know: | ||
print(i) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import openai, os | ||
import pandas as pd | ||
import re | ||
from Hopfield import retrieval_info | ||
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader | ||
from llama_index import Prompt | ||
from llama_index import StorageContext, load_index_from_storage | ||
from langchain.prompts.chat import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
SystemMessagePromptTemplate, | ||
) | ||
import csv | ||
|
||
|
||
|
||
|
||
def answer_from_gpt(ques, context, work): | ||
|
||
storage_context = StorageContext.from_defaults(persist_dir='./storage') | ||
index = load_index_from_storage(storage_context, index_id="index_health") | ||
list_score = [] | ||
|
||
t = 0 | ||
for i in ques: | ||
my_context = context + work[t] | ||
QA_PROMPT = get_systemprompt_template(my_context) | ||
query_engine = index.as_query_engine(text_qa_template=QA_PROMPT) | ||
response = query_engine.query(i) | ||
stt = str(response) | ||
score = extract_score(stt) | ||
list_score.append(score) | ||
print(score) | ||
t = t + 1 | ||
|
||
return list_score | ||
|
||
|
||
|
||
def get_systemprompt_template(exist_context): | ||
|
||
chat_text_qa_msgs = [ | ||
SystemMessagePromptTemplate.from_template( | ||
exist_context | ||
), | ||
HumanMessagePromptTemplate.from_template( | ||
"Give the answer in jason format with only one number between 0 and 1 that is: 'score'\n" | ||
"The score number must be an decimals\n" | ||
"This is the rule of answer: 0-0.2 is mild or none, 0.3-0.6 is moderate, and above 0.7 is severe.\n" | ||
"This is a patient‘s medical record. Context information in below\n" | ||
"---------------------\n" | ||
"{context_str}" | ||
"Given the context information, you are a helpful health consultant " | ||
"answer the question: {query_str}\n" | ||
) | ||
] | ||
chat_text_qa_msgs_lc = ChatPromptTemplate.from_messages(chat_text_qa_msgs) | ||
text_qa_template = Prompt.from_langchain_prompt(chat_text_qa_msgs_lc) | ||
|
||
return text_qa_template | ||
|
||
|
||
def extract_score(string): | ||
numbers = re.findall(r'\d+\.\d+|\d+', string) | ||
if numbers: | ||
for i in numbers: | ||
return float(i) | ||
else: | ||
return 0.0 | ||
|
||
|
||
def generate_question(path): | ||
my_feature_list = [] | ||
related_work = [] | ||
with open(path, 'r') as file: | ||
for line in file: | ||
line = line.strip() | ||
my_feature_list.append(line) | ||
question = [] | ||
for i in my_feature_list: | ||
sentence = f"Does the person described in the case have {i} symptoms? Do you think it is serious?" | ||
list_sentence = [sentence] | ||
retrieval = retrieval_info(list_sentence, '/Users/jmy/Desktop/ai_for_health_final/', 1) | ||
question.append(sentence) | ||
related_work.append(retrieval[0]) | ||
print(retrieval[0]) | ||
|
||
return question, related_work, my_feature_list | ||
|
||
|
||
def count_subfolders(folder_path): | ||
subfolder_count = 0 | ||
subfolder_paths = [] | ||
|
||
for root, dirs, files in os.walk(folder_path): | ||
if root != folder_path: | ||
subfolder_count += 1 | ||
|
||
basepath = '/Users/jmy/Desktop/ai_for_health_final/dataset_folder/health_report_' | ||
for i in range(subfolder_count): | ||
path_rr = basepath+str({i}) | ||
subfolder_paths.append(path_rr) | ||
|
||
return subfolder_count, subfolder_paths | ||
|
||
|
||
# def answer_from_gpt(ques,context): | ||
# # rebuild storage context | ||
# storage_context = StorageContext.from_defaults(persist_dir='./storage') | ||
# # load index | ||
# index = load_index_from_storage(storage_context, index_id="index_health") | ||
# | ||
# list_score = [] | ||
# for i in ques: | ||
# QA_PROMPT = get_systemprompt_template(context) | ||
# query_engine = index.as_query_engine(text_qa_template=QA_PROMPT) | ||
# response = query_engine.query(i) | ||
# stt = str(response) | ||
# | ||
# score = extract_score(stt) | ||
# list_score.append(score) | ||
# print(score) | ||
# | ||
# return list_score | ||
|
||
|
||
def load_doc(folder_path,question,work): | ||
print(len(work)) | ||
count,dict = count_subfolders(folder_path) | ||
list_k = [] | ||
context = 'Here is some additional professional health knowledge that can help you better analyze the report' | ||
for i in range(400,500): | ||
documents = SimpleDirectoryReader(dict[i]).load_data() | ||
index = GPTVectorStoreIndex.from_documents(documents) | ||
index.set_index_id("index_health") | ||
index.storage_context.persist('./storage') | ||
content = context | ||
list = answer_from_gpt(question, content, work) | ||
list_k.append(list) | ||
return list_k | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
|
||
openai.api_key = os.environ.get("OPENAI_API_KEY") | ||
path = '/Users/jmy/Desktop/ai_for_health_final/label and feature/input_feature.txt' | ||
question, related_work, features_list = generate_question(path) | ||
folder_path = '/Users/jmy/Desktop/ai_for_health_final/dataset_folder' | ||
list = load_doc(folder_path, question, related_work) | ||
|
||
with open('training/train.txt', 'w') as file: | ||
for item in list: | ||
file.write(''.join(str(item)) + '\n\n') | ||
|
||
with open('training/combined7.csv', 'w', newline='') as file: | ||
writer = csv.writer(file) | ||
# 首先写入特征行 | ||
writer.writerow(features_list) | ||
# 然后写入矩阵的每一行 | ||
for row in list: | ||
writer.writerow(row) | ||
|
||
print("CSV file has been created successfully.") | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import openai | ||
import os | ||
|
||
|
||
def generate_feature(path): | ||
my_feature_list = [] | ||
related_work = [] | ||
with open(path, 'r') as file: | ||
for line in file: | ||
line = line.strip() | ||
my_feature_list.append(line) | ||
question = [] | ||
for i in my_feature_list: | ||
openai.api_key = "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" | ||
messages = [] | ||
system_message = "You are a medical health assistant robot" | ||
messages.append({"role": "system", "content": system_message}) | ||
message = f'Please list the symptoms of {i}? You only need to list the names of the symptoms in order. You do not need to describe the symptoms in detail.' | ||
messages.append({"role": "user", "content": message}) | ||
response = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=messages | ||
) | ||
reply = response["choices"][0]["message"]["content"] | ||
question.append(reply) | ||
print(reply) | ||
|
||
|
||
return question | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
OPENAI_API_KEY = "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" | ||
|
||
path = '/Users/jmy/Desktop/ai_for_health_final/label and feature/output_target.txt' | ||
question = generate_feature(path) | ||
with open("/Users/jmy/Desktop/ai_for_health_final/training/feature.txt", "w") as file: | ||
for item in question: | ||
file.write("%s\n" % item) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
尿频 | ||
烧心 | ||
呕血 | ||
腹胀 | ||
焦躁 | ||
咽部灼烧感 | ||
寒战 | ||
四肢麻木 | ||
打嗝 | ||
体重下降 | ||
食欲不振 | ||
水肿 | ||
嗜睡 | ||
肠鸣 | ||
精神不振 | ||
黑便 | ||
发热 | ||
脱水 | ||
粘便 | ||
呼吸困难 | ||
淋巴结肿大 | ||
鼻塞 | ||
咳嗽 | ||
头晕 | ||
咽部痛 | ||
喷嚏 | ||
口苦 | ||
有痰 | ||
菌群失调 | ||
肠梗阻 | ||
痉挛 | ||
过敏 | ||
气促 | ||
胃肠不适 | ||
胃肠功能紊乱 | ||
胃痛 | ||
螺旋杆菌感染 | ||
稀便 | ||
痔疮 | ||
尿急 | ||
乏力 | ||
消化不良 | ||
胸痛 | ||
反流 | ||
吞咽困难 | ||
背痛 | ||
腹泻 | ||
腹痛 | ||
头痛 | ||
月经紊乱 | ||
便血 | ||
饥饿感 | ||
贫血 | ||
肌肉酸痛 | ||
呕吐 | ||
细菌感染 | ||
黄疸 | ||
心悸 | ||
恶心 | ||
肛周疼痛 |
Oops, something went wrong.