如何增强小目标检测?SAHI与Yolov8完美结合

机器学习图像处理算法

picture.image

向AI转型的程序员都关注公众号 机器学习AI算法工程

在本文中,你将学习如何使用切片辅助超推理 (SAHI) 检测数据集中的小物体。我们将介绍以下内容:

  • 为什么很难检测小物体

  • SAHI 的工作原理

  • 如何将 SAHI 应用于你的数据集,以及

  • 如何评估这些预测的质量

picture.image

为什么检测小物体很难?

它们很小

首先,检测小物体很难,因为小物体很小。物体越小,检测模型需要处理的信息就越少。如果汽车在远处,它可能只占据我们图像中的几个像素。就像人类难以辨别远处的物体一样,我们的模型在没有车轮和车牌等视觉可辨别特征的情况下更难识别汽车!

训练数据

模型的好坏取决于训练它们的数据。大多数标准物体检测数据集和基准都集中在中到大型物体上,这意味着大多数现成的物体检测模型并未针对小物体检测进行优化。

固定输入大小

物体检测模型通常采用固定大小的输入。例如,YOLOv8 在最大边长为 640 像素的图像上进行训练。这意味着当我们输入一张 1920x1080 大小的图像时,模型会在进行预测之前将图像下采样到 640x360,降低分辨率并丢弃小物体的重要信息。

SAHI 的工作原理

picture.image

理论上,你可以在较大的图像上训练模型以改进对小物体的检测。但实际上,这需要更多的内存、更多的计算能力和更耗费人力的数据集。

另一种方法是利用现有的物体检测,将模型应用于图像中固定大小的块或切片,然后将结果拼接在一起。这就是切片辅助超推理(Slicing-Aided Hyper Inference)背后的想法!

SAHI 的工作原理是将图像分成完全覆盖它的切片,并使用指定的检测模型对每个切片进行推理。然后将所有这些切片的预测合并在一起,以生成整个图像的一个检测列表。SAHI 中的“超”来自这样一个事实,即 SAHI 的输出不是模型推理的结果,而是涉及多个模型推理的计算的结果。

💡SAHI 切片可以重叠(如上图 GIF 所示),这有助于确保至少一个切片中有足够的物体可供检测。

使用 SAHI 的主要优势在于它与模型无关。SAHI 可以利用当今的 SOTA 对象检测模型以及明天的 SOTA 模型!

当然,天下没有免费的午餐。作为“超推理”的交换,除了将结果拼接在一起所需的处理之外,你还要运行检测模型的多次前向传递。

设置 SAHI

为了说明如何应用 SAHI 来检测小物体,我们将使用天津大学机器学习与数据挖掘实验室的 AISKYEYE 团队的 VisDrone 检测数据集。该数据集包含 8629 张图像,边长从 360 像素到 2000 像素不等。Ultralytics 的 YOLOv8l 将作为我们的基础物体检测模型。

我们将使用以下库:

  • fiftyone 用于数据集管理和可视化

  • huggingface_hub 用于从 Hugging Face Hub 加载 VisDrone 数据集

  • ultralytics 用于使用 YOLOv8 运行推理,以及

  • sahi 用于在图像切片上运行推理

如果你还没有安装这些库的最新版本,请安装。你需要 fiftyone>=0.23.8 才能从 Hugging Face Hub 加载 VisDrone:


          
              

            pip install -U fiftyone sahi ultralytics huggingface\_hub --quiet
          
        

然后只需要导入相关的模块:


            
import fiftyone as fo
            
import fiftyone.zoo as foz
            
import fiftyone.utils.huggingface as fouh
            
from fiftyone import ViewField as F
        

就这样,我们就可以加载数据了!我们将使用 FiftyOne 的 Hugging Face 实用程序中的 load_from_hub() 函数,通过其 repo_id 直接从 Hugging Face Hub 加载 VisDrone 数据集的一部分。

为了演示并尽可能快地执行代码,我们将仅从数据集中获取前 100 张图像。我们还将为这个正在创建的新数据集命名为“sahi-test”:


            
dataset = fouh.load_from_hub(
            
    "Voxel51/VisDrone2019-DET", 
            
    name="sahi-test", 
            
    max_samples=100
            
)
        

在添加任何预测之前,让我们先看看 FiftyOne 应用程序中的数据集:


          
              

            session = fo.launch\_app(dataset)
          
        

picture.image

与 YOLOv8 的标准接口

在下一节中,我们将使用 SAHI 对我们的数据进行超推理。在引入 SAHI 之前,让我们使用 Ultralytics 的 YOLOv8 模型的大型变体对我们的数据进行标准对象检测推理。

首先,我们创建一个 ultralytics.YOLO 模型实例,并在必要时下载模型检查点。然后,我们将此模型应用于我们的数据集,并将结果存储在样本的字段“base_model”中:


            
from ultralytics import YOLO
            

            
ckpt_path = "yolov8l.pt"
            
model = YOLO(ckpt_path)
            

            
dataset.apply_model(model, label_field="base_model")
            
session.view = dataset.view()
        

picture.image

通过查看模型的预测和真实值标签,我们可以看到一些事情。首先,我们的 YOLOv8l 模型检测到的类别与 VisDrone 数据集中的真实值类别不同。我们的 YOLO 模型是在 COCO 数据集上训练的,该数据集有 80 个类别,而 VisDrone 数据集有 12 个类别,包括一个 ignore_regions 类。

为了简化比较,我们将只关注数据集中最常见的几个类别,并将 VisDrone 类映射到 COCO 类,如下所示:


            
mapping = {"pedestrians": "person", "people": "person", "van": "car"}
            
mapped_view = dataset.map_labels("ground_truth", mapping)
        

然后过滤标签,仅包含我们感兴趣的类别:


            
def get_label_fields(sample_collection):
            
    """Get the (detection) label fields of a Dataset or DatasetView."""
            
    label_fields = list(
            
        sample_collection.get_field_schema(embedded_doc_type=fo.Detections).keys()
            
    )
            
    return label_fields
            

            
def filter_all_labels(sample_collection):
            
    label_fields = get_label_fields(sample_collection)
            

            
    filtered_view = sample_collection
            

            
    for lf in label_fields:
            
        filtered_view = filtered_view.filter_labels(
            
            lf, F("label").is_in(["person", "car", "truck"]), only_matches=False
            
        )
            
    return filtered_view
            

            
filtered_view = filter_all_labels(mapped_view)
            
session.view = filtered_view.view()
        

picture.image

现在我们有了基本模型预测,让我们使用 SAHI 来对图像进行切片和切块💪。

使用 SAHI 进行超推理

SAHI 技术在我们之前安装的 sahi Python 包中实现。SAHI 是一个与许多对象检测模型兼容的框架,包括 YOLOv8。我们可以选择想要使用的检测模型,并创建任何 sahi.models.DetectionModel 子类的实例,包括 YOLOv8、YOLOv5 甚至 Hugging Face Transformers 模型。

我们将使用 SAHI 的 AutoDetectionModel 类创建模型对象,并指定模型类型和检查点文件的路径:


            
from sahi import AutoDetectionModel
            
from sahi.predict import get_prediction, get_sliced_prediction
            

            
detection_model = AutoDetectionModel.from_pretrained(
            
    model_type='yolov8',
            
    model_path=ckpt_path,
            
    confidence_threshold=0.25, ## same as the default value for our base model
            
    image_size=640,
            
    device="cpu", # or 'cuda' if you have access to GPU
            
)
        

在生成切片预测之前,让我们使用 SAHI 的 get_prediction() 函数检查模型在试验图像上的预测:


            
result = get_prediction(dataset.first().filepath, detection_model)
            
print(result)
        

输出是:


          
              

            <sahi.prediction.PredictionResult object at 0x2b0e9c250>
          
        

幸运的是,SAHI 结果对象有一个 to_fiftyone_detections() 方法,它将结果转换为 FiftyOne Detection 对象列表:


          
              

            print(result.to\_fiftyone\_detections())
          
        

picture.image

这样我们就可以专注于数据,而不是繁琐的格式转换细节。SAHI 的 get_sliced_prediction() 函数的工作方式与 get_prediction() 相同,但增加了一些超参数,让我们可以配置图像的切片方式。具体来说,我们可以指定切片高度和宽度,以及切片之间的重叠。以下是示例:


            
sliced_result = get_sliced_prediction(
            
    dataset.skip(40).first().filepath,
            
    detection_model,
            
    slice_height = 320,
            
    slice_width = 320,
            
    overlap_height_ratio = 0.2,
            
    overlap_width_ratio = 0.2,
            
)
        

作为初步检查,我们可以将切片预测中的检测数量与原始预测中的检测数量进行比较:


            
num_sliced_dets = len(sliced_result.to_fiftyone_detections())
            
num_orig_dets = len(result.to_fiftyone_detections())
            

            
print(f"Detections predicted without slicing: {num_orig_dets}")
            
print(f"Detections predicted with slicing: {num_sliced_dets}")
            

            
Detections predicted without slicing: 17
            
Detections predicted with slicing: 73
        

我们可以看到预测数量大幅增加!我们尚未确定这些额外的预测是否有效,或者我们是否只是有更多的误报。我们将很快使用 FiftyOne 的评估 API 来做到这一点。我们还想为我们的切片找到一组好的超参数。我们需要将 SAHI 应用于整个数据集来完成所有这些事情。现在就开始吧!

为了简化流程,我们将定义一个函数,将预测添加到指定标签字段中的样本,然后我们将遍历数据集,将该函数应用于每个样本。此函数将样本的文件路径和切片超参数传递给 get_sliced_prediction(),然后将预测添加到指定标签字段中的样本:


            
def predict_with_slicing(sample, label_field, **kwargs):
            
    result = get_sliced_prediction(
            
        sample.filepath, detection_model, verbose=0, **kwargs
            
    )
            
    sample[label_field] = fo.Detections(detections=result.to_fiftyone_detections())
        

我们将切片重叠固定为 0.2,并观察切片高度和宽度如何影响预测的质量:


            
kwargs = {"overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
            

            
for sample in dataset.iter_samples(progress=True, autosave=True):
            
    predict_with_slicing(sample, label_field="small_slices", slice_height=320, slice_width=320, **kwargs)
            
    predict_with_slicing(sample, label_field="large_slices", slice_height=480, slice_width=480, **kwargs)
        

请注意,这些推理时间比原始推理时间长得多。这是因为我们在每个图像的多个切片上运行模型,这增加了模型必须进行的正向传递次数。我们正在做出权衡以改善对小物体的检测。

现在让我们再次过滤标签,以挑选出我们感兴趣的类别,并在 FiftyOne 应用程序中可视化结果:


            
filtered_view = filter_all_labels(mapped_view)
            
session = fo.launch_app(filtered_view, auto=False)
        

picture.image

结果看起来确实很有希望!从几个视觉示例来看,切片似乎可以提高地面实况检测的覆盖率,尤其是较小的切片似乎可以捕获更多的人像检测。但我们如何才能确定呢?让我们运行一个评估程序,将检测标记为真阳性、假阳性或假阴性,以将切片预测与真实值进行比较。我们将使用过滤视图的evaluate_detections()方法。

评估 SAHI 的预测

继续使用数据集的过滤视图,让我们运行一个评估程序,将每个预测标签字段的预测与真实标签进行比较。在这里,我们使用默认的 IoU 阈值 0.5,但你可以根据需要进行调整:


            
base_results = filtered_view.evaluate_detections("base_model", gt_field="ground_truth", eval_key="eval_base_model")
            
large_slice_results = filtered_view.evaluate_detections("large_slices", gt_field="ground_truth", eval_key="eval_large_slices")
            
small_slice_results = filtered_view.evaluate_detections("small_slices", gt_field="ground_truth", eval_key="eval_small_slices")
        

逐个打印出来看:


            
print("Base model results:")
            
base_results.print_report()
            

            
print("-" * 50)
            
print("Large slice results:")
            
large_slice_results.print_report()
            

            
print("-" * 50)
            
print("Small slice results:")
            
small_slice_results.print_report()
        

picture.image

我们可以看到,随着我们引入更多切片,误报的数量会增加,而误报的数量会减少。这是意料之中的,因为模型能够使用更多切片检测更多物体,但也会犯更多错误!您可以应用更积极的置信度阈值来对抗误报的增加,但即使不这样做,F1 分数也会显著提高。

让我们更深入地研究这些结果。我们之前提到,该模型在处理小物体时会遇到困难,所以让我们看看这三种方法在小于 32x32 像素的物体上的表现如何。我们可以使用 FiftyOne 的 ViewField 执行此过滤:


            
## Filtering for only small boxes
            

            
box_width, box_height = F("bounding_box")[2], F("bounding_box")[3]
            
rel_bbox_area = box_width * box_height
            

            
im_width, im_height = F("$metadata.width"), F("$metadata.height")
            
abs_area = rel_bbox_area * im_width * im_height
            

            
small_boxes_view = filtered_view
            
for lf in get_label_fields(filtered_view):
            
    small_boxes_view = small_boxes_view.filter_labels(lf, abs_area < 32**2, only_matches=False)
            

            
session.view = small_boxes_view.view()
        

picture.image

如果我们根据这些视图评估我们的模型并像以前一样打印报告,我们可以清楚地看到 SAHI 提供的价值!使用 SAHI 时,对于小物体的召回率要高得多,而精度没有明显下降,从而提高了 F1 分数。这对于人体检测尤其明显,其中 F1 分数增加了三倍!


            
## Evaluating on only small boxes
            
small_boxes_base_results = small_boxes_view.evaluate_detections("base_model", gt_field="ground_truth", eval_key="eval_small_boxes_base_model")
            
small_boxes_large_slice_results = small_boxes_view.evaluate_detections("large_slices", gt_field="ground_truth", eval_key="eval_small_boxes_large_slices")
            
small_boxes_small_slice_results = small_boxes_view.evaluate_detections("small_slices", gt_field="ground_truth", eval_key="eval_small_boxes_small_slices")
            

            
## Printing reports
            
print("Small Box — Base model results:")
            
small_boxes_base_results.print_report()
            

            
print("-" * 50)
            
print("Small Box — Large slice results:")
            
small_boxes_large_slice_results.print_report()
            

            
print("-" * 50)
            
print("Small Box — Small slice results:")
            
small_boxes_small_slice_results.print_report()
        

picture.image

机器学习算法AI大数据技术

搜索公众号添加: datanlp

picture.image

长按图片,识别二维码

阅读过本文的人还看了以下文章:

实时语义分割ENet算法,提取书本/票据边缘

整理开源的中文大语言模型,以规模较小、可私有化部署、训练成本较低的模型为主

《大语言模型》PDF下载

动手学深度学习-(李沐)PyTorch版本

YOLOv9电动车头盔佩戴检测,详细讲解模型训练

TensorFlow 2.0深度学习案例实战

基于40万表格数据集TableBank,用MaskRCNN做表格检测

《基于深度学习的自然语言处理》中/英PDF

Deep Learning 中文版初版-周志华团队

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

《深度学习:基于Keras的Python实践》PDF和代码

特征提取与图像处理(第二版).pdf

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

搜索公众号添加: datayx

picture.image

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论