Skip to content

使用 YOLOX 进行目标检测训练

开始

从这里下载代码:https://github.com/Megvii-BaseDetection/YOLOX

INFO

YOLOX 是一种目标检测模型,它是基于 YOLO 系列(You Only Look Once)的最新发展。YOLOX 模型具有快速、准确和高效的特点,在计算机视觉领域中被广泛应用。

相比于传统的目标检测模型,YOLOX 具有以下几个显著的特点:

  1. 高精度:YOLOX 采用了一种新的解耦策略,通过设计“头”和“骨干”网络结构的不同组合,可以实现高精度的目标检测。同时,YOLOX 还采用了一种自适应权重融合策略,可以平衡不同尺度的特征图之间的信息。
  2. 快速:YOLOX 具有高效的运行速度,可以实现实时目标检测。它通过优化网络结构和使用高效的推理算法,从而在不降低准确性的情况下提高了检测速度。
  3. 可扩展性:YOLOX 的设计具有很高的可扩展性,可以在不同的场景和任务中灵活应用。它可以适应不同尺度和大小的目标,同时还支持多种视觉任务,如目标检测、实例分割等。
  4. 开源和易用:YOLOX 是一个开源项目,提供了完整的代码和预训练模型,方便研究人员和开发者进行使用和扩展。

第一步:使用 labelme 进行图片标注

INFO

Labelme 是一个开源的图像标注工具,用于创建和编辑图像标注数据集。它是由麻省理工学院计算机科学与人工智能实验室(MIT CSAIL)开发的,旨在为计算机视觉研究人员和开发者提供一个简单易用的标注工具。

Labelme 的主要特点和功能包括:

  1. 图像标注:Labelme 允许用户在图像上绘制各种形状,如矩形、多边形、线条等,以标注出感兴趣的目标或区域。用户可以选择不同的标注工具和颜色,以便清晰可见地标注图像。
  2. 标注数据保存:Labelme 支持将标注数据保存为 JSON 文件格式,其中包含了标注的形状、位置和标签等信息。这些标注数据可以用于训练计算机视觉模型,如目标检测、图像分割等任务。
  3. 标注数据可视化:Labelme 可以加载和显示已标注的图像和标注数据,以便用户查看和验证标注结果。用户可以轻松地切换不同图像的显示,以及在图像上显示标注的结果。
  4. 标注数据编辑:Labelme 允许用户对已标注的数据进行编辑和修改。用户可以添加、删除或修改标注的形状和标签,以便进行纠正或完善标注结果。
  5. 标注数据集管理:Labelme 支持管理多个标注数据集,用户可以方便地加载和切换不同的数据集,以便进行比较和分析。

生成如下结构的文件:

labelme
- x.jpg
- x.json
- y.jpg
- y.json
- ...
labelme
- x.jpg
- x.json
- y.jpg
- y.json
- ...

第二步:数据集转换到 COCO 格式

INFO

COCO(Common Objects in Context)是一种通用的图像数据集格式,用于目标检测、图像分割和关键点检测等计算机视觉任务。COCO 数据集格式由微软研究院提出,并已成为计算机视觉领域中最常用的数据集之一。

COCO 数据集格式的主要特点和组成部分包括:

  1. 图像:COCO 数据集包含大量的图像,每个图像都有唯一的标识符(ID)。这些图像可以来自不同的来源,如网络、摄像头等。
  2. 标注:对于每个图像,COCO 数据集提供了对应的标注信息。标注信息包括目标的边界框(Bounding Box)位置、目标类别、图像分割掩码、关键点位置等。每个标注都有一个唯一的标识符(ID),并与对应的图像关联。
  3. 类别:COCO 数据集定义了一组常见的目标类别,如人、动物、车辆、家具等。每个类别都有一个唯一的标识符(ID)和一个易于理解的名称。
  4. 标注文件:COCO 数据集的标注信息以 JSON 文件的形式进行存储。标注文件包含了图像和标注的详细信息,可以通过解析 JSON 文件来获取图像和标注数据。

COCO 数据集格式的优点在于其丰富的标注信息和广泛的应用领域。通过 COCO 格式,研究人员和开发者可以方便地使用和共享大规模的图像数据集,从而促进计算机视觉算法的研究和发展。此外,COCO 数据集还提供了评估指标和基准结果,可以用于算法性能的比较和评估。

需要注意的是,COCO 数据集格式是一种通用的格式,并不限定于特定的数据集。因此,可以根据具体的任务和需求,使用 COCO 格式进行数据集的创建、标注和管理。

使用 tools/labelme2coco.pylabelme 格式转换成 coco 格式

py
#!/usr/bin/env python

import argparse
import collections
import datetime
import glob
import json
import os
import os.path as osp
import sys
import uuid

import imgviz
import numpy as np

import labelme

from yolox.data import COCO_CLASSES

try:
    import pycocotools.mask
except ImportError:
    print("Please install pycocotools:\n\n    pip install pycocotools\n")
    sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("input_dir", help="input annotated directory")
    parser.add_argument("output_dir", help="output dataset directory")
    # parser.add_argument("--labels", help="labels file", required=True)
    parser.add_argument(
        "--noviz", help="no visualization", action="store_true"
    )
    args = parser.parse_args()

    if osp.exists(args.output_dir):
        print("Output directory already exists:", args.output_dir)
        sys.exit(1)
    os.makedirs(args.output_dir)
    os.makedirs(osp.join(args.output_dir, "JPEGImages"))
    if not args.noviz:
        os.makedirs(osp.join(args.output_dir, "Visualization"))
    print("Creating dataset:", args.output_dir)

    now = datetime.datetime.now()

    data = dict(
        info=dict(
            description=None,
            url=None,
            version=None,
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        licenses=[
            dict(
                url=None,
                id=0,
                name=None,
            )
        ],
        images=[
            # license, url, file_name, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    class_name_to_id = {}
    labels = list(COCO_CLASSES)
    labels.insert(0, "__ignore__")
    for i, line in enumerate(labels):
        class_id = i - 1  # starts with -1
        class_name = line.strip()
        if class_id == -1:
            assert class_name == "__ignore__"
            continue
        class_name_to_id[class_name] = class_id
        data["categories"].append(
            dict(
                supercategory=None,
                id=class_id,
                name=class_name,
            )
        )

    out_ann_file = osp.join(args.output_dir, "annotations.json")
    label_files = glob.glob(osp.join(args.input_dir, "*.json"))
    for image_id, filename in enumerate(label_files):
        print("Generating dataset from:", filename)

        label_file = labelme.LabelFile(filename=filename)

        base = osp.splitext(osp.basename(filename))[0]
        out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")

        img = labelme.utils.img_data_to_arr(label_file.imageData)
        imgviz.io.imsave(out_img_file, img)
        data["images"].append(
            dict(
                license=0,
                url=None,
                file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                height=img.shape[0],
                width=img.shape[1],
                date_captured=None,
                id=image_id,
            )
        )

        masks = {}  # for area
        segmentations = collections.defaultdict(list)  # for segmentation
        for shape in label_file.shapes:
            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")
            shape_type = shape.get("shape_type", "polygon")
            mask = labelme.utils.shape_to_mask(
                img.shape[:2], points, shape_type
            )

            if group_id is None:
                group_id = uuid.uuid1()

            instance = (label, group_id)

            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            if shape_type == "circle":
                (x1, y1), (x2, y2) = points
                r = np.linalg.norm([x2 - x1, y2 - y1])
                # r(1-cos(a/2))<x, a=2*pi/N => N>pi/arccos(1-x/r)
                # x: tolerance of the gap between the arc and the line segment
                n_points_circle = max(int(np.pi / np.arccos(1 - 1 / r)), 12)
                i = np.arange(n_points_circle)
                x = x1 + r * np.sin(2 * np.pi / n_points_circle * i)
                y = y1 + r * np.cos(2 * np.pi / n_points_circle * i)
                points = np.stack((x, y), axis=1).flatten().tolist()
            else:
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        segmentations = dict(segmentations)

        for instance, mask in masks.items():
            cls_name, group_id = instance
            if cls_name not in class_name_to_id:
                continue
            cls_id = class_name_to_id[cls_name]

            mask = np.asfortranarray(mask.astype(np.uint8))
            mask = pycocotools.mask.encode(mask)
            area = float(pycocotools.mask.area(mask))
            bbox = pycocotools.mask.toBbox(mask).flatten().tolist()

            data["annotations"].append(
                dict(
                    id=len(data["annotations"]),
                    image_id=image_id,
                    category_id=cls_id,
                    segmentation=segmentations[instance],
                    area=area,
                    bbox=bbox,
                    iscrowd=0,
                )
            )

        if not args.noviz:
            viz = img
            if masks:
                labels, captions, masks = zip(
                    *[
                        (class_name_to_id[cnm], cnm, msk)
                        for (cnm, gid), msk in masks.items()
                        if cnm in class_name_to_id
                    ]
                )
                viz = imgviz.instances2rgb(
                    image=img,
                    labels=labels,
                    masks=masks,
                    captions=captions,
                    font_size=15,
                    line_width=2,
                )
            out_viz_file = osp.join(
                args.output_dir, "Visualization", base + ".jpg"
            )
            imgviz.io.imsave(out_viz_file, viz)

    with open(out_ann_file, "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    main()
#!/usr/bin/env python

import argparse
import collections
import datetime
import glob
import json
import os
import os.path as osp
import sys
import uuid

import imgviz
import numpy as np

import labelme

from yolox.data import COCO_CLASSES

try:
    import pycocotools.mask
except ImportError:
    print("Please install pycocotools:\n\n    pip install pycocotools\n")
    sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("input_dir", help="input annotated directory")
    parser.add_argument("output_dir", help="output dataset directory")
    # parser.add_argument("--labels", help="labels file", required=True)
    parser.add_argument(
        "--noviz", help="no visualization", action="store_true"
    )
    args = parser.parse_args()

    if osp.exists(args.output_dir):
        print("Output directory already exists:", args.output_dir)
        sys.exit(1)
    os.makedirs(args.output_dir)
    os.makedirs(osp.join(args.output_dir, "JPEGImages"))
    if not args.noviz:
        os.makedirs(osp.join(args.output_dir, "Visualization"))
    print("Creating dataset:", args.output_dir)

    now = datetime.datetime.now()

    data = dict(
        info=dict(
            description=None,
            url=None,
            version=None,
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        licenses=[
            dict(
                url=None,
                id=0,
                name=None,
            )
        ],
        images=[
            # license, url, file_name, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    class_name_to_id = {}
    labels = list(COCO_CLASSES)
    labels.insert(0, "__ignore__")
    for i, line in enumerate(labels):
        class_id = i - 1  # starts with -1
        class_name = line.strip()
        if class_id == -1:
            assert class_name == "__ignore__"
            continue
        class_name_to_id[class_name] = class_id
        data["categories"].append(
            dict(
                supercategory=None,
                id=class_id,
                name=class_name,
            )
        )

    out_ann_file = osp.join(args.output_dir, "annotations.json")
    label_files = glob.glob(osp.join(args.input_dir, "*.json"))
    for image_id, filename in enumerate(label_files):
        print("Generating dataset from:", filename)

        label_file = labelme.LabelFile(filename=filename)

        base = osp.splitext(osp.basename(filename))[0]
        out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")

        img = labelme.utils.img_data_to_arr(label_file.imageData)
        imgviz.io.imsave(out_img_file, img)
        data["images"].append(
            dict(
                license=0,
                url=None,
                file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                height=img.shape[0],
                width=img.shape[1],
                date_captured=None,
                id=image_id,
            )
        )

        masks = {}  # for area
        segmentations = collections.defaultdict(list)  # for segmentation
        for shape in label_file.shapes:
            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")
            shape_type = shape.get("shape_type", "polygon")
            mask = labelme.utils.shape_to_mask(
                img.shape[:2], points, shape_type
            )

            if group_id is None:
                group_id = uuid.uuid1()

            instance = (label, group_id)

            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            if shape_type == "circle":
                (x1, y1), (x2, y2) = points
                r = np.linalg.norm([x2 - x1, y2 - y1])
                # r(1-cos(a/2))<x, a=2*pi/N => N>pi/arccos(1-x/r)
                # x: tolerance of the gap between the arc and the line segment
                n_points_circle = max(int(np.pi / np.arccos(1 - 1 / r)), 12)
                i = np.arange(n_points_circle)
                x = x1 + r * np.sin(2 * np.pi / n_points_circle * i)
                y = y1 + r * np.cos(2 * np.pi / n_points_circle * i)
                points = np.stack((x, y), axis=1).flatten().tolist()
            else:
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        segmentations = dict(segmentations)

        for instance, mask in masks.items():
            cls_name, group_id = instance
            if cls_name not in class_name_to_id:
                continue
            cls_id = class_name_to_id[cls_name]

            mask = np.asfortranarray(mask.astype(np.uint8))
            mask = pycocotools.mask.encode(mask)
            area = float(pycocotools.mask.area(mask))
            bbox = pycocotools.mask.toBbox(mask).flatten().tolist()

            data["annotations"].append(
                dict(
                    id=len(data["annotations"]),
                    image_id=image_id,
                    category_id=cls_id,
                    segmentation=segmentations[instance],
                    area=area,
                    bbox=bbox,
                    iscrowd=0,
                )
            )

        if not args.noviz:
            viz = img
            if masks:
                labels, captions, masks = zip(
                    *[
                        (class_name_to_id[cnm], cnm, msk)
                        for (cnm, gid), msk in masks.items()
                        if cnm in class_name_to_id
                    ]
                )
                viz = imgviz.instances2rgb(
                    image=img,
                    labels=labels,
                    masks=masks,
                    captions=captions,
                    font_size=15,
                    line_width=2,
                )
            out_viz_file = osp.join(
                args.output_dir, "Visualization", base + ".jpg"
            )
            imgviz.io.imsave(out_viz_file, viz)

    with open(out_ann_file, "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    main()
coco_cata
- JPEGImages
  - x.jpg
  - y.jpg
- Visualization
  - x.jpg
  - y.jpg
- annotations.json
coco_cata
- JPEGImages
  - x.jpg
  - y.jpg
- Visualization
  - x.jpg
  - y.jpg
- annotations.json

最后整理成如下格式:

coco
- annotations
  - annotations.json
- train2017
  - JPEGImages
    - x.jpg
    - y.jpg
  - Visualization
    - x.jpg
    - y.jpg
- val2017
  - JPEGImages
    - x.jpg
    - y.jpg
  - Visualization
    - x.jpg
    - y.jpg
coco
- annotations
  - annotations.json
- train2017
  - JPEGImages
    - x.jpg
    - y.jpg
  - Visualization
    - x.jpg
    - y.jpg
- val2017
  - JPEGImages
    - x.jpg
    - y.jpg
  - Visualization
    - x.jpg
    - y.jpg

第三步:修改 coco classes

yolox/data/datasets/coco_classes.py

py
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

COCO_CLASSES = (
    'xxx',
    'yyy',
)
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

COCO_CLASSES = (
    'xxx',
    'yyy',
)

第四步:修改模型训练参数

exps/example/custom/yolox_s.py

py
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import os

from yolox.data import COCO_CLASSES
from yolox.exp import Exp as MyExp


class Exp(MyExp):
    def __init__(self):
        super(Exp, self).__init__()
        self.depth = 0.33
        self.width = 0.50
        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

        # Define yourself dataset path
        self.data_dir = r"D:\Work\ai\YOLOX\datasets\coco"
        self.train_ann = "annotations.json"
        self.val_ann = "annotations.json"

        self.num_classes = len(COCO_CLASSES)

        self.max_epoch = 100
        self.data_num_workers = 4
        self.eval_interval = 10
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import os

from yolox.data import COCO_CLASSES
from yolox.exp import Exp as MyExp


class Exp(MyExp):
    def __init__(self):
        super(Exp, self).__init__()
        self.depth = 0.33
        self.width = 0.50
        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

        # Define yourself dataset path
        self.data_dir = r"D:\Work\ai\YOLOX\datasets\coco"
        self.train_ann = "annotations.json"
        self.val_ann = "annotations.json"

        self.num_classes = len(COCO_CLASSES)

        self.max_epoch = 100
        self.data_num_workers = 4
        self.eval_interval = 10

第五步:训练模型

使用 tools/train.py 执行训练。

最后:In memory of Dr. Jian Sun

Without the guidance of Dr. Jian Sun, YOLOX would not have been released and open sourced to the community. The passing away of Dr. Jian is a huge loss to the Computer Vision field. We add this section here to express our remembrance and condolences to our captain Dr. Jian. It is hoped that every AI practitioner in the world will stick to the concept of "continuous innovation to expand cognitive boundaries, and extraordinary technology to achieve product value" and move forward all the way.

没有孙剑博士的指导,YOLOX 也不会问世并开源给社区使用。 孙剑博士的离去是 CV 领域的一大损失,我们在此特别添加了这个部分来表达对我们的“船长”孙老师的纪念和哀思。 希望世界上的每个 AI 从业者秉持着“持续创新拓展认知边界,非凡科技成就产品价值”的观念,一路向前。

最后编辑时间:

Version 4.0 (framework-1.0.0-rc.20)