新建
AI4S Cup - LLM挑战赛 - 多模态表格识别与理解-解决方案-kwe3n7s-训练代码-推理代码


更新于 2024-12-19
推荐镜像 :qwen2vl-latex-swift:0.2
推荐机型 :c12_m92_1 * NVIDIA V100
赞
1
目录
数据集
qwen2-vl-7b-instruct-vqa-v5(v2)
lora(v2)
tableDataset(v6)
表格识别A榜测试集(v4)
[ ]
# 训练代码脚本(参考下方的github)
# https://github.com/Copilot-X/MMFR
代码
文本
1、分类任务用qwen2.5-7b-instruct进行微调
# 训练数据来源(arxivQA采样2k)
# 受测试集的比例影响,训练采样的比例CS偏重
# 复赛分类表现受这个采样的影响
2、表格行列识别用的/StructEqTable进行处理
# https://hf-mirror.com/U4R/StructTable-base
# 将base模型转换为trt推理加速
# # 推理脚本(这个推理的GPU不适用trt的环境编译安装,V100不支持bf16)
3、表格问答采用的是qwen2-vl
# glm4v随机采样伪标50条
# 微调3epoch,使模型服从指令输出
# query = f"You are an expert in diagram comprehension, combining questions and images to answer.\n##question: {item['question']}\n##options:\n{new_options}"
代码
文本
[ ]
# 推理脚本(这个推理的GPU不适用trt的环境编译安装,V100不支持bf16)
代码
文本
[ ]
! cp -r /bohr/tableDataset-f1ba/v6/StructEqTable-Deploy/ .
代码
文本
[ ]
%%bash
cd StructEqTable-Deploy/tools/
bash scripts/build_tensorrt.sh
代码
文本
[5]
import os
import sys
import json
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm
import re
import random
import pandas as pd
import argparse
代码
文本
[ ]
# 分类模型识别
代码
文本
[ ]
import os
import json
import sys
sys.path.append('/bohr/tableDataset-f1ba/v6/ms-swift')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from tqdm import tqdm
from swift.llm import (
get_model_tokenizer, get_template, inference, ModelType, get_default_template_type,
)
from swift.utils import seed_everything
from swift.tuners import Swift
seed_everything(42)
# ckpt_dir = '/bohr/tableDataset-f1ba/v6/Qwen2.5-7B-Instruct_category/'
model_type = ModelType.qwen2_5_7b_instruct
template_type = get_default_template_type(model_type)
model_id_or_path = "/bohr/Qwen2-5-7B-Instruct-category-merge-lora-5wny/v2/"
model, tokenizer = get_model_tokenizer(model_type, model_id_or_path=model_id_or_path, model_kwargs={'device_map': 'auto'})
model.generation_config.max_new_tokens = 16
# model = Swift.from_pretrained(model, ckpt_dir, inference_mode=True)
template = get_template(template_type, tokenizer)
labels = ['Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance', 'Statistics', 'ElectricalEngineeringandSystemsScience', 'Economics']
category_prompt = f"You're an expert in classifying papers. According to the provided table associated with the textual description, which subjuct does the content most likely belong to? Choice one from {labels}."
def generate_category(content):
query = f"{category_prompt}\n\n{content}"
response, history = inference(model, template, query)
response = response.strip()
if response not in labels:
response = "ComputerScience"
return response.strip()
代码
文本
[ ]
latex_model_path = "/bohr/tableDataset-f1ba/v6/tableDataset/"
vqa_path = "/bohr/qwen2-vl-7b-instruct-vqa-v5-m0lm/v2"
model_type = ModelType.qwen2_vl_7b_instruct
template_type = get_default_template_type(model_type)
print(f'template_type: {template_type}')
vqa_model, vqa_tokenizer = get_model_tokenizer(model_type, torch.bfloat16,model_id_or_path=vqa_path,
model_kwargs={'device_map': 'auto'})
vqa_model.generation_config.max_new_tokens = 256
vqa_template = get_template(template_type, vqa_tokenizer)
seed_everything(42)
def generate_vqa(image_path, query):
images = [image_path]
response, _ = inference(vqa_model, vqa_template, query, images=images)
return response.strip()
代码
文本
直接推理
代码
文本
[ ]
if os.environ.get('DATA_PATH_B'): # 提交时会选择隐藏的测试数据集路径(A+B榜),数据集的格式与A榜数据相同,但数目不同(5360张)
base_dir = os.environ.get('DATA_PATH_B')
else:
base_dir = '/bohr/form-recognition-train-b6y2/v4' # 示例,把A榜测试数据集路径作为测试集路径,仅开发时挂载A榜数据用于debug
with open(os.path.join(base_dir, 'dataset.json'), 'r') as f:
data = json.load(f)
with open(os.path.join(base_dir, 'sample_submission.json'), 'r') as f:
sub = json.load(f)
代码
文本
[ ]
sys.path.append(latex_model_path)
from struct_eqtable import build_model as build_latex_model
latex_model_dir = os.path.join(latex_model_path, "StructTable-base")
tensorrt_path = "StructEqTable-Deploy/ckpts/StructTable-base-TensorRT/"
def calc_rows_cols(latex_code):
rows = latex_code.split(re.escape('\\'))
num_rows = sum(1 for row in rows if '&' in row or '\\rule' in row)
if True:
num_cols = 0
for row in rows:
col = row.split("&")
if len(col) > num_cols:
num_cols = len(col)
return num_rows, num_cols
latex_model = build_latex_model(
latex_model_dir,
max_new_tokens=2048,
max_time=60,
tensorrt_path=tensorrt_path
)
def parse_latex(image_path):
raw_image = Image.open(image_path)
with torch.no_grad():
output = latex_model(raw_image)
if len(output) > 0:
num_rows, num_cols = calc_rows_cols(output[0])
else:
num_rows = 4
num_cols = 4
return num_rows, num_cols
代码
文本
[ ]
# prompt4 = f"<image>Table caption: ##caption\n\n question:##question\n\n options:##options\n\n please choice the correct answer from options. think step bu step, Just output the answer"
idx2opt = {0: "A", 1: "B", 2: "C", 3: "D"}
opt2idx = {v:k for k, v in idx2opt.items()}
submission = []
half_num = int(len(data))
for idx, item in tqdm(enumerate(data)):
image_path = os.path.join(base_dir, 'test_images', item["image_path"])
image = Image.open(image_path).convert('RGB')
caption = item['caption']
question = item['question']
content = caption + " [SEP] " + question
try:
response1 = generate_category(content)
except Exception as e:
print(e)
response1 = "ComputerScience"
try:
rows, cols = parse_latex(image_path)
except Exception as e:
print(e)
rows = 6
cols = 4
options = item['options']
# new_options = "\n"
# for idx, op in enumerate(options):
# line = idx2opt[idx] + ": " + op + "\n"
# new_options += line
new_options = ""
for idx, op in enumerate(options):
line = idx2opt[idx] + ". " + op + "\n"
new_options += line
try:
# query = f"You are an expert in diagram comprehension, combining questions and images to answer.\n##question: {item["question"]}\n##options:\n{new_options}"
query = f"You are an expert in diagram comprehension, combining questions and images to answer.\n##question: {item['question']}\n##options:\n{new_options}"
# query = prompt4.replace('##caption', item["caption"]).replace('##question',item["question"]).replace('##options', new_options)
res3 = generate_vqa(image_path, query)
response3 = opt2idx.get(res3, "2")
# print("===response3===", response3)
except Exception as e:
print(e)
response3 = 0
try:
rows = int(rows)
except:
rows = 6
try:
cols = int(cols)
except:
cols = 4
try:
answer = int(response3)
except:
answer = 0
sub_item = {
"image_path": item["image_path"],
"category": str(response1),
"cols": cols,
"rows": rows,
"answer": answer,
}
submission.append(sub_item)
print(sub_item)
with open('submission.json', 'w') as f:
json.dump(submission, f)
代码
文本
[ ]
代码
文本
[ ]
代码
文本
点个赞吧