基于Qwen2-VL模型針對LaTeX OCR任務進行微調訓練 - 多圖推理
flyfish
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_LoRA配置如何寫
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_單圖推理
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_原模型_單圖推理
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_原模型_多圖推理
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_多圖推理
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_數據處理
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_訓練
基于Qwen2-VL模型針對LaTeX_OCR任務進行微調訓練_-_訓練過程
輸入兩張圖像
輸出
可視化
Image 1:
E m m ˉ = 2 7 Q c π 1 / 2 Γ ( 1 / 4 ) 2 log ? ( L 0 / L ) L ∫ 1 ∞ d y y 2 y 4 ? 1 . E _ { m \bar { m } } = \frac { 2 ^ { 7 } \sqrt { Q _ { c } } \pi ^ { 1 / 2 } } { \Gamma ( 1 / 4 ) ^ { 2 } } \frac { \log \left( L _ { 0 } / L \right) } { L } \int _ { 1 } ^ { \infty } d y \frac { y ^ { 2 } } { \sqrt { y ^ { 4 } - 1 } } . Emmˉ?=Γ(1/4)227Qc??π1/2?Llog(L0?/L)?∫1∞?dyy4?1?y2?.
Image 2:
u ( τ )  ̄ = u ( ? τ ˉ ) , u ( τ + 1 ) = ? u ( τ ) , \overline { { u ( \tau ) } } = u ( - \bar { \tau } ) , \qquad \qquad u ( \tau + 1 ) = - u ( \tau ) , u(τ)?=u(?τˉ),u(τ+1)=?u(τ),
import argparse
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from peft import PeftModel, LoraConfig, TaskType
import torchclass LaTeXOCR:def __init__(self, local_model_path, lora_model_path):self.local_model_path = local_model_pathself.lora_model_path = lora_model_pathself._load_model_and_processor()def _load_model_and_processor(self):config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",],inference_mode=True,r=64,lora_alpha=16,lora_dropout=0.05,bias="none",)self.model = Qwen2VLForConditionalGeneration.from_pretrained(self.local_model_path, torch_dtype=torch.float16, device_map="auto")self.model = PeftModel.from_pretrained(self.model, self.lora_model_path, config=config)self.processor = AutoProcessor.from_pretrained(self.local_model_path)def generate_latex_from_images(self, test_image_paths, prompt):"""根據給定的測試圖像路徑列表和提示信息,生成對應的LaTeX格式文本。參數:test_image_paths (list of str): 包含數學公式的測試圖像路徑列表。prompt (str): 提供給模型的提示信息。返回:list of str: 轉換后的LaTeX格式文本列表。"""results = []for image_path in test_image_paths:messages = [{"role": "user","content": [{"type": "image","image": image_path,"resized_height": 100,"resized_width": 500,},{"type": "text", "text": prompt},],}]text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = self.processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")with torch.no_grad():generated_ids = self.model.generate(**inputs, max_new_tokens=8192)generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]output_text = self.processor.batch_decode(generated_ids_trimmed,skip_special_tokens=True,clean_up_tokenization_spaces=False,)results.append(output_text[0])return resultsdef parse_arguments():parser = argparse.ArgumentParser(description="LaTeX OCR using Qwen2-VL")parser.add_argument("--local_model_path",type=str,default="./Qwen/Qwen2-VL-7B-Instruct",help='Path to the local model.',)parser.add_argument("--lora_model_path",type=str,default="./output/Qwen2-VL-7B-LatexOCR/checkpoint-1500",help='Path to the LoRA model checkpoint.',)parser.add_argument("--test_image_paths",nargs='+', # 接受多個參數type=str,default=["./LaTeX_OCR/987.jpg", "./LaTeX_OCR/986.jpg"], # 設置默認值為兩個圖像路徑help='Paths to the test images.',)return parser.parse_args()if __name__ == "__main__":args = parse_arguments()prompt = ("尊敬的Qwen2VL大模型,我需要你幫助我將一張包含數學公式的圖片轉換成LaTeX格式的文本。\n""請按照以下說明進行操作:\n""1. **圖像中的內容**: 圖像中包含的是一個或多個數學公式,請確保準確地識別并轉換為LaTeX代碼。\n""2. **公式識別**: 請專注于識別和轉換數學符號、希臘字母、積分、求和、分數、指數等數學元素。\n""3. **LaTeX語法**: 輸出時使用標準的LaTeX語法。確保所有的命令都是正確的,并且可以被LaTeX編譯器正確解析。\n""4. **結構保持**: 如果圖像中的公式有特定的結構(例如多行公式、矩陣、方程組),請在輸出的LaTeX代碼中保留這些結構。\n""5. **上下文無關**: 不要嘗試解釋公式的含義或者添加額外的信息,只需嚴格按照圖像內容轉換。\n""6. **格式化**: 如果可能的話,使輸出的LaTeX代碼易于閱讀,比如適當添加空格和換行。")latex_ocr = LaTeXOCR(args.local_model_path, args.lora_model_path)results = latex_ocr.generate_latex_from_images(args.test_image_paths, prompt)for i, result in enumerate(results):print(f"Image {i + 1}:")print(result)print("-" * 80)