YOLOv5 + SE注意力機制:提升目標檢測性能的實踐

一、引言

目標檢測是計算機視覺領域的一個重要任務,廣泛應用于自動駕駛、安防監控、工業檢測等領域。YOLOv5作為YOLO系列的最新版本,以其高效性和準確性在實際應用中表現出色。然而,隨著應用場景的復雜化,傳統的卷積神經網絡在處理復雜背景和多尺度目標時可能會遇到性能瓶頸。為此,引入注意力機制成為了一種有效的改進方法。本文將詳細介紹如何在YOLOv5中引入SE(Squeeze-and-Excitation)注意力機制,通過修改模型配置文件和代碼實現,提升模型性能,并對比訓練效果。

YOLOv5是YOLO系列的最新版本,相較于之前的版本,YOLOv5在模型結構、訓練策略和數據增強等方面進行了多項改進,顯著提升了模型的性能和效率。其主要特點包括:

  • 模型結構優化:YOLOv5采用新的骨干網絡(Backbone)和路徑聚合網絡(Neck),提高了特征提取和融合的能力。
  • 數據增強策略:引入了多種數據增強方法,如Mosaic、MixUp等,提升了模型的泛化能力。
  • 訓練策略改進:采用動態標簽分配策略(SimOTA),提高了訓練效率和檢測精度。

然而,隨著任務復雜度的增加,傳統的卷積神經網絡在處理多尺度目標時的表現不夠理想,SE注意力機制的引入為提升目標檢測精度提供了新的思路。

二、YOLOv5與SE注意力機制

2.1 YOLOv5簡介

YOLOv5以其高效性和準確性在目標檢測中得到了廣泛應用。其主要結構特點是:

  • Backbone:負責從輸入圖像中提取特征。
  • Neck:通過特征融合提高模型的多尺度感知能力。
  • Head:根據提取的特征進行預測。

2.2 SE注意力機制簡介

SE(Squeeze-and-Excitation)注意力機制是一種輕量級的注意力模塊,旨在通過顯式地建模通道間的依賴關系,提升模型的表示能力。SE模塊由兩個關鍵部分組成:

  • Squeeze(壓縮):通過全局平均池化操作,將特征圖的空間維度壓縮為1,生成通道描述符。
  • Excitation(激勵):通過兩個全連接層和一個Sigmoid激活函數生成通道權重,用于重新校準特征圖的通道響應。

通過引入SE模塊,YOLOv5能夠更加關注重要的特征通道,抑制不重要的特征通道,從而提升模型性能。

三、YOLOv5 + SE注意力機制的實現

3.1 模型配置文件修改

首先,想要將SE注意力機制引入到Yolov5中去,需要修改以下幾個文件:commom.py、yolo.py和yolov5s.yaml文件。需要修改YOLOv5的模型配置文件(yolov5_se.yaml),在Backbone和Neck中引入SE模塊。注意將SE模塊引入之后,需要更改層數的號碼,SE注意力機制也可以加入到其他層中,比如head層的P3輸出之前等等。以下是修改后的配置文件內容:

# YOLOv5 馃殌 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SENet,[1024]], #SEAttention #9[-1, 1, SPPF, [1024, 5]],  # 10]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [256, False]],  # 18 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [512, False]],  # 21 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [1024, False]],  # 24 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

3.2 SE注意力模塊的代碼實現

在YOLOv5的代碼中,需要實現SE模塊。以下是一個SEBlock的實現:

import torch
import torch.nn as nnclass SENet(nn.Module):#c1, c2, n=1, shortcut=True, g=1, e=0.5def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5 ):super(SENet, self).__init__()#c*1*1self.avgpool = nn.AdaptiveAvgPool2d(1)self.l1 = nn.Linear(c1, c1 // 16, bias=False)self.relu = nn.ReLU(inplace=True)self.l2 = nn.Linear(c1 // 16, c1, bias=False)self.sig = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()y = self.avgpool(x).view(b, c)y = self.l1(y)y = self.relu(y)y = self.l2(y)y = self.sig(y)y = y.view(b, c, 1, 1)return x * y.expand_as(x)

3.3 使用SE注意力模塊

為了在YOLOv5的Backbone和Neck中引入SE模塊,可以對Yolo.py文件原有的parse_model進行修改,以下是修改后的Bottleneck模塊:

def parse_model(d, ch):  # model_dict, input_channels(3)# Parse a YOLOv5 model.yaml dictionaryLOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()LOGGER.info(f"{colorstr('activation:')} {act}")  # printna = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchorsno = na * (nc + 5)  # number of outputs = anchors * (classes + 5)layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, argsm = eval(m) if isinstance(m, str) else m  # eval stringsfor j, a in enumerate(args):with contextlib.suppress(NameError):args[j] = eval(a) if isinstance(a, str) else a  # eval stringsn = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gainif m in {Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,SENet,}:c1, c2 = ch[f], args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x, CBAMBottleneck, CABottleneck, CBAMC3, SENet, CANet, CAC3, CBAM, ECANet, GAMNet}:args.insert(2, n)  # number of repeatsn = 1elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum(ch[x] for x in f)# TODO: channel, gw, gdelif m in {Detect, Segment}:args.append([ch[x] for x in f])if isinstance(args[1], int):  # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)if m is Segment:args[3] = make_divisible(args[3] * gw, 8)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace('__main__.', '')  # module typenp = sum(x.numel() for x in m_.parameters())  # number paramsm_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number paramsLOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)

3.4 模型訓練與效果對比

完成模型配置文件和代碼的修改后,可以開始訓練模型。推薦使用

COCO數據集或自定義數據集進行訓練和驗證。或者其他的自定義數據集也可以,在這里我使用自定義數據集camel_elephant_training進行100個epoch訓練,該數據集僅僅有駱駝和大象兩個種類。

訓練完成后,可以通過AP(平均精度)指標來評估引入SE注意力機制前后的模型性能。一般情況下,引入SE模塊后,YOLOv5在復雜背景和多尺度目標的檢測中表現更為出色。

訓練之后的結果如下:

由于時間有限我僅僅訓練了100個epoch,正常情況下應設置150~200epoch,從train/obj_loss來看,仍然有下降的空間。

3.5 訓練步驟

  1. 配置訓練環境,確保已安裝YOLOv5和相關依賴。
  2. 下載COCO數據集或使用自定義數據集進行訓練。
  3. 修改訓練腳本,加載修改后的模型配置文件yolov5_se.yaml
  4. 開始訓練并監控訓練過程中的損失和精度。
  5. 完成訓練后,使用驗證集評估效果。

3.6?模型部署

將訓練好的數據權重通過export.py文件轉換成.onnx格式,可以部署到任意平臺上。

import argparse
import contextlib
import json
import os
import platform
import re
import subprocess
import sys
import time
import warnings
from pathlib import Pathimport pandas as pd
import torch
from torch.utils.mobile_optimizer import optimize_for_mobileFILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:sys.path.append(str(ROOT))  # add ROOT to PATH
if platform.system() != 'Windows':ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relativefrom models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_modeMACOS = platform.system() == 'Darwin'  # macOS environmentdef export_formats():# YOLOv5 export formatsx = [['PyTorch', '-', '.pt', True, True],['TorchScript', 'torchscript', '.torchscript', True, True],['ONNX', 'onnx', '.onnx', True, True],['OpenVINO', 'openvino', '_openvino_model', True, False],['TensorRT', 'engine', '.engine', False, True],['CoreML', 'coreml', '.mlmodel', True, False],['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],['TensorFlow GraphDef', 'pb', '.pb', True, True],['TensorFlow Lite', 'tflite', '.tflite', True, False],['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],['TensorFlow.js', 'tfjs', '_web_model', False, False],['PaddlePaddle', 'paddle', '_paddle_model', True, True],]return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])def try_export(inner_func):# YOLOv5 export decorator, i..e @try_exportinner_args = get_default_args(inner_func)def outer_func(*args, **kwargs):prefix = inner_args['prefix']try:with Profile() as dt:f, model = inner_func(*args, **kwargs)LOGGER.info(f'{prefix} export success 鉁?{dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')return f, modelexcept Exception as e:LOGGER.info(f'{prefix} export failure 鉂?{dt.t:.1f}s: {e}')return None, Nonereturn outer_func@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):# YOLOv5 TorchScript model exportLOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')f = file.with_suffix('.torchscript')ts = torch.jit.trace(model, im, strict=False)d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()if optimize:  # https://pytorch.org/tutorials/recipes/mobile_interpreter.htmloptimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)else:ts.save(str(f), _extra_files=extra_files)return f, None@try_export
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX exportcheck_requirements('onnx')import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']if dynamic:dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}}  # shape(1,3,640,640)if isinstance(model, SegmentationModel):dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'}  # shape(1,32,160,160)elif isinstance(model, DetectionModel):dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)torch.onnx.export(model.cpu() if dynamic else model,  # --dynamic only compatible with cpuim.cpu() if dynamic else im,f,verbose=False,opset_version=opset,do_constant_folding=True,input_names=['images'],output_names=output_names,dynamic_axes=dynamic or None)# Checksmodel_onnx = onnx.load(f)  # load onnx modelonnx.checker.check_model(model_onnx)  # check onnx model# Metadatad = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)# Simplifyif simplify:try:cuda = torch.cuda.is_available()check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx)assert check, 'assert check failed'onnx.save(model_onnx, f)except Exception as e:LOGGER.info(f'{prefix} simplifier failure: {e}')return f, model_onnx@try_export
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):# YOLOv5 OpenVINO exportcheck_requirements('openvino-dev')  # requires openvino-dev: https://pypi.org/project/openvino-dev/import openvino.inference_engine as ieLOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')f = str(file).replace('.pt', f'_openvino_model{os.sep}')cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"subprocess.run(cmd.split(), check=True, env=os.environ)  # exportyaml_save(Path(f) / file.with_suffix('.yaml').name, metadata)  # add metadata.yamlreturn f, None@try_export
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):# YOLOv5 Paddle exportcheck_requirements(('paddlepaddle', 'x2paddle'))import x2paddlefrom x2paddle.convert import pytorch2paddleLOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')f = str(file).replace('.pt', f'_paddle_model{os.sep}')pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im])  # exportyaml_save(Path(f) / file.with_suffix('.yaml').name, metadata)  # add metadata.yamlreturn f, None@try_export
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):# YOLOv5 CoreML exportcheck_requirements('coremltools')import coremltools as ctLOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')f = file.with_suffix('.mlmodel')ts = torch.jit.trace(model, im, strict=False)  # TorchScript modelct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)if bits < 32:if MACOS:  # quantization only supported on macOSwith warnings.catch_warnings():warnings.filterwarnings("ignore", category=DeprecationWarning)  # suppress numpy==1.20 float warningct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)else:print(f'{prefix} quantization only supported on macOS, skipping...')ct_model.save(f)return f, ct_model@try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrtassert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'try:import tensorrt as trtexcept Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trtif trt.__version__[0] == '7':  # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12model.model[-1].anchor_grid = gridelse:  # TensorRT >= 8check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12onnx = file.with_suffix('.onnx')LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')assert onnx.exists(), f'failed to export ONNX file: {onnx}'f = file.with_suffix('.engine')  # TensorRT engine filelogger = trt.Logger(trt.Logger.INFO)if verbose:logger.min_severity = trt.Logger.Severity.VERBOSEbuilder = trt.Builder(logger)config = builder.create_builder_config()config.max_workspace_size = workspace * 1 << 30# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)  # fix TRT 8.4 deprecation noticeflag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))network = builder.create_network(flag)parser = trt.OnnxParser(network, logger)if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')inputs = [network.get_input(i) for i in range(network.num_inputs)]outputs = [network.get_output(i) for i in range(network.num_outputs)]for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')if dynamic:if im.shape[0] <= 1:LOGGER.warning(f"{prefix} WARNING 鈿狅笍 --dynamic model requires maximum --batch-size argument")profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())return f, None@try_export
def export_saved_model(model,im,file,dynamic,tf_nms=False,agnostic_nms=False,topk_per_class=100,topk_all=100,iou_thres=0.45,conf_thres=0.25,keras=False,prefix=colorstr('TensorFlow SavedModel:')):# YOLOv5 TensorFlow SavedModel exporttry:import tensorflow as tfexcept Exception:check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")import tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2from models.tf import TFModelLOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')f = str(file).replace('.pt', '_saved_model')batch_size, ch, *imgsz = list(im.shape)  # BCHWtf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)im = tf.zeros((batch_size, *imgsz, ch))  # BHWC order for TensorFlow_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)keras_model.trainable = Falsekeras_model.summary()if keras:keras_model.save(f, save_format='tf')else:spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)m = tf.function(lambda x: keras_model(x))  # full modelm = m.get_concrete_function(spec)frozen_func = convert_variables_to_constants_v2(m)tfm = tf.Module()tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])tfm.__call__(im)tf.saved_model.save(tfm,f,options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())return f, keras_model@try_export
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlowimport tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')f = file.with_suffix('.pb')m = tf.function(lambda x: keras_model(x))  # full modelm = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))frozen_func = convert_variables_to_constants_v2(m)frozen_func.graph.as_graph_def()tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)return f, None@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):# YOLOv5 TensorFlow Lite exportimport tensorflow as tfLOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')batch_size, ch, *imgsz = list(im.shape)  # BCHWf = str(file).replace('.pt', '-fp16.tflite')converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]converter.target_spec.supported_types = [tf.float16]converter.optimizations = [tf.lite.Optimize.DEFAULT]if int8:from models.tf import representative_dataset_gendataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]converter.target_spec.supported_types = []converter.inference_input_type = tf.uint8  # or tf.int8converter.inference_output_type = tf.uint8  # or tf.int8converter.experimental_new_quantizer = Truef = str(file).replace('.pt', '-int8.tflite')if nms or agnostic_nms:converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)tflite_model = converter.convert()open(f, "wb").write(tflite_model)return f, None@try_export
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/cmd = 'edgetpu_compiler --version'help_url = 'https://coral.ai/docs/edgetpu/compiler/'assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0  # sudo installed on systemfor c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -','echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list','sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')f = str(file).replace('.pt', '-int8_edgetpu.tflite')  # Edge TPU modelf_tfl = str(file).replace('.pt', '-int8.tflite')  # TFLite modelcmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"subprocess.run(cmd.split(), check=True)return f, None@try_export
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):# YOLOv5 TensorFlow.js exportcheck_requirements('tensorflowjs')import tensorflowjs as tfjsLOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')f = str(file).replace('.pt', '_web_model')  # js dirf_pb = file.with_suffix('.pb')  # *.pb pathf_json = f'{f}/model.json'  # *.json pathcmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'subprocess.run(cmd.split())json = Path(f_json).read_text()with open(f_json, 'w') as j:  # sort JSON Identity_* in ascending ordersubst = re.sub(r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, 'r'"Identity_1": {"name": "Identity_1"}, 'r'"Identity_2": {"name": "Identity_2"}, 'r'"Identity_3": {"name": "Identity_3"}}}', json)j.write(subst)return f, Nonedef add_tflite_metadata(file, metadata, num_outputs):# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadatawith contextlib.suppress(ImportError):# check_requirements('tflite_support')from tflite_support import flatbuffersfrom tflite_support import metadata as _metadatafrom tflite_support import metadata_schema_py_generated as _metadata_fbtmp_file = Path('/tmp/meta.txt')with open(tmp_file, 'w') as meta_f:meta_f.write(str(metadata))model_meta = _metadata_fb.ModelMetadataT()label_file = _metadata_fb.AssociatedFileT()label_file.name = tmp_file.namemodel_meta.associatedFiles = [label_file]subgraph = _metadata_fb.SubGraphMetadataT()subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputsmodel_meta.subgraphMetadata = [subgraph]b = flatbuffers.Builder(0)b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)metadata_buf = b.Output()populator = _metadata.MetadataPopulator.with_model_file(file)populator.load_metadata_buffer(metadata_buf)populator.load_associated_files([str(tmp_file)])populator.populate()tmp_file.unlink()@smart_inference_mode()
def run(data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'weights=ROOT / 'yolov5s.pt',  # weights pathimgsz=(640, 640),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinclude=('torchscript', 'onnx'),  # include formatshalf=False,  # FP16 half-precision exportinplace=False,  # set YOLOv5 Detect() inplace=Truekeras=False,  # use Kerasoptimize=False,  # TorchScript: optimize for mobileint8=False,  # CoreML/TF INT8 quantizationdynamic=False,  # ONNX/TF/TensorRT: dynamic axessimplify=False,  # ONNX: simplify modelopset=12,  # ONNX: opset versionverbose=False,  # TensorRT: verbose logworkspace=4,  # TensorRT: workspace size (GB)nms=False,  # TF: add NMS to modelagnostic_nms=False,  # TF: add agnostic NMS to modeltopk_per_class=100,  # TF.js NMS: topk per class to keeptopk_all=100,  # TF.js NMS: topk for all classes to keepiou_thres=0.45,  # TF.js NMS: IoU thresholdconf_thres=0.25,  # TF.js NMS: confidence threshold
):t = time.time()include = [x.lower() for x in include]  # to lowercasefmts = tuple(export_formats()['Argument'][1:])  # --include argumentsflags = [x in include for x in fmts]assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags  # export booleansfile = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights# Load PyTorch modeldevice = select_device(device)if half:assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'model = attempt_load(weights, device=device, inplace=True, fuse=True)  # load FP32 model# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expandif optimize:assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# Update modelmodel.eval()for k, m in model.named_modules():if isinstance(m, Detect):m.inplace = inplacem.dynamic = dynamicm.export = Truefor _ in range(2):y = model(im)  # dry runsif half and not coreml:im, model = im.half(), model.half()  # to FP16shape = tuple((y[0] if isinstance(y, tuple) else y).shape)  # model output shapemetadata = {'stride': int(max(model.stride)), 'names': model.names}  # model metadataLOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")# Exportsf = [''] * len(fmts)  # exported filenameswarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningif jit:  # TorchScriptf[0], _ = export_torchscript(model, im, file, optimize)if engine:  # TensorRT required before ONNXf[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)if onnx or xml:  # OpenVINO requires ONNXf[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)if xml:  # OpenVINOf[3], _ = export_openvino(file, metadata, half)if coreml:  # CoreMLf[4], _ = export_coreml(model, im, file, int8, half)if any((saved_model, pb, tflite, edgetpu, tfjs)):  # TensorFlow formatsassert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'f[5], s_model = export_saved_model(model.cpu(),im,file,dynamic,tf_nms=nms or agnostic_nms or tfjs,agnostic_nms=agnostic_nms or tfjs,topk_per_class=topk_per_class,topk_all=topk_all,iou_thres=iou_thres,conf_thres=conf_thres,keras=keras)if pb or tfjs:  # pb prerequisite to tfjsf[6], _ = export_pb(s_model, file)if tflite or edgetpu:f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)if edgetpu:f[8], _ = export_edgetpu(file)add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))if tfjs:f[9], _ = export_tfjs(file)if paddle:  # PaddlePaddlef[10], _ = export_paddle(model, im, file, metadata)# Finishf = [str(x) for x in f if x]  # filter out '' and Noneif any(f):cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel))  # typedir = Path('segment' if seg else 'classify' if cls else '')h = '--half' if half else ''  # --half FP16 inference args = "# WARNING 鈿狅笍 ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \"# WARNING 鈿狅笍 SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'f"\nResults saved to {colorstr('bold', file.parent.resolve())}"f"\nDetect:          python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"f"\nValidate:        python {dir / 'val.py'} --weights {f[-1]} {h}"f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')  {s}"f"\nVisualize:       https://netron.app")return f  # return list of exported files/dirsdef parse_opt():parser = argparse.ArgumentParser()parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')parser.add_argument('--batch-size', type=int, default=1, help='batch size')parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')parser.add_argument('--half', action='store_true', help='FP16 half-precision export')parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')parser.add_argument('--keras', action='store_true', help='TF: use Keras')parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')parser.add_argument('--include',nargs='+',default=['torchscript'],help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')opt = parser.parse_args()print_args(vars(opt))return optdef main(opt):for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):run(**vars(opt))if __name__ == "__main__":opt = parse_opt()main(opt)

四、總結

本文介紹了如何在YOLOv5中引入SE注意力機制,包括模型配置文件的修改、代碼實現、訓練步驟以及效果對比。通過引入SE模塊,YOLOv5在多尺度目標和復雜背景下的檢測精度有所提升。未來,可以繼續探索其他注意力機制(如CBAM、ECA等)的應用,以進一步提升YOLOv5的性能。感謝大家的支持。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/896549.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/896549.shtml
英文地址,請注明出處:http://en.pswp.cn/news/896549.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

跟我學C++中級篇——定時器的設計

一、定時器 談到定時器&#xff0c;理論上講是各種語言和各種設計都無法避開的一個技術點。對于定時器來說&#xff0c;表面上就是一種時間間隔的處理約定&#xff0c;但對程序來說&#xff0c;可能就是設計層面、接口層面和庫或框架以及系統應用的一個大集合。不同的系統&…

智能機器人加速進化:AI大模型與傳感器的雙重buff加成

Deepseek不僅可以在手機里為你解答現在的困惑、占卜未來的可能&#xff0c;也將成為你的貼心生活幫手&#xff01; 2月21日&#xff0c;追覓科技旗下Dreamehome APP正式接入DeepSeek-R1大模型&#xff0c;2月24日發布的追覓S50系列掃地機器人也成為市面上首批搭載DeepSeek-R1的…

PostgreSQL10 邏輯復制實戰:構建高可用數據同步架構!

PostgreSQL10 邏輯復制實戰&#xff1a;打造高可用數據同步架構&#xff01; 概述 PostgreSQL 10 引入了邏輯復制&#xff08;Logical Replication&#xff09;&#xff0c;為數據庫高可用和數據同步提供了更靈活的選擇。PostgreSQL 復制機制主要分為物理復制和邏輯復制兩種&…

LVS+Keepalived高可用群集配置案例

以下是一個 LVSKeepalived 高可用群集配置案例&#xff1a; 1、環境準備 LVS 主調度器&#xff08;lvs1&#xff09;&#xff1a;IP 地址為 192.168.8.101&#xff0c;心跳 IP 為 192.168.4.101LVS 備調度器&#xff08;lvs2&#xff09;&#xff1a;IP 地址為 192.168.8.102…

原生家庭獨立的藝術:找到自我與家庭的平衡點

原生家庭獨立的藝術&#xff1a;找到自我與家庭的平衡點 &#x1f331; 引言 &#x1f308; 小林剛剛和父母結束了一次激烈的電話對峙。父母堅持認為他應該回到家鄉工作&#xff0c;“這樣我們也能照顧你”&#xff0c;而他則努力解釋自己在大城市的職業規劃。掛掉電話后&…

Java進階——注解一文全懂

Java注解&#xff08;Annotation&#xff09;是一種強大的元數據機制&#xff0c;為代碼提供了附加信息&#xff0c;能簡化配置、增強代碼的可讀性和可維護性。本文將深入探討 Java 注解的相關知識。首先闡述了注解的基礎概念&#xff0c;包括其本質、作用以及核心分類&#xf…

DeepSeek 15天指導手冊——從入門到精通 PDF(附下載)

DeepSeek使用教程系列--DeepSeek 15天指導手冊——從入門到精通pdf下載&#xff1a; https://pan.baidu.com/s/1PrIo0Xo0h5s6Plcc_smS8w?pwd1234 提取碼: 1234 或 https://pan.quark.cn/s/2e8de75027d3 《DeepSeek 15天指導手冊——從入門到精通》以系統化學習路徑為核心&…

【智能音頻新風尚】智能音頻眼鏡+FPC,打造極致聽覺享受!【新立電子】

智能音頻眼鏡&#xff0c;作為一款將時尚元素與前沿科技精妙融合的智能設備&#xff0c;這種將音頻技術與眼鏡形態完美結合的可穿戴設備&#xff0c;不僅解放了用戶的雙手&#xff0c;更為人們提供了一種全新的音頻交互體驗。新立電子FPC在智能音頻眼鏡中的應用&#xff0c;為音…

常用的 pip 命令

pip 是 Python 的包管理工具&#xff0c;可用于安裝、卸載、更新和管理 Python 包。以下是一些常用的 pip 命令&#xff1a; 1. 安裝包 安裝最新版本的包 pip install package_namepackage_name 是你要安裝的 Python 包的名稱&#xff0c;例如 pip install requests 可以安裝…

學習threejs,使用ShaderMaterial自定義著色器材質

&#x1f468;??? 主頁&#xff1a; gis分享者 &#x1f468;??? 感謝各位大佬 點贊&#x1f44d; 收藏? 留言&#x1f4dd; 加關注?! &#x1f468;??? 收錄于專欄&#xff1a;threejs gis工程師 文章目錄 一、&#x1f340;前言1.1 ??THREE.ShaderMaterial1.1.1…

從暴力破解到時空最優:LeetCode算法設計核心思維解密

一、算法優化金字塔模型&#xff08;時間復雜度/空間復雜度協同優化&#xff09; 1.1 復雜度分析的本質 大O記號的三層認知&#xff1a; ① 理論復雜度邊界&#xff08;理想模型&#xff09; ② 硬件架構影響&#xff08;緩存命中率/分支預測&#xff09; ③ 語言特性損耗&am…

Typora的Github主題美化

[!note] Typora的Github主題進行一些自己喜歡的修改&#xff0c;主要包括&#xff1a;字體、代碼塊、表格樣式 美化前&#xff1a; 美化后&#xff1a; 一、字體更換 之前便看上了「中文網字計劃」的「朱雀仿宋」字體&#xff0c;于是一直想更換字體&#xff0c;奈何自己拖延癥…

用大白話解釋搜索引擎Elasticsearch是什么,有什么用,怎么用

Elasticsearch是什么&#xff1f; Elasticsearch&#xff08;簡稱ES&#xff09;就像一個“超級智能的圖書館管理系統”&#xff0c;專門幫你從海量數據中快速找到想要的信息。它底層基于倒排索引技術&#xff08;類似書籍的目錄頁&#xff09;&#xff0c;能秒級搜索和分析萬…

神經網絡 - 激活函數(Sigmoid 型函數)

激活函數在神經元中非常重要的。為了增強網絡的表示能力和學習能力&#xff0c;激活函數需要具備以下幾點性質: (1) 連續并可導(允許少數點上不可導)的非線性函數。可導的激活函數可以直接利用數值優化的方法來學習網絡參數. (2) 激活函數及其導函數要盡可能的簡單&#xff0…

Spring 源碼硬核解析系列專題(六):Spring MVC 的請求處理源碼解析

在前幾期中,我們探討了 Spring 的 IoC 容器、Bean 創建、AOP、事務管理以及 Spring Boot 的自動裝配,這些為 Spring MVC 的運行奠定了基礎。作為 Spring 生態中處理 Web 請求的核心模塊,Spring MVC 通過 DispatcherServlet 實現了靈活的請求分發與處理。本篇將深入 Dispatch…

Docker容器日常維護常用命令大全

友情提示&#xff1a;本文內容由銀河易創&#xff08;https://ai.eaigx.com&#xff09;AI創作平臺deepseek-v3模型生成&#xff0c;文中所有命令未進行驗證&#xff0c;僅供參考。請根據具體情況和需求進行適當的調整和驗證。 引言 Docker作為當前最流行的容器化技術&#xf…

Pytest測試用例執行跳過的3種方式

文章目錄 1.前言2.使用 pytest.mark.skip 標記無條件跳過3.使用 pytest.mark.skipif 標記根據條件跳過4. 執行pytest.skip()方法跳過測試用例 1.前言 在實際場景中&#xff0c;我們可能某條測試用例沒寫完&#xff0c;代碼執行時會報錯&#xff0c;或者是在一些條件下不讓某些…

GitHub 語析 - 基于大模型的知識庫與知識圖譜問答平臺

語析 - 基于大模型的知識庫與知識圖譜問答平臺 GitHub 地址&#xff1a;https://github.com/xerrors/Yuxi-Know &#x1f4dd; 項目概述 語析是一個強大的問答平臺&#xff0c;結合了大模型 RAG 知識庫與知識圖譜技術&#xff0c;基于 Llamaindex VueJS FastAPI Neo4j 構…

vue學習七

十四 pinia 官網&#xff1a;安裝 | Pinia 中文文檔 集中式狀態管理&#xff0c;與vuex相似&#xff0c;提供變量存儲便于數據共享。 從概念上類似于php中的session吧…… 適用于少量數據的共享&#xff0c;可操作數據都是先定義后使用。 適用于判斷用戶是否登錄&#xff…

【Prometheus】prometheus服務發現與relabel原理解析與應用實戰

?? 歡迎大家來到景天科技苑?? ???? 養成好習慣,先贊后看哦~???? ?? 作者簡介:景天科技苑 ??《頭銜》:大廠架構師,華為云開發者社區專家博主,阿里云開發者社區專家博主,CSDN全棧領域優質創作者,掘金優秀博主,51CTO博客專家等。 ??《博客》:Python全…