如何使用YOLOv5训练一个包含六个类别的可见光船舶目标检测数据集,并附上详细的训练代码。

数据集介绍

该数据集包含7000张图像,共有六个类别,每个图像已经以Pascal VOC XML格式进行了标注。具体信息如下:

 

  • 训练集:包含4900张图像(含XML标注)
  • 验证集:包含1400张图像(含XML标注)
  • 测试集:包含700张图像(含XML标注)

目标检测标签有六个类别:

 

  • 货船:Cargo Ship
  • 客船:Passenger Ship
  • 渔船:Fishing Boat
  • 油轮:Tanker
  • 军舰:Warship
  • 游艇:Yacht

数据集准备

假设你的数据集目录结构如下:

ship_detection/
├── train/
│   ├── images/
│   └── annotations/
├── valid/
│   ├── images/
│   └── annotations/
├── test/
│   ├── images/
│   └── annotations/
└── README.txt  # 数据说明

其中:

  • train/images/ 存放训练集的图像。
  • train/annotations/ 存放训练集的标注文件(.xml)。
  • valid/images/ 存放验证集的图像。
  • valid/annotations/ 存放验证集的标注文件(.xml)。
  • test/images/ 存放测试集的图像。
  • test/annotations/ 存放测试集的标注文件(.xml)。

将VOC格式转换为YOLO格式

 

由于数据集是以VOC XML格式标注的,你需要将其转换为YOLO TXT格式。可以使用一些工具库来完成这个转换,这里我们提供一个简单的Python脚本来完成转换:

import os
import xml.etree.ElementTree as ET
from pathlib import Path
import shutil

def convert_voc_to_yolo(voc_path, yolo_path):
    # 创建YOLO格式的目录结构
    Path(yolo_path).mkdir(parents=True, exist_ok=True)
    Path(os.path.join(yolo_path, "images")).mkdir(exist_ok=True)
    Path(os.path.join(yolo_path, "labels")).mkdir(exist_ok=True)

    # 复制图像文件到新的目录
    for img_file in os.listdir(os.path.join(voc_path, "images")):
        shutil.copy(os.path.join(voc_path, "images", img_file), os.path.join(yolo_path, "images"))

    # 转换XML标注文件为YOLO格式
    for xml_file in os.listdir(os.path.join(voc_path, "annotations")):
        tree = ET.parse(os.path.join(voc_path, "annotations", xml_file))
        root = tree.getroot()

        width = int(root.find("size").find("width").text)
        height = int(root.find("size").find("height").text)

        # 创建YOLO格式的标签文件
        label_file = os.path.join(yolo_path, "labels", xml_file.replace(".xml", ".txt"))
        
        with open(label_file, "w") as f:
            for obj in root.findall("object"):
                cls = obj.find("name").text.lower().strip()
                if cls not in ["cargo ship", "passenger ship", "fishing boat", "tanker", "warship", "yacht"]:
                    continue
                xmlbox = obj.find("bndbox")
                x_min = float(xmlbox.find("xmin").text)
                y_min = float(xmlbox.find("ymin").text)
                x_max = float(xmlbox.find("xmax").text)
                y_max = float(xmlbox.find("ymax").text)
                
                x_center = (x_min + x_max) / 2.0
                y_center = (y_min + y_max) / 2.0
                w = x_max - x_min
                h = y_max - y_min
                
                x_center /= width
                y_center /= height
                w /= width
                h /= height
                
                class_id = {"cargo ship": 0, "passenger ship": 1, "fishing boat": 2, "tanker": 3, "warship": 4, "yacht": 5}[cls]
                
                line = f"{class_id} {x_center} {y_center} {w} {h}\n"
                f.write(line)

if __name__ == "__main__":
    voc_train_path = "path/to/voc/train/"
    yolo_train_path = "path/to/yolo/train/"
    convert_voc_to_yolo(voc_train_path, yolo_train_path)
    
    voc_valid_path = "path/to/voc/valid/"
    yolo_valid_path = "path/to/yolo/valid/"
    convert_voc_to_yolo(voc_valid_path, yolo_valid_path)
    
    voc_test_path = "path/to/voc/test/"
    yolo_test_path = "path/to/yolo/test/"
    convert_voc_to_yolo(voc_test_path, yolo_test_path)

数据配置文件

创建一个data.yaml文件来指定数据集路径、类别数量等信息:

path: ../ship_detection/
train: yolo/train/images/
val: yolo/valid/images/
test: yolo/test/images/  # test set
nc: 6  # number of classes
names: ['cargo ship', 'passenger ship', 'fishing boat', 'tanker', 'warship', 'yacht']  # class names

模型训练

  1. 训练脚本
    • 使用YOLOv5提供的训练脚本进行训练。可以调整超参数以获得更好的性能。例如:
      python train.py --img 640 --batch 16 --epochs 300 --data ./data.yaml --cfg models/yolov5s.yaml --weights '' --name ship_detection_project --cache
    • 这里--epochs 300表示训练300个epoch。

训练脚本详细代码

下面是一个详细的训练脚本示例:

#!/bin/bash
# 训练YOLOv5模型
cd yolov5
python train.py \
    --img 640 \
    --batch 16 \
    --epochs 300 \
    --data ./data.yaml \
    --cfg models/yolov5s.yaml \
    --weights '' \
    --name ship_detection_project \
    --cache

模型评估

训练结束后,使用验证集评估模型性能:

python val.py --data ./data.yaml --weights runs/train/ship_detection_project/weights/best.pt --img 640 --batch 32

预测示例

下面是一个使用训练好的模型进行预测的Python脚本示例:

import torch
from PIL import Image
import numpy as np
import cv2

def load_model(weights_path):
    # 加载预训练模型
    model = torch.hub.load('ultralytics/yolov5', 'custom', path=weights_path, force_reload=True)
    return model

def detect_ships(model, image_path, save_dir='results'):
    # 加载图像
    img = Image.open(image_path)
    
    # 使用模型进行预测
    results = model(img)
    
    # 可视化结果
    results.show()
    
    # 保存结果
    results.save(save_dir=save_dir)

if __name__ == '__main__':
    weights_path = 'path/to/your/best.pt'  # 模型权重文件路径
    image_path = 'path/to/your/image.jpg'  # 测试图像路径
    
    # 加载模型
    model = load_model(weights_path)
    
    # 进行预测
    detect_ships(model, image_path)

完整的训练和预测流程

  1. 克隆YOLOv5仓库

    git clone https://github.com/ultralytics/yolov5.git
    cd yolov5
    pip install -r requirements.txt
  2. 转换VOC格式为YOLO格式

    python convert_voc_to_yolo.py
  3. 创建数据配置文件

    path: ../ship_detection/
    train: yolo/train/images/
    val: yolo/valid/images/
    test: yolo/test/images/  # test set
    nc: 6  # number of classes
    names: ['cargo ship', 'passenger ship', 'fishing boat', 'tanker', 'warship', 'yacht']  # class names
  4. 运行训练脚本

    bash train.sh
  5. 运行预测脚本

    python detect.py

注意事项

  • 数据集质量:确保数据集的质量,包括清晰度、标注准确性等。
  • 模型选择:可以选择更强大的模型版本(如YOLOv5l或YOLOv5x)以提高性能。
  • 超参数调整:根据实际情况调整超参数,如批量大小(--batch)、图像大小(--img)等。
  • 监控性能:训练过程中监控损失函数和mAP指标,确保模型收敛。

通过上述步骤,你可以使用YOLOv5来训练一个包含六个类别的可见光船舶目标检测数据集,并使用训练好的模型进行预测。

 

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐