使用 Rust 和 FasterR-CNN 進行目標檢測
FasterR-CNN 是目標檢測領域廣泛使用的深度學習模型。Rust 生態中可以通過 tch-rs
(Torch 綁定)調用預訓練的 PyTorch 模型實現。以下為完整實現步驟:
環境準備
安裝 Rust 和必要的依賴:
cargo add tch
cargo add anyhow # 錯誤處理
下載預訓練的 FasterR-CNN 模型(需 PyTorch 格式 .pt
文件),或使用 TorchScript 格式模型。示例中使用 fasterrcnn_resnet50_fpn
。
加載預訓練模型
use tch::{nn, Device, Tensor, Kind};fn load_model(model_path: &str) -> anyhow::Result<nn::Module> {let device = Device::cuda_if_available();let model = nn::Module::load(model_path, device)?;Ok(model)
}
圖像預處理
將輸入圖像轉換為模型需要的格式(歸一化 + 標準化):
use tch::vision::image;fn preprocess_image(img_path: &str) -> anyhow::Result<Tensor> {let image = image::load(img_path)?;let resized = image.resize(800, 800); // FasterR-CNN 典型輸入尺寸let tensor = resized.to_kind(Kind::Float) / 255.0;let mean = Tensor::of_slice(&[0.485, 0.456, 0.406]).view([3, 1, 1]);let std = Tensor::of_slice(&[0.229, 0.224, 0.225]).view([3, 1, 1]);Ok((tensor - mean) / std)
}
運行推理
執行目標檢測并獲取結果:
fn run_detection(model: &nn::Module, input_tensor: &Tensor) -> anyhow::Result<(Tensor, Tensor)> {let output = model.forward_ts(&[input_tensor.unsqueeze(0)])?;let boxes = output.get(0).unwrap();let scores = output.get(1).unwrap();Ok((boxes, scores))
}
后處理與可視化
過濾低置信度檢測結果并繪制邊框:
use tch::IndexOp;fn filter_results(bboxes: &Tensor, scores: &Tensor, threshold: f64) -> Vec<(Vec<f64>, f64)> {let mut detections = Vec::new();for i in 0..scores.size()[0] {if scores.double_value(&[i]) > threshold {let bbox = bboxes.i(i).to_kind(Kind::Double).to_vec::<f64>().unwrap();detections.push((bbox, scores.double_value(&[i])));}}detections
}
使用 imageproc
或 opencv-rust
繪制檢測框(需額外安裝依賴)。
完整流程示例
fn main() -> anyhow::Result<()> {let model = load_model("fasterrcnn.pt")?;let input = preprocess_image("input.jpg")?;let (bboxes, scores) = run_detection(&model, &input)?;let detections = filter_results(&bboxes, &scores, 0.7);for (bbox, score) in detections {println!("Detected: {:?} with score {:.2}", bbox, score);}Ok(())
}
注意事項
- 模型需提前轉換為 TorchScript 格式(通過 Python 的
torch.jit.script
) - GPU 加速需配置 CUDA 環境
- 輸入圖像尺寸應與模型訓練時一致
- COCO 數據集的類別標簽需單獨加載
Rust 生態的計算機視覺庫(如 cv
)可進一步簡化圖像操作,但 tch-rs
目前是調用 PyTorch 模型的最成熟方案。
Polars 支持各種文件格式
Polars 支持各種文件格式、包括 CSV、Parquet 和 JSON
use polars::prelude::*;fn main() -> Result<()> {// Create a DataFrame with 4 names, ages, and citieslet df = df!["name" => &["周杰倫", "力辣", "張慧費", "王菲"],"age" => &[55, 60, 70, 67],"city" => &["New York", "Los Angeles", "Chicago", "San Francisco"]]?;// Display the DataFrameprintln!("{:?}", df);Ok(())
}
集成Polars和Pyo3構建
在Rust中集成Polars(數據框庫)和Pyo3(Python綁定)構建Web服務,可以通過以下方法實現:
創建基礎Rust項目
使用Cargo初始化新項目,添加必要的依賴。Cargo.toml
需要包含以下依賴項:
[dependencies]
actix-web = "4" # Web框架
polars = { version = "0.28", features = ["lazy"] } # 數據處理
pyo3 = { version = "0.18", features = ["extension-module"] } # Python集成
tokio = { version = "1", features = ["full"] } # 異步運行時