Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
AI4S Cup -LLM挑战赛 - 多模态表格识别与理解-解决方案-nano团队-推理代码
AI4SCUP-LLMTable
多模态表格理解
AI4SCUP-LLMTable多模态表格理解
bohre684ec
John Yan
更新于 2024-09-30
推荐镜像 :table:v2
推荐机型 :c2_m4_cpu
赞 9
1
qwen-2.5-14b-instruct(v1)
表格识别A榜测试集(v4)

1. 文件路径设置

代码
文本
[1]
import os
import shutil

if os.environ.get('DATA_PATH_B'): # 提交时会选择隐藏的测试数据集路径(A+B榜),数据集的格式与A榜数据相同,但数目不同(5360张)
base_dir = os.environ.get('DATA_PATH_B')
else:
base_dir = '/bohr/form-recognition-train-b6y2/v4' # 示例,把A榜测试数据集路径作为测试集路径,仅开发时挂载A榜数据用于debug

data_path = os.path.join(base_dir, 'dataset.json')
sub_path = os.path.join(base_dir, 'sample_submission.json')
image_path = os.path.join(base_dir, 'test_images')

# 定义新的可写路径,这里将文件复制到当前工作目录下
writable_data_path = os.path.join(os.getcwd(), 'dataset_copy.json')
writable_sub_path = os.path.join(os.getcwd(), 'sample_submission_copy.json') # 可写的 sub_path 复制路径

# 复制只读的 dataset.json 到可写位置
try:
shutil.copyfile(data_path, writable_data_path)
print(f"Copied dataset.json to {writable_data_path}")
except Exception as e:
print(f"Error copying dataset.json: {e}")

# 更新 data_path 指向新的可写文件
data_path = writable_data_path

# 复制只读的 sample_submission.json 到可写位置
try:
shutil.copyfile(sub_path, writable_sub_path)
print(f"Copied sample_submission.json to {writable_sub_path}")
except Exception as e:
print(f"Error copying sample_submission.json: {e}")

# 更新 sub_path 指向新的可写文件
sub_path = writable_sub_path

Copied dataset.json to /dataset_copy.json
Copied sample_submission.json to /sample_submission_copy.json
代码
文本

2. 解析Images到latex

代码
文本

2.1 build trt models

代码
文本
[2]
!cd /StructEqTable-Deploy/tools && bash scripts/build_tensorrt.sh
代码
文本

2.2 images to latex

代码
文本
[3]
import json
import os
import time
import torch
from tqdm import tqdm
from PIL import Image
from struct_eqtable import build_model
from pypandoc import convert_text

# 定义常量路径
DATASET_JSON_PATH = data_path # 原始 JSON 文件路径
OUTPUT_IMAGES_DIR = image_path # 图像所在的目录路径
STATE_FILE_PATH = './processing_state.json' # 保存处理状态的文件

# 模型相关参数
CKPT_PATH = '/StructEqTable-Deploy/ckpts/StructTable-base'
TENSORRT_PATH = '/StructEqTable-Deploy/ckpts/StructTable-base-TensorRT'
MAX_NEW_TOKENS = 2048
MAX_WAITING_TIME = 60
USE_CPU = False
OUTPUT_FORMAT = ['latex']

def load_json(filepath):
with open(filepath, 'r') as f:
return json.load(f)

def save_json(data, filepath):
with open(filepath, 'w') as f:
json.dump(data, f, indent=4)

def load_state(filepath):
"""加载处理状态"""
if os.path.exists(filepath):
with open(filepath, 'r') as f:
return json.load(f)
return {'processed_images': []}

def save_state(state, filepath):
"""保存处理状态"""
with open(filepath, 'w') as f:
json.dump(state, f, indent=4)

def process_image(image_path):
"""使用模型对图像进行处理并返回 LaTeX 代码"""
# 加载模型
model = build_model(
CKPT_PATH,
max_new_tokens=MAX_NEW_TOKENS,
max_time=MAX_WAITING_TIME,
tensorrt_path=TENSORRT_PATH
)

if not USE_CPU and TENSORRT_PATH is None:
model = model.cuda()

raw_image = Image.open(image_path)
start_time = time.time()
with torch.no_grad():
output = model(raw_image)
cost_time = time.time() - start_time
print(f"Total cost time: {cost_time:.2f}s")

if cost_time >= MAX_WAITING_TIME:
print(f"\033[93mWarning: Model inference time exceeds {MAX_WAITING_TIME} seconds.\033[0m")

latex_codes = []
for latex_code in output:
for tgt_fmt in OUTPUT_FORMAT:
tgt_code = convert_text(latex_code, tgt_fmt, format='latex') if tgt_fmt != 'latex' else latex_code
latex_codes.append(tgt_code)
return latex_codes

def main():
# 加载原始 JSON 数据
dataset = load_json(DATASET_JSON_PATH)

# 加载处理状态
state = load_state(STATE_FILE_PATH)
processed_images = set(state.get('processed_images', []))

# 设置计数器,用于定期保存
processed_count = 0
save_interval = 50 # 每处理50张图片保存一次

# 使用 tqdm 添加进度条
for entry in tqdm(dataset, desc="Processing images"):
image_name = entry['image_path'] # 获取 image_path 字段

# 如果图像已经处理过,跳过
if image_name in processed_images:
continue

image_path = os.path.join(OUTPUT_IMAGES_DIR, image_name)

# 检查图像是否存在
if not os.path.exists(image_path):
print(f"Image {image_name} not found in {OUTPUT_IMAGES_DIR}. Skipping.")
continue

# 对图像进行识别
try:
latex_code = process_image(image_path)
except Exception as e:
print(f"Error processing image {image_name}: {e}")
continue

# 更新 dataset 中对应的 source_code 字段
entry['source_code'] = latex_code[0] # 假设只使用第一个 LaTeX 输出

# 标记此图像已处理
processed_images.add(image_name)

# 每处理完50个图片保存一次
processed_count += 1
if processed_count % save_interval == 0:
save_json(dataset, DATASET_JSON_PATH)
save_state({'processed_images': list(processed_images)}, STATE_FILE_PATH)
print(f"Progress saved after processing {processed_count} images.")

# 最后保存一次,确保所有处理的数据都保存
save_json(dataset, DATASET_JSON_PATH)
save_state({'processed_images': list(processed_images)}, STATE_FILE_PATH)
print(f"Final output saved to {DATASET_JSON_PATH}.")

if __name__ == '__main__':
main()
/opt/mamba/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Processing images: 100%|██████████| 5/5 [00:00<00:00, 78840.30it/s]Final output saved to /dataset_copy.json.

代码
文本

3. 由latex code解析行数和列数(2-3问题)

代码
文本
[4]
import json
import re

# 读取JSON文件
def load_json(filepath):
with open(filepath, 'r') as f:
return json.load(f)

# 保存JSON文件
def save_json(data, filepath):
with open(filepath, 'w') as f:
json.dump(data, f, indent=4)


import re

# 计算LaTeX表格中的列数和行数,并确保返回的值是整数
def count_columns_and_rows(latex_code):
# 检查是否缺少 \end{tabular},如果缺少则在末尾添加
if '\\end{tabular}' not in latex_code:
latex_code += '\n\\end{tabular}'

# 预处理:删除可能影响解析的LaTeX命令
# 移除颜色设置命令,如 \rowcolor[...]{...}、\columncolor[...]{...}、{\columncolor[...]{...}}
latex_code = re.sub(r'(\\(rowcolor|columncolor)(\[.*?\])?\{.*?\})|(\{\\columncolor\[.*?\]\{.*?\}\})', '', latex_code)
# 移除尺寸设置命令,如 \small、\footnotesize 等
latex_code = re.sub(r'\\(tiny|scriptsize|footnotesize|small|normalsize|large|Large|LARGE|huge|Huge)', '', latex_code)
# 移除列格式修饰符,如 @{...}, >{...}, <{...}, !{...}
latex_code = re.sub(r'@{.*?}|>{.*?}|<{.*?}|!{.*?}', '', latex_code)
# 移除列格式中的空格和竖线,以免影响列数统计
latex_code = latex_code.replace(' ', '').replace('|', '')
# 处理 \tabcolsep 等表格间距设置命令
latex_code = re.sub(r'\\(setlength|addtolength)\{\\tabcolsep\}\{.*?\}', '', latex_code)
# 匹配 tabular 环境
tabular_matches = re.finditer(r'\\begin{tabular}{(.*?)}(.*?)\\end{tabular}', latex_code, re.DOTALL)
column_counts = []
row_counts = []

for match in tabular_matches:
columns_spec = match.group(1)
tabular_content = match.group(2)

# 进一步清理 columns_spec,移除可能残留的特殊字符
columns_spec = re.sub(r'[^lcrpmbX]', '', columns_spec)

# 统计列类型的个数作为列数
col_types = re.findall(r'[lcrpmbX]', columns_spec)
col_count = len(col_types)
column_counts.append(col_count)

# 行数统计,按照您的规则
# 统计 \\ 数量
row_count = tabular_content.count('\\\\')

# 检查最后一个 \\ 到 \end{tabular} 之间是否有 & 符号,补充行数
last_row_content = tabular_content.split('\\\\')[-1]
if '&' in last_row_content:
row_count += 1

row_counts.append(row_count)

# 返回第一个表格的列数和行数,并确保它们是整数
return int(column_counts[0]) if column_counts else 0, int(row_counts[0]) if row_counts else 0


def update_cols_and_rows(data_path, sub_path):
# 加载 data_path 和 sub_path 文件
data = load_json(data_path)
sub_data = load_json(sub_path)

# 将 data_path 中的内容转换为字典,便于快速查找
data_dict = {entry['image_path']: entry for entry in data}

# 遍历 sub_path 中的每个条目
for sub_entry in sub_data:
image_path = sub_entry.get("image_path", "")

# 查找 data_path 中对应的条目
data_entry = data_dict.get(image_path)

if data_entry:
source_code = data_entry.get("source_code", "")

if source_code:
try:
# 计算表格的列数和行数,并确保它们是整数
cols, rows = count_columns_and_rows(source_code)
sub_entry["cols"] = int(cols)
sub_entry["rows"] = int(rows)
print(f"Updated entry: image_path={image_path}, cols={cols}, rows={rows}")
except Exception as e:
print(f"Error processing source_code for image_path {image_path}: {e}")

# 保存更新后的sub_path文件
save_json(sub_data, sub_path)
print(f"Updated data saved to {sub_path}")

# 调用函数处理文件
if __name__ == '__main__':
# data_path = 'path_to_data_file.json' # 这里替换为data_path的实际路径
# sub_path = 'path_to_sub_file.json' # 这里替换为sub_path的实际路径
update_cols_and_rows(data_path, sub_path)

Updated data saved to /sample_submission_copy.json
代码
文本

4. 调用微调后的qwen2.5-14B回答分类问题和选择题

代码
文本
[5]
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# 提交时可能不能联网,设置成离线模式防止联网失败报错
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'

path = "/bohr/qwen-2-5-14b-instruct-rk4r/v1/Qwen2___5-14B-Instruct/" # 已下载好的模型路径
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

# 初始化配置
generation_config_dict = dict(max_new_tokens=64, do_sample=False)
generation_config = GenerationConfig(**generation_config_dict)

# 加载 sub_path 文件
with open(sub_path, 'r') as f:
sub_data = json.load(f)

with open(data_path, 'r') as f:
data = json.load(f)

# 模型推理和结果处理
for item in data:
image_path = item['image_path']
# 初始化消息列表
messages = [
{"role": "system", "content": "You are an AI assistant specialized in analyzing tables and answering related questions."}
]

# 第一个问题:判断表格属于哪个学科
question = (
f'This table is represented by the LaTeX code: "{item["source_code"]}" with the caption: "{item["caption"]}". '
f'Based on the table, along with the following question and options:\n'
f'Question: "{item["question"]}"\n'
f'Option 0: "{item["options"][0]}"\n'
f'Option 1: "{item["options"][1]}"\n'
f'Option 2: "{item["options"][2]}"\n'
f'Option 3: "{item["options"][3]}"\n'
f'Which subject does the table\'s content most likely relate to? Choose one from '
f'(Physics, Mathematics, ComputerScience, QuantitativeBiology, QuantitativeFinance, Statistics, '
f'ElectricalEngineeringandSystemsScience, Economics).'
)

# 添加用户问题到消息列表
messages.append({"role": "user", "content": question})
# 生成模型的响应
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=64)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
response1 = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response1)
# 添加助手的响应到消息列表
messages.append({"role": "assistant", "content": response1})

# 第二个问题:选择正确答案
question = (
f'Question: "{item["question"]}"\nOption 0: "{item["options"][0]}"\n'
f'Option 1: "{item["options"][1]}"\nOption 2: "{item["options"][2]}"\n'
f'Option 3: "{item["options"][3]}"\nBased on the table and the related information, '
f'select the correct answer from the options (0, 1, 2, or 3).'
)
# 添加用户问题到消息列表
messages.append({"role": "user", "content": question})
# 生成模型的响应
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=64)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
response3 = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response3)

# 处理响应
try:
if response1 not in ('Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance', 'Statistics', 'ElectricalEngineeringandSystemsScience', 'Economics'):
for cat in ('Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance', 'Statistics', 'ElectricalEngineeringandSystemsScience', 'Economics'):
if cat in response1:
response1 = cat
break
response1 = 'ComputerScience'
else:
pass
except:
response1 = 'ComputerScience'

try:
answer = int(response3[0])
except:
answer = 0

# 在 sub_data 中查找对应的 image_path
for sub_entry in sub_data:
if sub_entry["image_path"] == image_path:
# 更新 category 和 answer 字段
sub_entry["category"] = str(response1)
sub_entry["answer"] = answer
# 删除 source_code 字段
if "source_code" in sub_entry:
del sub_entry["source_code"]
break

# 保存提交结果
with open('submission.json', 'w') as f:
json.dump(sub_data, f, ensure_ascii=False, indent=4)

print("Submission saved to submission.json")
Loading checkpoint shards: 100%|██████████| 16/16 [00:02<00:00,  7.73it/s]
ComputerScience
3
ComputerScience
1
Mathematics
0
ComputerScience
1
ComputerScience
0
Submission saved to submission.json
代码
文本
AI4SCUP-LLMTable
多模态表格理解
AI4SCUP-LLMTable多模态表格理解
已赞9
推荐阅读
公开
AI4S Cup -LLM挑战赛-大模型提取“基因-疾病-药物”知识图谱-解决方案-不知道对不队-推理代码
notebookAI4SAI4SCUP-LLMKG
notebookAI4SAI4SCUP-LLMKG
bohrac44ed
发布于 2024-04-23
公开
AI4S Cup -LLM挑战赛-大模型提取“基因-疾病-药物”知识图谱-解决方案-[取名好难]-训练代码
AI4S cup 训练代码AI4SCUP-LLMKG
AI4S cup 训练代码AI4SCUP-LLMKG
取名好难
发布于 2024-04-20
1 转存文件