新建
AI4S Cup - LLM挑战赛 - 大模型科学文献分析-解决方案-[天命人]-推理代码
大陆
推荐镜像 :Basic Image:ubuntu:22.04-py3.10-cuda12.1
推荐机型 :c2_m4_cpu
赞
数据集
Mini_cpm(v1)
LLMS-chat(v3)
yolo_model(v2)
torch_py10_linux(v1)
LLMS-version(v3)
all-MiniLM-L6-v2(v1)
Train_data(v5)
img2smiles(v3)
双击即可修改
代码
文本
[27]
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple torch==2.1.2
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple torchvision==0.16.2
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple transformers==4.44.0
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple Pillow==10.1.0
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple sentencepiece==0.1.99
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple decord
代码
文本
[28]
import logging
logging.basicConfig(
filename='./log.txt',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logging.info('--start')
代码
文本
[29]
!apt-get update
!apt-get install -y tesseract-ocr
!pip install Pillow pytesseract
!tesseract --version
!apt-get install -y poppler-utils
代码
文本
[30]
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple torch
logging.info('--torch installed')
!pip install decimer
logging.info('--decimer installed')
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple PyPDF2==3.0.1
logging.info('--PyPDF2 installed')
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple transformers==4.44.0 # 4.40.2
logging.info('--transformers installed')
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple sentence-transformers==3.0.1 # 2.7.0
logging.info('--sentence-transformers installed')
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple pdf2image==1.17.0
logging.info('--pdf2image installed')
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple langchain==0.2.6
logging.info('--langchain installed')
!pip install --index-url https://pypi.tuna.tsinghua.edu.cn/simple opencv-python==4.10.0.84
logging.info('--opencv-python installed')
!pip install rdkit-pypi==2022.9.5
logging.info('--rdkit-pypi installed')
!pip install faiss-cpu==1.8.0.post1
logging.info('--faiss-cpu installed')
!pip install pandas>=1.2.4
logging.info('--pandas installed')
!pip install -U langchain-community==0.2.6
logging.info('--pandas installed')
!pip install accelerate==0.33.0
logging.info('--accelerate installed')
!pip install tiktoken==0.7.0
logging.info('--tiktoken installed')
!pip install numpy #==1.26.4
logging.info('--numpy installed')
!pip install albumentations==1.1.0
logging.info('--albumentations installed')
!pip install SmilesPE==0.0.3
logging.info('--SmilesPE installed')
!pip install timm==0.4.12
logging.info('--timm installed')
!pip install OpenNMT-py==2.2.0
logging.info('--OpenNMT installed')
!pip install matplotlib==3.9.1
logging.info('--matplotlib installed')
!pip install seaborn==0.13.2
logging.info('--seaborn installed')
!apt-get install -y poppler-utils
logging.info('--poppler-utils installed')
代码
文本
[31]
# 导包
import os
import re
import sys
import cv2
import glob
import time
import torch
import faiss
import json
import PyPDF2
import pytesseract
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModel
from sentence_transformers import SentenceTransformer, util
from pdf2image import convert_from_path
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.docstore import InMemoryDocstore
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
from DECIMER import predict_SMILES
logging.info('导包成功')
代码
文本
[32]
# 路径配置
DATA_PATH=os.getenv('DATA_PATH')
device="cuda"
if not DATA_PATH:
DATA_PATH='/bohr/v-test-zup1/v6/exampleData_v5/'
print("Warning: DATA_PATH environment variable is not set. Using default path:", DATA_PATH)
pdfs_dir=DATA_PATH+'/pdfs/'
test_input_path=DATA_PATH+'/question.jsonl'
test_output_path='submission.jsonl'
llm_dir = "/bohr/1014-8ud9/v3/"
vlm_dir = "/bohr/1015-t5o4/v1/"
embedding_model_path = "/bohr/all-MiniLM-L6-v2-c74o/v1/all-MiniLM-L6-v2/"
molscribe_path = '/bohr/img2smiles-zhs7/v3/swin_base_char_aux_1m680k.pth'
!mkdir image_list_address
device = torch.device('cuda')
logging.info('配置执行成功')
代码
文本
[33]
yolo_model = torch.hub.load("/bohr/yolo-r2gq/v2/yolo/", "custom", path="/bohr/yolo-r2gq/v2/best.pt", source="local",device="cpu")
logging.info('yolo_model执行成功')
代码
文本
[34]
table_yolo_model = torch.hub.load("/bohr/yolo-r2gq/v2/yolo", "custom", path="/bohr/yolo-r2gq/v2/table.pt", source="local",device="cpu")
logging.info('table_yolo_model执行成功')
代码
文本
[35]
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_path)
logging.info('embeddings_model执行成功')
代码
文本
[36]
sentence_embeddings_model = SentenceTransformer(embedding_model_path)
logging.info('sentence_embeddings_model执行成功')
代码
文本
[37]
# 加载大模型
llm_tokenizer = AutoTokenizer.from_pretrained(llm_dir,trust_remote_code=True)
llm = AutoModelForCausalLM.from_pretrained(
llm_dir,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device).eval()
logging.info('llm执行成功')
try:
vlm_tokenizer = AutoTokenizer.from_pretrained(vlm_dir, trust_remote_code=True)
vlm = AutoModel.from_pretrained(
vlm_dir,
trust_remote_code=True,
attn_implementation='sdpa',
torch_dtype=torch.bfloat16)
vlm = vlm.eval().cuda()
logging.info('vlm执行成功')
except:
logging.info('vlm执行失败')
代码
文本
[ ]
task1_prompt = """When faced with a task, first identify domain experts as participants who will help solve the task. Then, multiple rounds of collaborative processes are initiated until a final answer is reached. Participants will make critical comments and detailed recommendations where necessary.
Here are some examples:
---
example task1:
question: Use numbers and basic arithmetic operations (+ - * /) to obtain 24. You need to use all numbers, and each number can only be used once.
a) 6 12 1 1
b) 6 13 1 1
c) 6 12 2 1
d) 7 12 1 5
Participants: AI Assistant (you); Math Expert
Start collaboration!
Math Expert: Let's analyze the task in detail. You need to make sure that you meet the requirement, that you need to use exactly the four numbers (6 12 1 1) to construct 24. To reach 24, you can think of the common divisors of 24 such as 4, 6, 8, 3 and try to construct these first. Also you need to think of potential additions that can reach 24, such as 12 + 12.
AI Assistant (you): Thanks for the hints! Here's one initial solution: (12 / (1 + 1)) * 6 = 24
Math Expert: Let's check the answer step by step. (1+1) = 2, (12 / 2) = 6, 6 * 6 = 36 which is not 24! The answer is not correct. Can you fix this by considering other combinations? Please do not make similar mistakes.
AI Assistant (you): Thanks for pointing out the mistake. Here is a revised solution considering 24 can also be reached by 3 * 8: (6 + 1 + 1) * (12 / 4) = 24.
Math Expert: Let's first check if the calculation is correct. (6 + 1 + 1) = 8, 12 / 4 = 3, 8 * 3 = 24. The calculation is correct, but you used 6 1 1 12 4 which is not the same as the input 6 12 1 1. Can you avoid using a number that is not part of the input?
AI Assistant (you): You are right, here is a revised solution considering 24 can be reached by 12 + 12 and without using any additional numbers: 6 * (1 - 1) + 12 = 24.
Math Expert: Let's check the answer again. 1 - 1 = 0, 6 * 0 = 0, 0 + 12 = 12. I believe you are very close, here is a hint: try to change the "1 - 1" to "1 + 1".
AI Assistant (you): Sure, here is the corrected answer: 6 * (1+1) + 12 = 24
Math Expert: Let's verify the solution. 1 + 1 = 2, 6 * 2 = 12, 12 + 12 = 12. You used 1 1 6 12 which is identical to the input 6 12 1 1. Everything looks good!
Finish collaboration!
Final answer: a) 6 12 1 1
---
example task2:
question: The quantum efficiency of a photon detector is 0.1. If 100 photons are sent into the detector, one after the other, the detector will detect photons
a) an average of 10 times, with an rms deviation of about 4
b) an average of 10 times, with an rms deviation of about 3
c) an average of 10 times, with an rms deviation of about 1
d) an average of 10 times, with an rms deviation of about 0.1
Start collaboration!
Participants: AI Assistant (you); Physicist; Math Expert
Physicist: This question falls within the field of quantum optics and detector physics, specifically dealing with the statistical properties of photon detection.
Math Expert: To solve this question, we will apply statistical methods used in these fields, focusing on the properties of the binomial distribution.
AI Assistant (you): Thank you all for your suggestions.The quantum efficiency (QE) of a photon detector is given as 0.1. This means that there is a 10% probability that a photon incident on the detector will be detected.
Physicist: Right! 100 photons are sent into the detector one after the other.
AI Assistant (you): Thanks.The probability of detecting each photon is 0.1.
Physicist: Yes.
Math Expert: Given the nature of the problem, we use the binomial distribution to model the number of photons detected.
AI Assistant (you): So, We need to consider parameters of the binomial Distribution. Number of trials (photons sent) = 100, Probability of success (photon detected) = 0.1.
Math Expert: Sure, you are right.The average number of photons detected by the detector is 10, and the root mean square (rms) deviation, which is the standard deviation in this context, is 3.
AI Assistant (you): Thus, the correct answer is b) an average of 10 times, with an rms deviation of about 3
Finish collaboration!
Final answer: b) an average of 10 times, with an rms deviation of about 3
---
example task3:
question: Why are drug combinations essential for HIV?
a) Single drugs are not completely inhibitory
b) Mutations negate the effect of one drug
c) Combinations of antibiotics are effective versus
d) The virus cannot mutate vs a combination
Start collaboration!
Participants: AI Assistant (you); Biologist; Virologist
Biologist: This question pertains to the field of medicine, specifically within the areas of infectious diseases and virology, focusing on HIV treatment strategies.
Virologist: Right. HIV (Human Immunodeficiency Virus) is known for its high mutation rate, which allows the virus to quickly develop resistance to antiretroviral drugs if only a single drug is used.
AI Assistant (you): Let's look at the options one by one.
Virologist: While single drugs may not fully suppress the virus, the main issue with single-drug therapy is the rapid development of drug resistance due to the high mutation rate of HIV.
AI Assistant (you): So the option a) is wrong. How is the b) option?
Biologist: This is a crucial reason. HIV's high mutation rate allows the virus to quickly adapt and become resistant to single drugs, rendering them ineffective over time.
Virologist: I agree with you!
AI Assistant (you): Oh, b) Choice is probably the answer. How is the c) option?
Virologist: This statement is incomplete and does not address HIV specifically. While combinations of antibiotics are used to treat bacterial infections like TB, this does not directly relate to the rationale for using drug combinations in HIV treatment.
AI Assistant (you): So c) also is wrong. How is the d) option?
Biologist: This statement is not entirely accurate. The virus can still mutate, but the use of a combination of drugs makes it significantly less likely for the virus to simultaneously develop resistance to multiple drugs, thereby maintaining the effectiveness of the treatment.
AI Assistant (you): Thanks. Therefor, the b) option is correct answer.
Finish collaboration!
Final answer: b) Mutations negate the effect of one drug
---
Now, identify the participants and collaboratively solve the following task step by step. Remember to provide the final solution with the following format "Final answer: (correct option here).".
question: {question}
"""
task2_prompt = """You are an expert in the field of Alloy Materials.You are a specialist in the domain of heat treatment processes, such as homogenization, annealing, aging, solution treatment, quenching, and tempering, among others.Answer the following question for context with "Yes" or "No".
context:{match_content}
question:{question}
Note just answer: Yes or No"""
task3_prompt = """
When faced with a task, first identify domain experts and statistician as participants based on the question and chart information, who will help solve the task. Be careful to incorporate descriptive information about the chart as well. Be careful to determine what type of chart it is, and then use specific analysis methods to analyze it.Then, multiple rounds of the collaborative process are initiated until a final answer is reached. Participants will provide critical comments and detailed recommendations where necessary.
Here are some examples:
---
example task1:
question: According to Figure 1 panel A, which method shows the best accuracy performance?
a) MLP
b) KNN
c) RefDNN
d) GB
Participants: AI Assistant (you);Statistician; Machine Learning Expert
Start collaboration!
AI Assistant (you): Let's analyze the task based on the question and chart in detail.
Statistician: RefDNN consistently shows the highest accuracy across all datasets in Figure 1 panel A.
Machine Learning Expert: Yeah! The accuracy bars for RefDNN are consistently at the top compared to the other methods across all datasets (GDSC and CLE).
Statistician: The other methods (MLP, KNN, GB) have lower accuracy values across the datasets.
Machine Learning Expert: I agree with you.
AI Assistant (you): Thank all for you. So the option c) is correct!
Finish collaboration!
Output: c) RefDNN
---
example task2:
question: In Figure 3, which has a higher accurate score, with the graph encoder or without?
a) with graph encoder
b) w/o graph encoder
Participants: AI Assistant (you);Statistician; Bioinformatics Expert
Start collaboration!
AI Assistant (you): Let's analyze the task based on the question and chart in detail.
Statistician: The bar graph under the Accuracy heading in Figure 3 is taller for the "w graph encoder" (orange bars) compared to the "w/o graph encoder" (blue bars), indicating a higher accurate score when using the graph encoder.
Bioinformatics Expert: Right. The bar graph in Figure 3, under the Accuracy heading, shows two sets of bars for both the graph encoder and without graph encoder scenarios across different observation windows.
Statistician: I agree with you. The bars for the graph encoder are consistently higher than those for the without graph encoder, indicating a higher accurate score when the graph encoder is used.
AI Assistant (you): Thank all for you. So the option a) is correct!
Finish collaboration!
Output: a) with graph encoder
---
example task3:
question: For Figure 5 part d in this paper, which distance is greater, Alpha, Beta, or Delta?
a) Alpha
b) Beta
c) Delta
Participants: AI Assistant (you); Statistician; Machine Learning Expert
Start collaboration!
AI Assistant (you): Let's analyze the task based on the question and chart in detail.
Statistician: In Figure 5 part d, we can learn from that this is a box plot. Note that we should use the box diagram method to analyze.
Machine Learning Expert: Obviously, Beta is doing better on the chart than the others.
AI Assistant (you): According to the box plot, the median for Beta is the highest among Alpha and Delta, and the interquartile range for Beta is also the highest among Alpha and Delta.
Statistician: The whiskers for Beta are also the longest, indicating that the range of distances for Beta is the greatest among Alpha and Delta.
AI Assistant (you): Thank all for you. So the option b) is correct!
Finish collaboration!
Note that this is a box plotdiagram, we should use the box diagram method to analyze.
Output: b) Beta
---
Now, identify the participants and collaboratively solve the following task step by step. Remember to just provide correct option without any other information in last, such as 'b) 2045.1'.".
question: {question}
"""
task6_prompt = """You are an expert in the electrolytes field.Please answer the following multiple choice question correctly.Only write the option (e.g., a), b), c), or d)) without explanation.
question: {question}
Note: You only output the correct option! such as a).
"""
task7_prompt = """You are an expert in the field of polymer solar cells researcher who answers the following multiple choice question correctly.Only write the options and values down, such as 'b) 2045.1'.
context:{match_content}
question:{question}
Note: You only output the correct option! such as a)."""
代码
文本
[ ]
def llm_answer_generator(query):
model = llm
inputs = llm_tokenizer.apply_chat_template([{"role": "user", "content": query}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = inputs.to(device)
gen_kwargs = {"do_sample": True, "top_k":1, "max_new_tokens": 5000}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
LLM_answer = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
match = re.search(r"Final answer:\s*([a-z]\))", LLM_answer)
if match:
answer = match.group(1)
else:
answer = LLM_answer
return answer
代码
文本
[ ]
def vlm_generate_answer(image, query):
model = vlm
tokenizer= vlm_tokenizer
image_array = np.array(image)
image = Image.fromarray(image_array.astype('uint8')).convert('RGB')
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
add_generation_prompt=True, tokenize=True, return_tensors="pt",
return_dict=True) # chat mode
inputs = inputs.to(device)
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
answer = tokenizer.decode(outputs[0])
return answer
代码
文本
[ ]
def extract_text_from_pdf(file_path):
"""提取pdf文本内容
:param file_path:string
:return title:string
:return content:string
"""
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
title = pdf_reader.metadata.title
content = ''
for page in pdf_reader.pages:
content += page.extract_text()
return title, content
代码
文本
[ ]
def textSplitter(text,chunk_size,chunk_overlap):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_text(text)
return chunks
代码
文本
[ ]
def sentences2embedding(chunks,source=None):
docs = [Document(page_content=chunk, metadata=dict(source=source)) for chunk in chunks]
db = FAISS.from_documents(docs, embeddings_model)
return db
代码
文本
[ ]
def compare_smiles(smiles1, smiles2):
len1 = len(smiles1)
len2 = len(smiles2)
if len1 > 2 * len2 or len2 > 2 * len1: # 长度差超过两倍,跳出
return 0
mol1 = Chem.MolFromSmiles(smiles1)
mol2 = Chem.MolFromSmiles(smiles2)
if mol1 is None or mol2 is None:
return 0
fp1 = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol1, 2)
fp2 = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol2, 2)
similarity = TanimotoSimilarity(fp1, fp2)
return similarity
代码
文本
[ ]
def remove_files(directory_path):
if os.path.isdir(directory_path):
for filename in os.listdir(directory_path):
file_path = os.path.join(directory_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The path {directory_path} is not a directory.')
代码
文本
[ ]
def page_image2molecule_images(page_images):
model=yolo_model
mol_img_list = []
for page_image in page_images:
image = np.array(page_image)
# im = cv2.imread(image_path)[..., ::-1]
results = model(image, size=640) # batch of images
pos_list = results.xyxy[0]
if len(pos_list) == 0:
continue
for pos_unit in pos_list:
x1, y1, x2, y2 = int(pos_unit[0]),int(pos_unit[1]),int(pos_unit[2]),int(pos_unit[3]) # 根据实际情况调整坐标
cropped_image = image[y1:y2, x1:x2]
size=(640, 640)
fill_color=(255, 255, 255)
h, w = cropped_image.shape[:2]
scale = min(size[0] / w, size[1] / h)
new_w = int(w * scale)
new_h = int(h * scale)
resized_img = cv2.resize(cropped_image, (new_w, new_h), interpolation=cv2.INTER_AREA)
new_img = np.full((size[1], size[0], 3), fill_color, dtype=np.uint8)
x_offset = (size[0] - new_w) // 2
y_offset = (size[1] - new_h) // 2
new_img[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_img
mol_img_list.append(new_img)
# 将图片保存到临时目录 保存前清空
remove_files(directory_path='image_list_address')
image_id = 0
for image_unit in mol_img_list:
path = 'image_list_address/'+str(image_id)+'.png'
cv2.imwrite(path, image_unit)
image_id += 1
return mol_img_list
代码
文本
[ ]
def molecule_images2smiles(molecule_images):
result = []
file_list = glob.glob('image_list_address/*png')
for image_unit in file_list:
smiles_output = predict_SMILES(image_unit)
if '.' in smiles_output:
result.extend(smiles_output.split('.'))
else:
result.append(smiles_output)
result = [m for m in result if m != '*']
remove_files(directory_path='image_list_address')
return result
代码
文本
[ ]
def add_embeddings(db,sentence_embeddings_model, text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
chunks = text_splitter.split_text(text)
embeddings = sentence_embeddings_model.encode(chunks)
db.add(embeddings)
# return db
def query_embeddings(db,query_text):
query_embedding = sentence_embeddings_model.encode([query_text])
distances, indices = db.search(query_embedding, k=1)
return distances[0][0]
代码
文本
[ ]
def img2text(image):
text = pytesseract.image_to_string(image, lang='eng').lstrip().splitlines()
return text[0]
代码
文本
[ ]
def sentences2embedding(chunks,source=None):
docs = [Document(page_content=chunk, metadata=dict(source=source)) for chunk in chunks]
db = FAISS.from_documents(docs, embeddings_model)
return db
代码
文本
[ ]
# 打开JSONL文件
with open(test_input_path, 'r', encoding='utf-8') as file:
data = []
for line in file:
# 解析每一行JSON
json_obj = json.loads(line)
data.append(json_obj)
total = 0
count = 0
test_out_list=[]
print("任务开始时间:",time.ctime())
for obj in data:)
total += 1
input_list = obj['input']
task = obj['task']
print("--task:",task)
logging.info("--task:{}".format(task))
system_content = ""
user_content = ""
for item in input_list:
if item['role'] == 'system':
system_content += item['content']
if item['role'] == "user":
user_content += item['content']
if task == '1':
option_list = obj['option']
query = task1_prompt.format(question=user_content)
try:
LLM_answer = llm_answer_generator(query)
if 'a)' in LLM_answer:
LLM_answer = option_list[0]
elif 'b)' in LLM_answer:
LLM_answer = option_list[1]
elif 'c)' in LLM_answer:
LLM_answer = option_list[2]
elif 'd)' in LLM_answer:
LLM_answer = option_list[3]
else:
LLM_answer = option_list[2]
except:
LLM_answer = option_list[0]
print("LLM_answer:",LLM_answer)
response = LLM_answer
if task == '2':
pages = obj['pages']
doi = obj['doi'].replace("/","_")
pdf_path = doi+".pdf"
pdf_dir = os.path.join(pdfs_dir,pdf_path).replace("\\", "/")
pdf_dir = pdf_dir.replace(" (Supporting Information)","_si") # 处理数据异常
try:
question = re.search(r"In the upper paper, (.+)", user_content).group(1)
match = re.search(r"technique before (.+)\?", user_content)
if match:
query = match.group(1)
else:
query = user_content
except:
query = user_content
_, pdf_content = extract_text_from_pdf(pdf_dir)
chunks = textSplitter(pdf_content,300,50)
db = sentences2embedding(chunks)
matching_docs = db.similarity_search(query, k=35) # K最优值
flag = "No"
total += 1
test_context = ""
for item in matching_docs: # 显示页面的条目数
match_content = item.page_content
query = task2_prompt.format(match_content=match_content,question=question)
LLM_answer = llm_answer_generator(query).strip()
if LLM_answer.lower().strip() != 'no':
print(LLM_answer.lower())
flag = "Yes"
print("flag:",flag)
response = flag
if task == '3':
ideal = obj['ideal']
option_list = obj['option']
pages = obj['pages']
doi = obj['doi'].replace("/","_")
pdf_path = doi+".pdf"
pdf_dir = os.path.join(pdfs_dir,pdf_path).replace("\\", "/")
pdf_dir = pdf_dir.replace(" (Supporting Information)","_si") #数据异常
try:
page_image = convert_from_path(pdf_dir,dpi=300)[pages[0]-1] # -------------dpi设置300 原600 理由加快推理速度
query = task3_prompt.format(question=user_content)
LLM_answer = vlm_generate_answer(page_image, query)
if 'a)' in LLM_answer:
LLM_answer = option_list[0]
elif 'b)' in LLM_answer:
LLM_answer = option_list[1]
elif 'c)' in LLM_answer:
LLM_answer = option_list[2]
elif 'd)' in LLM_answer:
LLM_answer = option_list[3]
else:
LLM_answer = option_list[1]
except:
LLM_answer = option_list[0]
print("LLM_answer:",LLM_answer)
response = LLM_answer
if task == '4':
ideal = obj['ideal']
doi = obj['doi'].replace("/","_")
pdf_path = doi+".pdf"
pdf_dir = os.path.join(pdfs_dir,pdf_path).replace("\\", "/")
pdf_dir = pdf_dir.replace(" (Supporting Information)","_si") #数据异常
source_smiles = re.search(r'"(.*?)"', user_content).group(1)
page_images = convert_from_path(pdf_dir,dpi=300)
flag = "No"
if len(page_images) < 50:
molecule_images = page_image2molecule_images(page_images)
smiles_list = molecule_images2smiles(molecule_images)
greedy_score = 0
for taget_smiles in smiles_list:
try:
score = compare_smiles(source_smiles,taget_smiles)
if greedy_score < score:
greedy_score = score
except:
continue
if score>0.5:
flag = "Yes"
else:
flag = "Yes"
print("flag:",flag)
response = flag
if task == '5':
ideal = obj['ideal']
option_list = obj['option']
doi = obj['doi'].replace("/","_")
pdf_path = doi+".pdf"
pdf_dir = os.path.join(pdfs_dir,pdf_path).replace("\\", "/")
pdf_dir = pdf_dir.replace(" (Supporting Information)","_si") # 处理异常数据
match_patterns = {
"table": r'table \d+',
"scheme": r'scheme \d+',
"figure": r'figure \d+',
"example": r'example\s+[^\s]+',
"procedures": r'([^\s"]+\s+Procedures)'
}
location_tags = ["table", "scheme", "figure", "example", "procedures"]
location_content = re.search(r'deal with is\s+(.+)', user_content).group(1).strip()
for tag in location_tags:
if tag in location_content.lower():
location = re.search(match_patterns[tag], location_content, re.IGNORECASE).group(0)
break
else:
location = None
try:
reply_answer = option_list[0]
except:
reply_answer = 'a'
greedy_score = 0
if location != None:
page_images = convert_from_path(pdf_dir,dpi=300)
molecule_images = page_image2molecule_images(page_images)
smiles_list = molecule_images2smiles(molecule_images)
for option in option_list:
option_smiles = re.findall(r'\b[a-d]\)\s*(.+)', option, re.IGNORECASE)[0].strip()
for target_smiles in smiles_list:
score = compare_smiles(option_smiles,target_smiles)
if greedy_score < score:
greedy_score = score
reply_answer = option
print("greedy_score:",greedy_score)
print("reply_answer:",reply_answer)
response = reply_answer
if task == "6":
ideal = obj['ideal']
option_list = obj['option']
doi = obj['doi'].replace("/","_")
pdf_path = doi+".pdf"
pdf_dir = os.path.join(pdfs_dir,pdf_path).replace("\\", "/")
pdf_dir = pdf_dir.replace(" (Supporting Information)","_si") # 处理数据异常
try:
question_match = re.search(r"In the upper paper, what (?:is|are) the (.+)\??", user_content).group(1).replace('?','')
except:
question_match = user_content
db = faiss.IndexFlatL2(384) # dimension:384
images = convert_from_path(pdf_dir,dpi=300) # ,dpi=600
greedy_smililarity_distance = 9999
for i, image in enumerate(images):
im = np.array(image)[..., ::-1]
results = table_yolo_model(im, size=640)
pos_list = results.xyxy[0]
if len(pos_list) == 0:
continue
for pos_unit in pos_list:
x1, y1, x2, y2 = int(pos_unit[0]-4),int(pos_unit[1]-5),int(pos_unit[2]+6),int(pos_unit[3]+5) # 根据实际情况调整坐标
cropped_image = im[y1:y2, x1:x2]
text = img2text(cropped_image)
add_embeddings(db,sentence_embeddings_model, text)
print('--text:',text)
smililarity_distance = query_embeddings(db,question_match)
# print("--smililarity_distance:",smililarity_distance)
if greedy_smililarity_distance > smililarity_distance:
greedy_smililarity_distance = smililarity_distance
target_table_image = cropped_image
try:
query = task6_prompt.format(question=user_content)
LLM_answer = vlm_generate_answer(target_table_image, query)
if 'a)' in LLM_answer:
LLM_answer = option_list[0]
elif 'b)' in LLM_answer:
LLM_answer = option_list[1]
elif 'c)' in LLM_answer:
LLM_answer = option_list[2]
elif 'd)' in LLM_answer:
LLM_answer = option_list[3]
else:
LLM_answer = option_list[1]
except:
try:
LLM_answer = option_list[0]
except:
LLM_answer = 'a'
response = LLM_answer
if task == "7":
ideal = obj['ideal']
option_list = obj['option']
doi = obj['doi'].replace("/","_")
pdf_path = doi+".pdf"
pdf_dir = os.path.join(pdfs_dir,pdf_path).replace("\\", "/")
pdf_dir = pdf_dir.replace(" (Supporting Information)","_si") # 处理数据异常
try:
match = re.search(r"What is the (.+)\??", user_content).group(1).replace('?','')
if match:
query = match
else:
query = user_content
except:
query = user_content
_, pdf_content = extract_text_from_pdf(pdf_dir)
chunks = textSplitter(pdf_content,800,180)
db = sentences2embedding(chunks)
matching_docs = db.similarity_search(query, k=25)
all_context = ""
for item in matching_docs: # 显示页面的条目数
match_content = item.page_content
all_context += match_content
query = task7_prompt.format(match_content=all_context,question=user_content)
LLM_answer = llm_answer_generator(query).strip()
try:
if 'a)' in LLM_answer:
LLM_answer = option_list[0]
elif 'b)' in LLM_answer:
LLM_answer = option_list[1]
elif 'c)' in LLM_answer:
LLM_answer = option_list[2]
elif 'd)' in LLM_answer:
LLM_answer = option_list[3]
else:
LLM_answer = option_list[1]
except:
LLM_answer = option_list[0]
response = LLM_answer
print("LLM_answer:",LLM_answer)
print("\n\n")
obj["ideal"]=response
#把拼接的文献内容弹出去
obj["input"].pop()
test_out_list.append(obj)
#把结果写入json
with open(test_output_path,'w',encoding='utf-8') as f:
for item in test_out_list:
json.dump(item, f, ensure_ascii=False)
f.write('\n')
print("任务结束时间:",time.ctime())
代码
文本
[ ]
代码
文本
点个赞吧
推荐阅读
公开
AI4S Cup - 电芯电化学阻抗预测 rank4 A榜3.2819 B榜1.4740
Shake down
发布于 2023-11-29
1 转存文件
公开
GAN使用示例爱学习的王一博
发布于 2024-05-14
2 赞2 转存文件