【AI實戰】從零開始微調Qwen2-VL模型:打造制造業智能安全巡檢系統
- 🎯 項目背景與目標
- 🛠 環境準備
- 硬件要求
- 軟件環境搭建
- 📊 數據準備:構建高質量訓練集
- 第一步:提取規章制度知識
- 第二步:創建標注數據集
- 第三步:數據集格式轉換
- 🤖 模型微調實現
- 加載預訓練模型
- 配置LoRA微調
- 數據預處理流水線
- 訓練配置與啟動
- 🔍 模型推理與測試
- 加載訓練好的模型
- 推理函數實現
- 批量測試
- 🚀 模型部署方案
- 方案1:Flask API服務
- 方案2:FastAPI高性能服務
- 方案3:Docker容器化部署
- 📊 性能評估與優化
- 評估指標設計
- 性能優化策略
- 📈 持續改進建議
- 數據增強策略
- 主動學習框架
- 🎉 總結與展望
- 🏆 核心成果
- 💡 關鍵優勢
- 🔮 未來發展方向
摘要:本文將手把手教你如何微調Qwen2-VL多模態大模型,構建一個能夠自動識別制造業安全違規行為的智能巡檢系統。從環境搭建到模型部署,提供完整的代碼實現和實踐經驗。
🎯 項目背景與目標
在制造業中,安全生產是重中之重。傳統的人工巡檢存在效率低、標準不一、漏檢等問題。本項目旨在通過微調Qwen2-VL多模態模型,實現:
- 📸 智能圖像分析:自動識別現場違規行為
- 📖 規章制度理解:基于企業規章制度進行判斷
- ? 實時檢測:快速響應,及時預警
- 🎯 精準定位:準確指出具體違規問題
技術棧:Qwen2-VL
+ LoRA微調
+ PyTorch
+ Transformers
🛠 環境準備
硬件要求
組件 | 最低配置 | 推薦配置 |
---|---|---|
GPU | 16GB顯存 | RTX 4090/A100 |
內存 | 32GB | 64GB |
存儲 | 100GB | 500GB SSD |
軟件環境搭建
# 🐍 創建Python環境
conda create -n qwen_vl python=3.10
conda activate qwen_vl# 🔥 安裝PyTorch(根據CUDA版本調整)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118# 📚 安裝核心依賴
pip install transformers==4.37.0
pip install accelerate peft datasets pillow opencv-python
pip install scikit-learn wandb qwen-vl-utils
💡 小貼士:建議使用conda管理環境,避免依賴沖突。
📊 數據準備:構建高質量訓練集
第一步:提取規章制度知識
import json
import re
from typing import List, Dictclass RuleExtractor:"""規章制度提取器"""def __init__(self):self.rule_patterns = [r'第(\d+)條[\s::](.*?)(?=第\d+條|$)', # 條款模式r'(\d+\.\d+)[\s::](.*?)(?=\d+\.\d+|$)', # 編號模式]def extract_from_text(self, text_content: str) -> List[Dict]:"""從文本中提取規則"""rules = []for pattern in self.rule_patterns:matches = re.findall(pattern, text_content, re.DOTALL)for match in matches:rule_id, content = matchrules.append({'id': rule_id,'content': content.strip(),'category': self._categorize_rule(content)})return rulesdef _categorize_rule(self, content: str) -> str:"""規則分類"""categories = {'安全防護': ['安全帽', '防護服', '安全帶', '護目鏡'],'操作規范': ['操作', '作業', '使用', '維護'],'現場管理': ['整理', '清潔', '擺放', '標識'],'應急處理': ['應急', '事故', '故障', '報告']}for category, keywords in categories.items():if any(keyword in content for keyword in keywords):return categoryreturn '其他'# 💼 使用示例
extractor = RuleExtractor()# 處理多個規章文件
rule_files = ['safety_regulations.txt','operation_manual.txt', 'quality_standards.txt'
]all_rules = {}
for file_path in rule_files:with open(file_path, 'r', encoding='utf-8') as f:content = f.read()rules = extractor.extract_from_text(content)all_rules[file_path] = rules# 保存規則庫
with open('rules_database.json', 'w', encoding='utf-8') as f:json.dump(all_rules, f, ensure_ascii=False, indent=2)print(f"? 成功提取 {sum(len(rules) for rules in all_rules.values())} 條規則")
第二步:創建標注數據集
import os
from PIL import Image
import jsonclass ViolationDatasetBuilder:"""違規檢測數據集構建器"""def __init__(self, rules_db: Dict, image_dir: str):self.rules_db = rules_dbself.image_dir = image_dirself.dataset = []def create_training_sample(self, image_path: str, violations: List[Dict]) -> Dict:"""創建單個訓練樣本"""# 構建違規描述violation_text = "根據制造業安全規章制度分析,發現以下問題:\n"for i, violation in enumerate(violations, 1):violation_text += f"{i}. {violation['description']}\n"violation_text += f" 違反規定:{violation['rule_reference']}\n"violation_text += f" 風險等級:{violation['risk_level']}\n"return {"image": image_path,"conversations": [{"from": "human","value": "<image>\n請根據制造業規章制度,識別圖片中的安全違規或操作不當之處,并說明具體違反了哪條規定。"},{"from": "assistant","value": violation_text.strip()}]}def build_dataset(self):"""構建完整數據集"""# 🏷? 示例標注數據(實際項目中需要大量人工標注)annotations = [{"image": "worker_no_helmet.jpg","violations": [{"description": "作業人員未佩戴安全帽,頭部缺乏有效防護","rule_reference": "《安全作業規程》第3.1條:進入作業區域必須佩戴安全帽","risk_level": "高風險"},{"description": "工作區域地面散落工具,存在絆倒風險","rule_reference": "《現場管理標準》第5.2條:作業現場應保持整潔有序","risk_level": "中風險"}]},{"image": "improper_lifting.jpg","violations": [{"description": "重物搬運姿勢錯誤,可能導致腰部損傷","rule_reference": "《人體工程學標準》第2.3條:搬運重物應采用正確姿勢","risk_level": "中風險"}]},{"image": "electrical_safety.jpg", "violations": [{"description": "電氣設備操作時未斷電,存在觸電危險","rule_reference": "《電氣安全規程》第1.2條:維修電氣設備前必須切斷電源","risk_level": "高風險"}]}]# 生成訓練樣本for ann in annotations:sample = self.create_training_sample(ann["image"], ann["violations"])self.dataset.append(sample)return self.dataset# 🔧 構建數據集
builder = ViolationDatasetBuilder(all_rules, "training_images/")
training_dataset = builder.build_dataset()# 保存數據集
with open('manufacturing_safety_dataset.json', 'w', encoding='utf-8') as f:json.dump(training_dataset, f, ensure_ascii=False, indent=2)print(f"? 數據集構建完成,共 {len(training_dataset)} 個樣本")
📝 數據標注建議:
- 質量優于數量:300個高質量標注勝過1000個粗糙標注
- 多人交叉驗證:確保標注的一致性和準確性
- 涵蓋典型場景:包含不同類型的違規情況
- 平衡數據分布:各類違規的樣本數量要相對均衡
第三步:數據集格式轉換
from datasets import Dataset, Features, Value, Image as HFImage
from sklearn.model_selection import train_test_splitdef prepare_huggingface_dataset(json_path: str):"""轉換為HuggingFace格式"""with open(json_path, 'r', encoding='utf-8') as f:data = json.load(f)# 🔄 數據格式轉換hf_data = []for item in data:hf_data.append({'image_path': item['image'],'conversations': json.dumps(item['conversations'], ensure_ascii=False)})# ?? 訓練驗證分割train_data, eval_data = train_test_split(hf_data, test_size=0.2, random_state=42,stratify=None # 可以根據違規類型進行分層抽樣)# 📦 創建Dataset對象train_dataset = Dataset.from_list(train_data)eval_dataset = Dataset.from_list(eval_data)return train_dataset, eval_dataset# 創建訓練和驗證數據集
train_ds, eval_ds = prepare_huggingface_dataset('manufacturing_safety_dataset.json')print(f"📈 訓練集:{len(train_ds)} 樣本")
print(f"📊 驗證集:{len(eval_ds)} 樣本")
🤖 模型微調實現
加載預訓練模型
from transformers import (Qwen2VLForConditionalGeneration,AutoProcessor,TrainingArguments,Trainer
)
from peft import LoraConfig, get_peft_model, TaskType
import torch# 🚀 模型加載
model_name = "Qwen/Qwen2-VL-7B-Instruct" # 可改為3B版本
print(f"🔄 正在加載模型: {model_name}")processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name,torch_dtype=torch.float16,device_map="auto",trust_remote_code=True
)print("? 模型加載完成")
配置LoRA微調
# 🎛? LoRA配置 - 關鍵參數說明
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, # 因果語言建模任務r=16, # 🔧 低秩矩陣的秩,影響參數量和效果lora_alpha=32, # 🔧 縮放因子,通常設為r的2倍lora_dropout=0.1, # 🔧 防止過擬合target_modules=[ # 🎯 目標模塊:注意力機制相關層"q_proj", "k_proj", "v_proj", "o_proj", # 注意力投影層"gate_proj", "up_proj", "down_proj" # MLP層],bias="none", # 不訓練bias參數inference_mode=False, # 訓練模式
)# 🔗 將LoRA應用到模型
model = get_peft_model(model, lora_config)# 📊 打印可訓練參數統計
trainable_params, all_params = model.get_nb_trainable_parameters()
print(f"🎯 可訓練參數: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)")
print(f"📊 總參數量: {all_params:,}")
數據預處理流水線
class SafetyViolationProcessor:"""安全違規檢測數據處理器"""def __init__(self, processor, max_length=1024):self.processor = processorself.max_length = max_lengthdef format_conversation(self, conversations_str: str) -> str:"""格式化對話"""conversations = json.loads(conversations_str)formatted = ""for conv in conversations:role = "用戶" if conv["from"] == "human" else "助手"content = conv["value"]formatted += f"{role}: {content}\n"return formatted.strip()def preprocess_batch(self, examples):"""批量預處理"""images = []texts = []for img_path, conv_str in zip(examples['image_path'], examples['conversations']):# 🖼? 加載圖像try:image = Image.open(img_path).convert('RGB')images.append(image)except Exception as e:print(f"?? 圖像加載失敗: {img_path}, 錯誤: {e}")# 創建占位圖像images.append(Image.new('RGB', (224, 224), color='white'))# 📝 格式化文本text = self.format_conversation(conv_str)texts.append(text)# 🔧 使用processor處理try:inputs = self.processor(text=texts,images=images,return_tensors="pt",padding=True,truncation=True,max_length=self.max_length)# 🏷? 設置標簽inputs["labels"] = inputs["input_ids"].clone()return inputsexcept Exception as e:print(f"? 預處理失敗: {e}")raise# 🛠? 創建處理器
data_processor = SafetyViolationProcessor(processor)# 📦 應用預處理
print("🔄 開始數據預處理...")
train_dataset = train_ds.map(data_processor.preprocess_batch,batched=True,batch_size=4,remove_columns=train_ds.column_names,desc="處理訓練數據"
)eval_dataset = eval_ds.map(data_processor.preprocess_batch,batched=True,batch_size=4,remove_columns=eval_ds.column_names,desc="處理驗證數據"
)print("? 數據預處理完成")
訓練配置與啟動
# 🎛? 訓練參數配置
training_args = TrainingArguments(# 📁 輸出設置output_dir="./qwen2-vl-safety-detector",run_name="safety-violation-detection",# 🔄 訓練設置 num_train_epochs=5, # 訓練輪數per_device_train_batch_size=1, # 單GPU批大小(根據顯存調整)per_device_eval_batch_size=1,gradient_accumulation_steps=8, # 梯度累積步數# 📈 學習率設置learning_rate=5e-5, # 學習率lr_scheduler_type="cosine", # 余弦學習率衰減warmup_ratio=0.1, # 預熱比例# 💾 保存設置save_strategy="steps",save_steps=100,save_total_limit=3, # 保留最近3個檢查點# 📊 評估設置evaluation_strategy="steps",eval_steps=50,# 📝 日志設置logging_steps=10,logging_first_step=True,# 🚀 優化設置fp16=True, # 混合精度訓練dataloader_pin_memory=True,dataloader_num_workers=4,remove_unused_columns=False,# 🎯 最佳模型選擇load_best_model_at_end=True,metric_for_best_model="eval_loss",greater_is_better=False,# 📊 實驗跟蹤(可選)report_to="none", # 可改為 "wandb" 使用WandB
)# 🎓 自定義訓練器
class SafetyDetectionTrainer(Trainer):def compute_loss(self, model, inputs, return_outputs=False):"""自定義損失計算"""labels = inputs.get("labels")outputs = model(**inputs)# 🎯 計算語言模型損失if labels is not None:shift_logits = outputs.logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),shift_labels.view(-1))else:loss = outputs.lossreturn (loss, outputs) if return_outputs else loss# 🚀 創建訓練器
trainer = SafetyDetectionTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=processor.tokenizer,
)# 🎯 開始訓練
print("🚀 開始微調訓練...")
print("=" * 50)try:# 📈 訓練過程trainer.train()# 💾 保存最終模型trainer.save_model()processor.save_pretrained(training_args.output_dir)print("🎉 訓練完成!")print(f"📁 模型保存位置: {training_args.output_dir}")except Exception as e:print(f"? 訓練失敗: {e}")raise
🔍 模型推理與測試
加載訓練好的模型
from peft import PeftModeldef load_finetuned_model(model_path: str):"""加載微調后的模型"""print(f"🔄 加載微調模型: {model_path}")# 🤖 加載基礎模型base_model = Qwen2VLForConditionalGeneration.from_pretrained(model_name,torch_dtype=torch.float16,device_map="auto",trust_remote_code=True)# 🔗 加載LoRA權重model = PeftModel.from_pretrained(base_model, model_path)model.eval()# 📝 加載處理器processor = AutoProcessor.from_pretrained(model_path)print("? 模型加載完成")return model, processor# 加載模型
finetuned_model, finetuned_processor = load_finetuned_model("./qwen2-vl-safety-detector"
)
推理函數實現
class SafetyViolationDetector:"""安全違規檢測器"""def __init__(self, model, processor):self.model = modelself.processor = processorself.device = next(model.parameters()).devicedef detect(self, image_path: str, return_details=True) -> Dict:"""檢測安全違規"""try:# 🖼? 加載圖像image = Image.open(image_path).convert('RGB')# 💭 構建提示prompt = """請仔細觀察這張制造業現場圖片,根據安全生產規章制度,識別其中的違規行為。請按以下格式回答:
1. 違規描述:[具體描述違規行為]
2. 違反規定:[引用具體的規章條款]
3. 風險等級:[高風險/中風險/低風險]
4. 改進建議:[具體的整改措施]如果沒有發現違規,請說明"未發現明顯違規行為"。"""# 🔧 構建輸入messages = [{"role": "user", "content": [{"type": "image", "image": image},{"type": "text", "text": prompt}]}]# 📝 應用聊天模板text_input = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)# 🔧 處理輸入inputs = self.processor(text=[text_input],images=[image],return_tensors="pt",padding=True).to(self.device)# 🎯 生成回答with torch.no_grad():outputs = self.model.generate(**inputs,max_new_tokens=512,temperature=0.1,do_sample=True,top_p=0.9,repetition_penalty=1.1)# 📤 解碼輸出generated_ids = outputs[0][len(inputs.input_ids[0]):]response = self.processor.decode(generated_ids,skip_special_tokens=True)# 🔍 解析結果if return_details:return self._parse_detection_result(response, image_path)else:return {"raw_response": response.strip()}except Exception as e:print(f"? 檢測失敗: {e}")return {"error": str(e)}def _parse_detection_result(self, response: str, image_path: str) -> Dict:"""解析檢測結果"""result = {"image_path": image_path,"timestamp": pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),"raw_response": response,"violations": [],"summary": {"total_violations": 0,"high_risk": 0,"medium_risk": 0,"low_risk": 0}}# 🔍 簡單的結果解析(可根據實際輸出格式優化)if "未發現明顯違規" in response:result["status"] = "safe"else:result["status"] = "violation_detected"# 這里可以添加更復雜的解析邏輯return result# 🎯 創建檢測器
detector = SafetyViolationDetector(finetuned_model, finetuned_processor)
批量測試
def batch_test_detector(test_images: List[str], detector: SafetyViolationDetector):"""批量測試檢測器"""results = []print(f"🧪 開始批量測試,共 {len(test_images)} 張圖片")for i, image_path in enumerate(test_images, 1):print(f"📸 處理第 {i}/{len(test_images)} 張: {image_path}")# 🔍 執行檢測result = detector.detect(image_path)results.append(result)# 📊 顯示結果摘要if result.get("status") == "violation_detected":print(" ?? 發現違規行為")else:print(" ? 未發現違規")return results# 🧪 測試圖片
test_images = ["test_images/worker_no_helmet.jpg","test_images/proper_operation.jpg", "test_images/messy_workplace.jpg"
]# 執行批量測試
test_results = batch_test_detector(test_images, detector)# 📊 統計測試結果
violation_count = sum(1 for r in test_results if r.get("status") == "violation_detected")
print(f"\n📈 測試完成!")
print(f"🔍 總測試圖片: {len(test_results)}")
print(f"?? 發現違規: {violation_count}")
print(f"? 安全圖片: {len(test_results) - violation_count}")
🚀 模型部署方案
方案1:Flask API服務
from flask import Flask, request, jsonify, render_template_string
import base64
from io import BytesIO
import uuid
import osapp = Flask(__name__)# 🌐 全局變量
global_detector = Nonedef init_detector():"""初始化檢測器"""global global_detectorprint("🚀 初始化安全違規檢測器...")model, processor = load_finetuned_model("./qwen2-vl-safety-detector")global_detector = SafetyViolationDetector(model, processor)print("? 檢測器初始化完成")@app.route('/', methods=['GET'])
def home():"""主頁"""html_template = """<!DOCTYPE html><html><head><title>🏭 制造業安全巡檢系統</title><meta charset="utf-8"><style>body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }.container { background: #f5f5f5; padding: 20px; border-radius: 10px; }.result { margin-top: 20px; padding: 15px; background: white; border-radius: 5px; }.violation { border-left: 4px solid #ff4444; }.safe { border-left: 4px solid #44ff44; }input[type="file"] { margin: 10px 0; }button { background: #007cba; color: white; padding: 10px 20px; border: none; border-radius: 5px; cursor: pointer; }button:hover { background: #005a8a; }</style></head><body><div class="container"><h1>🏭 制造業安全巡檢系統</h1><p>上傳制造現場圖片,AI將自動識別安全違規行為</p><form id="uploadForm" enctype="multipart/form-data"><input type="file" id="imageFile" accept="image/*" required><button type="submit">🔍 開始檢測</button></form><div id="result" style="display:none;"></div></div><script>document.getElementById('uploadForm').onsubmit = async function(e) {e.preventDefault();const fileInput = document.getElementById('imageFile');const file = fileInput.files[0];if (!file) {alert('請選擇圖片文件');return;}// 顯示加載狀態document.getElementById('result').innerHTML = '<p>🔄 正在檢測中,請稍候...</p>';document.getElementById('result').style.display = 'block';// 轉換為base64const reader = new FileReader();reader.onload = async function(e) {const base64 = e.target.result.split(',')[1];try {const response = await fetch('/detect', {method: 'POST',headers: { 'Content-Type': 'application/json' },body: JSON.stringify({image: base64,filename: file.name})});const result = await response.json();if (result.success) {let html = '<h3>🔍 檢測結果</h3>';if (result.data.status === 'violation_detected') {html += '<div class="result violation">';html += '<h4>?? 發現安全違規</h4>';html += '<pre>' + result.data.raw_response + '</pre>';html += '</div>';} else {html += '<div class="result safe">';html += '<h4>? 未發現違規行為</h4>';html += '<p>現場安全狀況良好</p>';html += '</div>';}document.getElementById('result').innerHTML = html;} else {document.getElementById('result').innerHTML = '<div class="result"><h4>? 檢測失敗</h4><p>' + result.error + '</p></div>';}} catch (error) {document.getElementById('result').innerHTML = '<div class="result"><h4>? 網絡錯誤</h4><p>' + error.message + '</p></div>';}};reader.readAsDataURL(file);};</script></body></html>"""return html_template@app.route('/detect', methods=['POST'])
def detect_violations():"""檢測接口"""try:# 📨 獲取請求數據data = request.get_json()if not data or 'image' not in data:return jsonify({'success': False,'error': '缺少圖片數據'})# 🖼? 解碼圖片image_data = base64.b64decode(data['image'])# 💾 臨時保存圖片temp_filename = f"temp_{uuid.uuid4().hex}.jpg"temp_path = os.path.join("temp", temp_filename)# 確保臨時目錄存在os.makedirs("temp", exist_ok=True)with open(temp_path, 'wb') as f:f.write(image_data)# 🔍 執行檢測result = global_detector.detect(temp_path)# 🗑? 清理臨時文件try:os.remove(temp_path)except:passreturn jsonify({'success': True,'data': result,'message': '檢測完成'})except Exception as e:print(f"? API錯誤: {e}")return jsonify({'success': False,'error': str(e)})@app.route('/health', methods=['GET'])
def health_check():"""健康檢查"""return jsonify({'status': 'healthy','service': 'safety-violation-detector','version': '1.0.0'})if __name__ == '__main__':# 🚀 啟動服務init_detector()print("🌐 啟動Flask服務...")print("🔗 訪問地址: http://localhost:5000")app.run(host='0.0.0.0',port=5000,debug=False, # 生產環境設為Falsethreaded=True)
方案2:FastAPI高性能服務
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
from pydantic import BaseModel
import aiofiles
import asyncio# 📋 數據模型
class DetectionRequest(BaseModel):image: str # base64編碼的圖片options: dict = {}class DetectionResponse(BaseModel):success: booldata: dict = Noneerror: str = None# 🚀 創建FastAPI應用
app = FastAPI(title="🏭 制造業安全巡檢API",description="基于Qwen2-VL的智能安全違規檢測系統",version="1.0.0"
)# 🌐 全局檢測器
detector_instance = None@app.on_event("startup")
async def startup_event():"""應用啟動時初始化"""global detector_instanceprint("🚀 初始化檢測器...")model, processor = load_finetuned_model("./qwen2-vl-safety-detector")detector_instance = SafetyViolationDetector(model, processor)print("? 檢測器初始化完成")@app.post("/api/detect", response_model=DetectionResponse)
async def api_detect(request: DetectionRequest):"""異步檢測接口"""try:# 解碼圖片image_data = base64.b64decode(request.image)# 保存臨時文件temp_filename = f"temp_{uuid.uuid4().hex}.jpg"temp_path = f"temp/{temp_filename}"async with aiofiles.open(temp_path, 'wb') as f:await f.write(image_data)# 執行檢測(在線程池中運行CPU密集型任務)loop = asyncio.get_event_loop()result = await loop.run_in_executor(None, detector_instance.detect, temp_path)# 清理臨時文件try:os.remove(temp_path)except:passreturn DetectionResponse(success=True, data=result)except Exception as e:raise HTTPException(status_code=500, detail=str(e))@app.post("/api/upload")
async def upload_image(file: UploadFile = File(...)):"""文件上傳檢測"""try:# 驗證文件類型if not file.content_type.startswith('image/'):raise HTTPException(status_code=400, detail="只支持圖片文件")# 讀取文件內容contents = await file.read()# 保存臨時文件temp_filename = f"temp_{uuid.uuid4().hex}_{file.filename}"temp_path = f"temp/{temp_filename}"async with aiofiles.open(temp_path, 'wb') as f:await f.write(contents)# 執行檢測loop = asyncio.get_event_loop()result = await loop.run_in_executor(None,detector_instance.detect,temp_path)# 清理臨時文件try:os.remove(temp_path)except:passreturn DetectionResponse(success=True, data=result)except Exception as e:raise HTTPException(status_code=500, detail=str(e))@app.get("/health")
async def health_check():"""健康檢查"""return {"status": "healthy","service": "safety-violation-detector","version": "1.0.0"}if __name__ == "__main__":# 🚀 啟動服務uvicorn.run("main:app",host="0.0.0.0", port=8000,workers=1, # 由于模型較大,建議單進程reload=False)
方案3:Docker容器化部署
# Dockerfile
FROM nvidia/cuda:11.8-devel-ubuntu20.04# 🐍 安裝Python和基礎工具
RUN apt-get update && apt-get install -y \python3 \python3-pip \python3-dev \wget \curl \&& rm -rf /var/lib/apt/lists/*# 📁 設置工作目錄
WORKDIR /app# 📋 復制依賴文件
COPY requirements.txt .# 📦 安裝Python依賴
RUN pip3 install --no-cache-dir -r requirements.txt# 📁 復制應用代碼
COPY . .# 📂 創建必要目錄
RUN mkdir -p temp logs# 🚀 暴露端口
EXPOSE 8000# 🎯 啟動命令
CMD ["python3", "app.py"]
# docker-compose.yml
version: '3.8'services:safety-detector:build: .ports:- "8000:8000"volumes:- ./models:/app/models- ./temp:/app/temp- ./logs:/app/logsenvironment:- CUDA_VISIBLE_DEVICES=0deploy:resources:reservations:devices:- driver: nvidiacount: 1capabilities: [gpu]restart: unless-stoppednginx:image: nginx:alpineports:- "80:80"- "443:443"volumes:- ./nginx.conf:/etc/nginx/nginx.conf- ./ssl:/etc/nginx/ssldepends_on:- safety-detectorrestart: unless-stopped
📊 性能評估與優化
評估指標設計
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as snsclass ModelEvaluator:"""模型評估器"""def __init__(self, detector):self.detector = detectorself.results = []def evaluate_on_testset(self, test_data_path: str):"""在測試集上評估"""print("📊 開始模型評估...")# 加載測試數據with open(test_data_path, 'r', encoding='utf-8') as f:test_data = json.load(f)predictions = []ground_truths = []for i, item in enumerate(test_data):print(f"📸 處理測試樣本 {i+1}/{len(test_data)}")# 獲取預測結果result = self.detector.detect(item['image'])pred_status = result.get('status', 'unknown')predictions.append(pred_status)# 獲取真實標簽conversations = item['conversations']true_answer = conversations[1]['value'] # assistant回答# 簡單判斷:如果回答中包含違規信息則為違規if "違反" in true_answer or "違規" in true_answer:true_status = "violation_detected"else:true_status = "safe"ground_truths.append(true_status)# 保存詳細結果self.results.append({'image': item['image'],'prediction': pred_status,'ground_truth': true_status,'prediction_text': result.get('raw_response', ''),'ground_truth_text': true_answer})# 計算指標metrics = self._calculate_metrics(predictions, ground_truths)self._generate_report(metrics)return metricsdef _calculate_metrics(self, predictions, ground_truths):"""計算評估指標"""# 轉換為二分類:違規 vs 安全y_true = [1 if gt == "violation_detected" else 0 for gt in ground_truths]y_pred = [1 if pred == "violation_detected" else 0 for pred in predictions]# 計算基礎指標accuracy = accuracy_score(y_true, y_pred)precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average='binary')metrics = {'accuracy': accuracy,'precision': precision,'recall': recall, 'f1_score': f1,'total_samples': len(y_true),'violation_samples': sum(y_true),'safe_samples': len(y_true) - sum(y_true)}return metricsdef _generate_report(self, metrics):"""生成評估報告"""print("\n" + "="*50)print("📈 模型性能評估報告")print("="*50)print(f"🎯 準確率 (Accuracy): {metrics['accuracy']:.3f}")print(f"🔍 精確率 (Precision): {metrics['precision']:.3f}")print(f"📊 召回率 (Recall): {metrics['recall']:.3f}")print(f"?? F1分數: {metrics['f1_score']:.3f}")print(f"📊 測試樣本總數: {metrics['total_samples']}")print(f"?? 違規樣本數: {metrics['violation_samples']}")print(f"? 安全樣本數: {metrics['safe_samples']}")# 保存詳細結果df_results = pd.DataFrame(self.results)df_results.to_csv('evaluation_results.csv', index=False, encoding='utf-8')print(f"📁 詳細結果已保存至: evaluation_results.csv")# 🧪 執行評估
evaluator = ModelEvaluator(detector)
metrics = evaluator.evaluate_on_testset('test_dataset.json')
性能優化策略
class ModelOptimizer:"""模型優化器"""def __init__(self, model, processor):self.model = modelself.processor = processordef optimize_inference(self):"""推理優化"""# 🚀 1. 模型量化print("🔧 應用模型量化...")try:from torch.quantization import quantize_dynamicquantized_model = quantize_dynamic(self.model.cpu(),{torch.nn.Linear},dtype=torch.qint8)print("? 量化完成,模型大小減少約50%")return quantized_model.cuda()except Exception as e:print(f"?? 量化失敗: {e}")return self.modeldef enable_batch_processing(self, batch_size=4):"""批處理優化"""def batch_detect(image_paths: List[str]):"""批量檢測"""results = []for i in range(0, len(image_paths), batch_size):batch_paths = image_paths[i:i+batch_size]batch_images = []batch_texts = []for path in batch_paths:image = Image.open(path).convert('RGB')batch_images.append(image)batch_texts.append("請檢測圖片中的安全違規行為")# 批量處理inputs = self.processor(text=batch_texts,images=batch_images,return_tensors="pt",padding=True).to(self.model.device)with torch.no_grad():outputs = self.model.generate(**inputs,max_new_tokens=256,do_sample=False)# 解碼結果for j, output in enumerate(outputs):generated_ids = output[len(inputs.input_ids[j]):]response = self.processor.decode(generated_ids, skip_special_tokens=True)results.append({'image_path': batch_paths[j],'response': response.strip()})return resultsreturn batch_detectdef setup_caching(self):"""設置緩存機制"""import hashlibfrom functools import lru_cache# 圖片哈希緩存image_cache = {}def cached_detect(image_path: str):"""帶緩存的檢測"""# 計算圖片哈希with open(image_path, 'rb') as f:image_hash = hashlib.md5(f.read()).hexdigest()# 檢查緩存if image_hash in image_cache:print(f"🎯 緩存命中: {image_path}")return image_cache[image_hash]# 執行檢測result = self.detector.detect(image_path)# 保存到緩存image_cache[image_hash] = resultreturn resultreturn cached_detect# 🚀 應用優化
optimizer = ModelOptimizer(finetuned_model, finetuned_processor)# 量化優化
optimized_model = optimizer.optimize_inference()# 批處理優化
batch_detector = optimizer.enable_batch_processing(batch_size=2)# 緩存優化
cached_detector = optimizer.setup_caching()
📈 持續改進建議
數據增強策略
class DataAugmentation:"""數據增強器"""def __init__(self):self.transforms = [self.brightness_adjustment,self.contrast_adjustment, self.gaussian_blur,self.random_rotation,self.random_crop]def brightness_adjustment(self, image, factor_range=(0.7, 1.3)):"""亮度調整"""from PIL import ImageEnhancefactor = np.random.uniform(*factor_range)enhancer = ImageEnhance.Brightness(image)return enhancer.enhance(factor)def contrast_adjustment(self, image, factor_range=(0.8, 1.2)):"""對比度調整"""from PIL import ImageEnhancefactor = np.random.uniform(*factor_range)enhancer = ImageEnhance.Contrast(image)return enhancer.enhance(factor)def gaussian_blur(self, image, radius_range=(0.5, 2.0)):"""高斯模糊"""from PIL import ImageFilterradius = np.random.uniform(*radius_range)return image.filter(ImageFilter.GaussianBlur(radius=radius))def augment_dataset(self, original_dataset, augment_factor=3):"""增強數據集"""augmented_data = []for item in original_dataset:# 保留原始數據augmented_data.append(item)# 生成增強數據original_image = Image.open(item['image'])for i in range(augment_factor):# 隨機選擇變換transform = np.random.choice(self.transforms)augmented_image = transform(original_image.copy())# 保存增強圖片aug_filename = f"aug_{i}_{item['image']}"augmented_image.save(aug_filename)# 創建新的訓練樣本aug_item = item.copy()aug_item['image'] = aug_filenameaugmented_data.append(aug_item)return augmented_data# 🔄 應用數據增強
augmenter = DataAugmentation()
enhanced_dataset = augmenter.augment_dataset(training_dataset, augment_factor=2)
主動學習框架
class ActiveLearning:"""主動學習系統"""def __init__(self, model, processor, uncertainty_threshold=0.7):self.model = modelself.processor = processorself.uncertainty_threshold = uncertainty_thresholdself.uncertain_samples = []def calculate_uncertainty(self, image_path: str):"""計算預測不確定性"""# 多次采樣獲取預測分布predictions = []for _ in range(10): # 10次采樣result = self.detector.detect(image_path)# 這里需要根據具體輸出格式計算不確定性# 簡化示例:基于關鍵詞統計violations = result.get('raw_response', '')confidence = self._estimate_confidence(violations)predictions.append(confidence)# 計算不確定性(預測方差)uncertainty = np.var(predictions)return uncertaintydef _estimate_confidence(self, response_text: str):"""估算響應置信度"""confidence_words = ['明顯', '清楚', '確實', '肯定']uncertainty_words = ['可能', '似乎', '疑似', '不確定']confidence_score = sum(word in response_text for word in confidence_words)uncertainty_score = sum(word in response_text for word in uncertainty_words)return confidence_score - uncertainty_scoredef identify_hard_samples(self, image_paths: List[str]):"""識別困難樣本"""hard_samples = []for path in image_paths:uncertainty = self.calculate_uncertainty(path)if uncertainty > self.uncertainty_threshold:hard_samples.append({'image_path': path,'uncertainty': uncertainty})# 按不確定性排序hard_samples.sort(key=lambda x: x['uncertainty'], reverse=True)return hard_samplesdef suggest_annotation_priority(self, candidate_images: List[str], budget=50):"""建議標注優先級"""print(f"🎯 分析 {len(candidate_images)} 張候選圖片...")hard_samples = self.identify_hard_samples(candidate_images)# 選擇最困難的樣本priority_samples = hard_samples[:budget]print(f"📋 建議優先標注以下 {len(priority_samples)} 張圖片:")for i, sample in enumerate(priority_samples, 1):print(f" {i}. {sample['image_path']} (不確定性: {sample['uncertainty']:.3f})")return priority_samples# 🎯 主動學習應用
active_learner = ActiveLearning(finetuned_model, finetuned_processor)# 識別需要標注的困難樣本
candidate_images = ["unlabeled_1.jpg", "unlabeled_2.jpg", ...] # 未標注圖片
priority_samples = active_learner.suggest_annotation_priority(candidate_images)
🎉 總結與展望
通過本文的詳細介紹,我們完成了一個完整的制造業安全巡檢系統的構建,主要包括:
🏆 核心成果
- 📚 知識提取:從規章制度文檔中自動提取結構化規則
- 🎯 模型微調:使用LoRA技術高效微調Qwen2-VL模型
- 🚀 系統部署:提供多種部署方案,支持生產環境
- 📊 性能優化:通過量化、批處理、緩存等技術提升效率
- 🔄 持續改進:建立數據增強和主動學習機制
💡 關鍵優勢
- 💰 成本效益:相比從零訓練,微調成本降低90%+
- ? 快速部署:從數據準備到上線僅需1-2周
- 🎯 高度定制:完全適配企業特定的規章制度
- 📈 持續優化:支持在線學習和模型迭代
🔮 未來發展方向
- 多模態融合:結合文本、圖像、視頻、傳感器數據
- 實時檢測:視頻流實時分析和預警
- 邊緣計算:模型輕量化,支持移動端部署
- 聯邦學習:多工廠協作訓練,保護數據隱私
- 可解釋AI:提供檢測決策的詳細解釋
🙏 致謝:感謝阿里巴巴通義千問團隊提供的優秀開源模型,為工業AI應用提供了強大的技術基礎。
聲明:本文檔僅供技術學習和研究使用,在實際生產環境中應用時請充分測試并遵循相關安全規范。