# COCO数据集DIY训练

下面是对COCO数据集的简要介绍：

```

COCO（Common Objects in Context）数据集是一个广泛使用的计算机视觉数据集，主要用于物体检测、分割、关键点检测等任务。由微软研究院开发，COCO数据集的目标是提供一个具有高质量标注的丰富数据集，以推动计算机视觉技术的发展。

COCO 数据集的主要特点：
数据内容：

图像数量：包含超过 33 万张图像。
类别：涵盖 80 个物体类别，例如人、车、动物、家具等。
标注：每张图像中都有物体的标注，支持物体检测、分割（实例分割）和关键点检测（如人体姿态估计）等任务。
数据类型：

物体检测：标注了物体的边界框（bounding box）。
实例分割：每个物体实例的像素级分割掩码（mask）。
关键点检测：人体关键点的位置标注。
场景描述：每张图像还包含描述图像内容的文本信息。
数据分布：

训练集：包含约 11.8 万张图像。
验证集：包含约 5,000 张图像。
测试集：包含约 20,000 张图像（用于竞赛和评估）。
数据集结构：

图像数据：存储在图像文件夹中。
标注数据：存储在 JSON 格式的文件中，包括物体边界框、分割掩码、关键点等信息。
使用场景：
物体检测：识别图像中的物体，并确定其位置。
实例分割：对图像中的每个物体进行像素级别的分割。
人体姿态估计：检测图像中人体的关键点位置。
场景理解：分析图像内容并生成描述。
COCO 数据集在计算机视觉领域被广泛使用，因为它提供了大量且多样化的标注数据，有助于训练和评估不同的视觉模型。
```

COCO数据集总体约20G，整个数据集数据量大，标签多，能够做到对多达80个标签进行检测，但是相应的，检测的准确度有所下降，为了满足用户在不同场景下对模型识别种类和识别精度的取舍权衡，Petoi实现了对COCO数据集中任意数量的标签进行数据集重新制作。

下面将讲解如何制作属于您自己的COCO数据集。

## COCO数据集下载

将下列代码复制为复制为coco\_download.py，脚本位于COCO数据集的目标位父目录。

```python
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm
import argparse

def download_url(url, dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    response = requests.get(url, stream=True)
    file_size = int(response.headers.get('content-length', 0))
    filename = os.path.join(dir, url.split('/')[-1])
    
    with open(filename, 'wb') as f, tqdm(
        desc=filename,
        total=file_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
                bar.update(len(chunk))

def extract_zip(file_path, extract_to):
    with ZipFile(file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

def main(dataset_path):
    images_path = os.path.join(dataset_path, 'images')
    labels_url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels.zip'
    data_urls = [
        'http://images.cocodataset.org/zips/train2017.zip',
        'http://images.cocodataset.org/zips/val2017.zip',
        'http://images.cocodataset.org/zips/test2017.zip'
    ]

    # Download and extract labels
    download_url(labels_url, dataset_path)
    extract_zip(os.path.join(dataset_path, 'coco2017labels.zip'), dataset_path)
    
    # Download and extract images
    if not os.path.exists(images_path):
        os.makedirs(images_path)

    for url in data_urls:
        zip_name = url.split('/')[-1]
        zip_path = os.path.join(images_path, zip_name)
        download_url(url, images_path)
        extract_zip(zip_path, images_path)
        os.remove(zip_path)  # Clean up zip file after extraction

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Download and extract COCO dataset.')

    main(".")
```

执行：

`python .\coco_download.py`

如果您在下载的过程中发现由于网络不稳定或者其他原因没有办法完整下载COCO数据集，那么请手动下载COCO数据集并将压缩包解压到相应位置。

<figure><img src="https://201656985-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MQ6a951Q6Jn1Zzt5Ajr-3369173170%2Fuploads%2Fv6arWdH2cNHiW82k8P1H%2Fc34609d980a59aae1fd8eb49754d3099.png?alt=media&#x26;token=954a2760-10d7-4efe-bee9-45d79d6dfbc8" alt=""><figcaption><p>直接在浏览器输入url网址即可下载压缩包</p></figcaption></figure>

保证最后的数据集格式为：

<figure><img src="https://201656985-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MQ6a951Q6Jn1Zzt5Ajr-3369173170%2Fuploads%2FM1Ssc2n1F59ArZIE57OI%2Fimage.png?alt=media&#x26;token=31674e90-350f-4fa5-9657-3c7a42669813" alt=""><figcaption></figcaption></figure>

使用下面的脚本对COCO数据集进行提取，可以命名为coco\_remake.py

脚本运行前请将src\_coco\_path修改为COCO2017的路径，将dst\_coco\_path修改为DIY\_COCO数据集的路径。将src\_yaml\_file修改为COCO2017官方YAML文件的路径，将dst\_yaml\_file修改为DIY\_COCO的YAML文件路径。如果您的数据集是手动下载的，您还需要手动下载COCO.yaml。下载链接如下：

<https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml>

```python
import os
import shutil
import yaml

src_coco_path = "E:\Project\yolov5\datasets\coco"
dst_coco_path = "E:\Project\yolov8_coco_simple\coco_04"
src_yaml_file ="E:\Project\yolov5\data\coco.yaml"
dst_yaml_file ="E:\Project\yolov8_coco_simple\coco_03_02_03.yaml"


def get_key_from_value(dictionary, value):
    for key, val in dictionary.items():
        if(val-int(value))==0:
            return key
           

def get_value_from_key(dictionary, key_item):
    for key, val in dictionary.items():
        if key == key_item:
            return val
    return None  # 如果没有找到匹配的值

def load_coco_labels(yaml_file):
    """
    Load COCO labels from a YAML file.
    :param yaml_file: Path to the YAML file
    :return: Dictionary mapping class names to their indices
    """
    with open(yaml_file, 'r', encoding='utf-8') as file:
        data = yaml.safe_load(file)
    return {v: k for k, v in data['names'].items()}

def get_class_indices(classes, labels_mapping):
    """
    Get indices for specific class names based on the labels mapping.
    :param classes: Set of class names to find indices for
    :param labels_mapping: Dictionary mapping class names to their indices
    :return: List of indices corresponding to the class names
    """
    return [labels_mapping[cls] for cls in classes if cls in labels_mapping]

def load_value_mapping(yaml_file):
    """
    从 YAML 文件中加载标签映射
    :param yaml_file: YAML 文件路径
    :return: 标签 ID 集合
    """
    with open(yaml_file, 'r') as file:
        data = yaml.safe_load(file)
    # 提取标签 ID
    labels = set(data['names'].values())
    return labels

def modify_list(lst):
    #print(lst)
    if not lst:
        return lst  # 如果列表为空，直接返回空列表
    
    first_element = lst[0]  # 提取第一个元素
    #print(first_element)
    dst_first_element_label = get_key_from_value(src_labels_values,first_element)
    #print(dst_first_element_label)
    dst_first_element       = get_value_from_key(dst_labels_values,dst_first_element_label)
    #print(dst_first_element)
    remaining_elements = lst[1:]  # 剩下的元素
    modified_list = []
    #print(">>>>>>>>>>")
    #print(len(remaining_elements))

    for i in range(0, (len(remaining_elements)//4)*4, 4):

        #print(remaining_elements[i])
        modified_list.append(dst_first_element)  # 添加第一个元素
        modified_list.append(' ')
        modified_list.append(remaining_elements[i])
        modified_list.append(' ') 
        modified_list.append(remaining_elements[i+1])  
        modified_list.append(' ')
        modified_list.append(remaining_elements[i+2])  
        modified_list.append(' ')
        modified_list.append(remaining_elements[i+3]) 
        modified_list.append('\n')  # 添加换行符
    
    return modified_list


if __name__ == '__main__':
    values = load_value_mapping(dst_yaml_file)
    #print(values)
    src_labels_values = load_coco_labels(src_yaml_file)
    #print(src_labels_values)
    dst_labels_values = load_coco_labels(dst_yaml_file)
    #print(dst_labels_values)

    gt_labels = get_class_indices(values,src_labels_values)

    #print(gt_labels)
    src_images_dir = os.path.join(src_coco_path, "images")
    src_labels_dir = os.path.join(src_coco_path, "labels")
    src_images_train_dir = os.path.join(src_images_dir, "train2017")
    src_images_val_dir = os.path.join(src_images_dir, "val2017")
    src_labels_train_dir = os.path.join(src_labels_dir, "train2017")
    src_labels_val_dir = os.path.join(src_labels_dir, "val2017")

    dst_images_dir = os.path.join(dst_coco_path, "images")
    dst_labels_dir = os.path.join(dst_coco_path, "labels")
    dst_images_train_dir = os.path.join(dst_images_dir, "train")
    dst_images_val_dir = os.path.join(dst_images_dir, "val")
    dst_labels_train_dir = os.path.join(dst_labels_dir, "train")
    dst_labels_val_dir = os.path.join(dst_labels_dir, "val")
    
    os.makedirs(dst_images_train_dir, exist_ok=True)
    os.makedirs(dst_images_val_dir, exist_ok=True)
    os.makedirs(dst_labels_train_dir, exist_ok=True)
    os.makedirs(dst_labels_val_dir, exist_ok=True)


    #print(src_labels_train_dir)
    for txt_file in os.listdir(src_labels_train_dir):
        if txt_file.endswith(".txt"):
            src_labels_train_file_path = os.path.join(src_labels_train_dir, txt_file)
            src_images_train_file_path = os.path.join(src_images_train_dir, txt_file.replace(".txt", ".jpg"))
            with open(src_labels_train_file_path, 'r') as f:
                    print(src_labels_train_file_path)
                    lines = f.readlines()
                    temp_lines=[]
                    temp_line=[]
                    for line in lines:
                        label_id = int(line.strip().split()[0])
                        if label_id in gt_labels:
                            temp_line=modify_list(line.strip().split())
                            temp_lines+=temp_line
                    #print(temp_lines)

                    if temp_lines:
                        print(temp_lines)
                        dst_labels_train_file_path = os.path.join(dst_labels_train_dir, txt_file)
                        dst_images_train_file_path = os.path.join(dst_images_train_dir, txt_file.replace(".txt", ".jpg"))
                        with open(dst_labels_train_file_path, 'w') as f_2:
                            for item in temp_lines:
                                f_2.write(f"{item}")
                        shutil.copy(src_images_train_file_path, dst_images_train_file_path)
        
    #print(src_labels_val_dir)
    for txt_file in os.listdir(src_labels_val_dir):
        if txt_file.endswith(".txt"):
            src_labels_val_file_path = os.path.join(src_labels_val_dir, txt_file)
            src_images_val_file_path = os.path.join(src_images_val_dir, txt_file.replace(".txt", ".jpg"))
            with open(src_labels_val_file_path, 'r') as f:
                    print(src_labels_val_file_path)
                    lines = f.readlines()
                    temp_lines=[]
                    temp_line=[]
                    for line in lines:
                        label_id = int(line.strip().split()[0])
                        if label_id in gt_labels:
                            temp_line=modify_list(line.strip().split())
                            temp_lines+=temp_line
                    #print(temp_lines)

                    if temp_lines:
                        print(temp_lines)
                        dst_labels_val_file_path = os.path.join(dst_labels_val_dir, txt_file)
                        dst_images_val_file_path = os.path.join(dst_images_val_dir, txt_file.replace(".txt", ".jpg"))
                        with open(dst_labels_val_file_path, 'w') as f_2:
                            for item in temp_lines:
                                f_2.write(f"{item}")
                        shutil.copy(src_images_val_file_path, dst_images_val_file_path)

```

您可以将DIY\_COCO的YAML文件写成如下样式，其中标签需要选取COCO数据集里本来就存在的标签。但是标签的顺序不必与COCO数据集相同，只需要从0开始计数即可。

```yaml

train: E:\Project\yolov8_coco_simple\coco_04\images\train
val: E:\Project\yolov8_coco_simple\coco_04\images\val

# Classes
names:
  0: bicycle
```

该YAML文件包括了train和val两个目录的路径以及标签顺序和对应标签。对于本地训练，请使用绝对路径。对于云端训练，需要将该YAML文件修改为如下格式：

<figure><img src="https://201656985-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MQ6a951Q6Jn1Zzt5Ajr-3369173170%2Fuploads%2FU3wEHRZGmCFoKly2l6yi%2Fimage.png?alt=media&#x26;token=37756066-2c9e-461a-8a73-972b70167f34" alt=""><figcaption></figcaption></figure>

并且，在云端训练时，一定要遵从下面这种数据集的组织形式：

<figure><img src="https://201656985-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MQ6a951Q6Jn1Zzt5Ajr-3369173170%2Fuploads%2FEPljbfC4ernqtc3u4KrT%2Fimage.png?alt=media&#x26;token=02443f34-9736-42f7-bc85-8e70b1286a1e" alt=""><figcaption></figcaption></figure>

再结合模型训练部分的步骤，您就可以训练出属于您自己的基于COCO部分数据集的模型啦！
