Skip to content

MMdetection 多光谱数据集加载

使用 MMdetection 3.x 框架加载双光谱数据集的配置

下面的配置部分不是很合理,请直接使用代码:https://github.com/echoniuniu/DSOD

这里使用MMdetection 3.x (3.10)

1. 配置文件数据集定义

dataset_type = 'MultispectralDataset' 
# 以train 为例,更改数据集类型,并删除原始data_prefix 配置,指定两个模态的路径
train_dataloader = dict(
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='Annotations/rgb_train.json',
        data_prefix=dict(
            _delete_=True,
            img1='train/rgb/',
            img2='train/ir/',
            # ...  参数名称无实义,但是有顺序
        ),
        pipeline=train_pipeline,
    )
)

2. 添加多光谱数据集

添加自定义数据集

在 mmdet/datasets 新建multispectral_detection.py 文件内容如下

这里对COCO 父类的 parse_data_info()只有部分改动,向datainfo中加入 img_path_m 值用来表示多光谱数据集,它指定多光谱路径的第一个img1的路径为原始img的路径,兼容LoadImageFromFile

如果需要继续添加自定义具体数据集,修改补充 METAINFO 参数即可

import os.path as osp
from typing import List, Union

from mmdet.registry import DATASETS
from .coco import CocoDataset


@DATASETS.register_module()
class MultispectralDataset(CocoDataset):

       def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
        """

        modify from mmdet.datasets.coco.CocoDataset.parse_data_info

        Parse raw annotation to target format.

        Args:
            raw_data_info (dict): Raw data information load from ``ann_file``

        Returns:
            Union[dict, List[dict]]: Parsed annotation.
        """
        img_info = raw_data_info['raw_img_info']
        ann_info = raw_data_info['raw_ann_info']

        data_info = {}

        # TODO: need to change data_prefix['img'] to data_prefix['img_path']
        # img_path = osp.join(self.data_prefix['img'], img_info['file_name'])

        img_path_m: dict = {k: osp.join(v, img_info['file_name']) for k, v in self.data_prefix.items()}
        # 这里,把多光谱数据的路径加入到img_path_m中
        # 支持双模态或更多,
        img_path = img_path_m[next(iter(img_path_m))]
        # 设置第一个模态的路径为img_path
        if self.data_prefix.get('seg', None):
            seg_map_path = osp.join(
                self.data_prefix['seg'],
                img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
        else:
            seg_map_path = None

        data_info['img_path_m'] = img_path_m# 多光谱数据的路径
        data_info['img_path'] = img_path # 为了兼容单模太数据加载器,这里保留了img_path
        data_info['img_id'] = img_info['img_id']
        data_info['seg_map_path'] = seg_map_path
        data_info['height'] = img_info['height']
        data_info['width'] = img_info['width']

        if self.return_classes:
            data_info['text'] = self.metainfo['classes']
            data_info['custom_entities'] = True

        instances = []
        for i, ann in enumerate(ann_info):
            instance = {}

            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
            if inter_w * inter_h == 0:
                continue
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            if ann['category_id'] not in self.cat_ids:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]

            if ann.get('iscrowd', False):
                instance['ignore_flag'] = 1
            else:
                instance['ignore_flag'] = 0
            instance['bbox'] = bbox
            instance['bbox_label'] = self.cat2label[ann['category_id']]

            if ann.get('segmentation', None):
                instance['mask'] = ann['segmentation']

            instances.append(instance)
        data_info['instances'] = instances
        return data_info

class YouSelfMldata(MultispectralDataset):
    # 示例 添加自己的指定名称和类别的数据集
    # METAINFO 的定义请参考,CocoDataset
    # mmdetection3.x 舍弃了 classes=() 注册数据集类型的配置
    METAINFO = {
        'classes': ('person',),  
        'palette': [(220, 20, 60)]
    }

注册数据集类型

MMDEtection3.x 支持Python语法定义,而不用写注册文件(下面的注册也是,如果熟悉可不用此种注册方法)

修改 mmdet/datasets/init.py 如下,写入自定义的多光谱数据集

# ...  前面内容省略,追加下面的即可

# 自定义数据集
from .multispectral_detection import MultispectralDataset

__all__ = [
    'XMLDataset', 'CocoDataset', 'DeepFashionDataset', 'VOCDataset',
    'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', 'LVISV1Dataset',
    'WIDERFaceDataset', 'get_loading_pipeline', 'CocoPanopticDataset',
    'MultiImageMixDataset', 'OpenImagesDataset', 'OpenImagesChallengeDataset',
    'AspectRatioBatchSampler', 'ClassAwareSampler', 'MultiSourceSampler',
    'GroupMultiSourceSampler', 'BaseDetDataset', 'CrowdHumanDataset',
    'Objects365V1Dataset', 'Objects365V2Dataset', 'DSDLDetDataset',
    'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler',
    'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler',
    'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset',
    'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset',
    'ADE20KInstanceDataset', 'iSAIDDataset',
    # 自定义数据集
    'MultispectralDataset','YouSelfMldata'
]

3. 多光谱数据集加载器

在 mmdet/datasets/transforms/loading.py 追加一个多光谱数据加载器的实现

# 自定义数据集加载器
import mmengine.fileio as fileio

@TRANSFORMS.register_module()
class LoadMultispectralImageFile(LoadImageFromFile):

    def transform(self, results: dict):
        """
        result 必须包含 img_path_m 字段
        如 result['img_path_m']={'rbg'='xxx','ir'='xxx'}
        """

        # 判断是否是多光谱数据集,存在 img_path_m 字段
        assert isinstance(results['img_path_m'], dict)

        imgs: list = []
        for img_path in results['img_path_m'].values():
            filename = img_path
            try:
                if self.file_client_args is not None:
                    file_client = fileio.FileClient.infer_client(
                        self.file_client_args, filename)
                    img_bytes = file_client.get(filename)
                else:
                    img_bytes = fileio.get(
                        filename, backend_args=self.backend_args)
                img = mmcv.imfrombytes(
                    img_bytes, flag=self.color_type, backend=self.imdecode_backend)
            except Exception as e:
                if self.ignore_empty:
                    return None
                else:
                    raise e
            # in some cases, images are not read successfully, the img would be
            # `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
            assert img is not None, f'failed to load image: {filename}'
            if self.to_float32:
                img = img.astype(np.float32)
            imgs.append(img)
        img = np.concatenate(imgs, axis=2)  # 拼接多光谱图像
        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results

注册数据集加载器

修改 mmdet/datasets/transforms/init.py 如下,写入自定义的多光谱数据集

# ...  前面内容省略,追加下面的即可
# 自定义数据加载器
from .loading import LoadMultispectralImageFile

__all__ = [
    'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose',
    'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations',
    'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip',
    'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand',
    'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
    'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize',
    'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift',
    'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste',
    'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform',
    'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize',
    'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing',
    'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
    'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader',
    'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample',
    'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', 'ResizeShortestEdge',
    # 自定义数据加载器
    'LoadMultispectralImageFile'
]

4. 添加多光谱数据预处理器

添加一个多光谱的 DataProcessor
DataPreprocessor 是3.x 引入的新特性,这部分原始是在 pipline中的,现在抽出来了,当然原则上也可以继续定义成 pipline

这里其实置修改了一句话,就是 默认的 len(mean) 被限制在 1或3

 assert len(mean) == len(std),

在 mmdet/models/data_preprocessors 添加 ml_data_preprocessor.py 文件,写入以下内容

import math  
from numbers import Number  
from typing import List, Optional, Sequence, Union  

import numpy as np  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from mmengine.model.base_model import BaseDataPreprocessor  
from mmengine.model.utils import stack_batch  
from mmengine.registry import MODELS  
from mmengine.structures import BaseDataElement  
from mmengine.structures import PixelData  
from mmengine.utils import is_seq_of  

from mmdet.models.utils.misc import samplelist_boxtype2tensor  
from mmdet.registry import MODELS  
from mmdet.structures import DetDataSample  

CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, None]    
@MODELS.register_module()  
class MlImgDataPreprocessor(BaseDataPreprocessor):  
    """  
    修改自 mmdet/models/data_preprocessors/data_preprocessor.py ImgDataPreprocessor  
    """  

    def __init__(self,  
                 mean: Optional[Sequence[Union[float, int]]] = None,  
                 std: Optional[Sequence[Union[float, int]]] = None,  
                 pad_size_divisor: int = 1,  
                 pad_value: Union[float, int] = 0,  
                 bgr_to_rgb: bool = False,  
                 rgb_to_bgr: bool = False,  
                 non_blocking: Optional[bool] = False):  
        super().__init__(non_blocking)  

        assert not (bgr_to_rgb and rgb_to_bgr), (  
            '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')  
        assert (mean is None) == (std is None), (  
            'mean and std should be both None or tuple')  
        if mean is not None:  
            # 放松对 mean std 的限制,原始的设置是只能是 1 或者 3  
            assert len(mean) == len(std), ('`mean` and `std` should have same length, but got'  
                                           f'{len(mean)} and {len(std)}')  
            self._enable_normalize = True  
            self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False)  
            self.register_buffer('std', torch.tensor(std).view(-1, 1, 1), False)  
        else:  
            self._enable_normalize = False  
        self._channel_conversion = rgb_to_bgr or bgr_to_rgb  
        self.pad_size_divisor = pad_size_divisor  
        self.pad_value = pad_value  

    def forward(self, data: dict, training: bool = False) -> Union[dict, list]:  
        """Performs normalization、padding and bgr2rgb conversion based on  
        ``BaseDataPreprocessor``.  

        Args:  
            data (dict): Data sampled from dataset. If the collate  
                function of DataLoader is :obj:`pseudo_collate`, data will be a  
                list of dict. If collate function is :obj:`default_collate`,  
                data will be a tuple with batch input tensor and list of data  
                samples.  
            training (bool): Whether to enable training time augmentation. If  
                subclasses override this method, they can perform different  
                preprocessing strategies for training and testing based on the  
                value of ``training``.  

        Returns:  
            dict or list: Data in the same format as the model input.  
        """  
        data = self.cast_data(data)  # type: ignore  
        _batch_inputs = data['inputs']  
        # Process data with `pseudo_collate`.  
        if is_seq_of(_batch_inputs, torch.Tensor):  
            batch_inputs = []  
            for _batch_input in _batch_inputs:  
                # channel transform  
                if self._channel_conversion:  
                    assert self._channel_conversion and _batch_input.shape[0] == 3  
                    # 假设要继续使用通道转换,那么输入的数据应该是 3 通道的  
                    _batch_input = _batch_input[[2, 1, 0], ...]  
                # Convert to float after channel conversion to ensure  
                # efficiency  
                _batch_input = _batch_input.float()  
                # Normalization.  
                if self._enable_normalize:  
                    if self.mean.shape[0] == 3:  
                        assert _batch_input.dim() == _batch_input.shape[0], (  
                            'If the mean has N values, the input tensorshould in shape of (N, H, W), but got the tensor'  
                            f'with shape {_batch_input.shape}')  
                    _batch_input = (_batch_input - self.mean) / self.std  
                batch_inputs.append(_batch_input)  
            # Pad and stack Tensor.  
            batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor,  
                                       self.pad_value)  
        # Process data with `default_collate`.  
        elif isinstance(_batch_inputs, torch.Tensor):  
            assert _batch_inputs.dim() == 4, (  
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '  
                'or a list of tensor, but got a tensor with shape: '  
                f'{_batch_inputs.shape}')  
            if self._channel_conversion:  
                _batch_inputs = _batch_inputs[:, [2, 1, 0], ...]  
            # Convert to float after channel conversion to ensure  
            # efficiency  
            _batch_inputs = _batch_inputs.float()  
            if self._enable_normalize:  
                _batch_inputs = (_batch_inputs - self.mean) / self.std  
            h, w = _batch_inputs.shape[2:]  
            target_h = math.ceil(  
                h / self.pad_size_divisor) * self.pad_size_divisor  
            target_w = math.ceil(  
                w / self.pad_size_divisor) * self.pad_size_divisor  
            pad_h = target_h - h  
            pad_w = target_w - w  
            batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h),  
                                 'constant', self.pad_value)  
        else:  
            raise TypeError('Output of `cast_data` should be a dict of '  
                            'list/tuple with inputs and data_samples, '  
                            f'but got {type(data)}: {data}')  
        data['inputs'] = batch_inputs  
        data.setdefault('data_samples', None)  
        return data  


@MODELS.register_module()  
class MlDetDataPreprocessor(MlImgDataPreprocessor):  
    """  
    相对元原始的 DetDataPreprocessor 没有任何修改,因为多重继承的语法限制,
    如果不修改源码只能创建一个新的
    """  
    def __init__(self,  
                 mean: Sequence[Number] = None,  
                 std: Sequence[Number] = None,  
                 pad_size_divisor: int = 1,  
                 pad_value: Union[float, int] = 0,  
                 pad_mask: bool = False,  
                 mask_pad_value: int = 0,  
                 pad_seg: bool = False,  
                 seg_pad_value: int = 255,  
                 bgr_to_rgb: bool = False,  
                 rgb_to_bgr: bool = False,  
                 boxtype2tensor: bool = True,  
                 non_blocking: Optional[bool] = False,  
                 batch_augments: Optional[List[dict]] = None):  
        super().__init__(  
            mean=mean,  
            std=std,  
            pad_size_divisor=pad_size_divisor,  
            pad_value=pad_value,  
            bgr_to_rgb=bgr_to_rgb,  
            rgb_to_bgr=rgb_to_bgr,  
            non_blocking=non_blocking)  
        if batch_augments is not None:  
            self.batch_augments = nn.ModuleList(  
                [MODELS.build(aug) for aug in batch_augments])  
        else:  
            self.batch_augments = None  
        self.pad_mask = pad_mask  
        self.mask_pad_value = mask_pad_value  
        self.pad_seg = pad_seg  
        self.seg_pad_value = seg_pad_value  
        self.boxtype2tensor = boxtype2tensor  

    def forward(self, data: dict, training: bool = False) -> dict:  
        """Perform normalization、padding and bgr2rgb conversion based on  
        ``BaseDataPreprocessor``.  

        Args:  
            data (dict): Data sampled from dataloader.  
            training (bool): Whether to enable training time augmentation.  

        Returns:  
            dict: Data in the same format as the model input.  
        """  

        batch_pad_shape = self._get_pad_shape(data)  

        data = super().forward(data=data, training=training)  
        # DetDataPreprocessor  ImgDataPreprocessor  
        inputs, data_samples = data['inputs'], data['data_samples']  

        if data_samples is not None:  
            # NOTE the batched image size information may be useful, e.g.  
            # in DETR, this is needed for the construction of masks, which is  
            # then used for the transformer_head.  
            batch_input_shape = tuple(inputs[0].size()[-2:])  
            for data_sample, pad_shape in zip(data_samples, batch_pad_shape):  
                data_sample.set_metainfo({  
                    'batch_input_shape': batch_input_shape,  
                    'pad_shape': pad_shape  
                })  

            if self.boxtype2tensor:  
                samplelist_boxtype2tensor(data_samples)  

            if self.pad_mask and training:  
                self.pad_gt_masks(data_samples)  

            if self.pad_seg and training:  
                self.pad_gt_sem_seg(data_samples)  

        if training and self.batch_augments is not None:  
            for batch_aug in self.batch_augments:  
                inputs, data_samples = batch_aug(inputs, data_samples)  

        return {'inputs': inputs, 'data_samples': data_samples}  

    def _get_pad_shape(self, data: dict) -> List[tuple]:  
        """Get the pad_shape of each image based on data and  
        pad_size_divisor."""  
        _batch_inputs = data['inputs']  
        # Process data with `pseudo_collate`.  
        if is_seq_of(_batch_inputs, torch.Tensor):  
            batch_pad_shape = []  
            for ori_input in _batch_inputs:  
                pad_h = int(  
                    np.ceil(ori_input.shape[1] /  
                            self.pad_size_divisor)) * self.pad_size_divisor  
                pad_w = int(  
                    np.ceil(ori_input.shape[2] /  
                            self.pad_size_divisor)) * self.pad_size_divisor  
                batch_pad_shape.append((pad_h, pad_w))  
        # Process data with `default_collate`.  
        elif isinstance(_batch_inputs, torch.Tensor):  
            assert _batch_inputs.dim() == 4, (  
                'The input of `ImgDataPreprocessor` should be a NCHW tensor '  
                'or a list of tensor, but got a tensor with shape: '  
                f'{_batch_inputs.shape}')  
            pad_h = int(  
                np.ceil(_batch_inputs.shape[1] /  
                        self.pad_size_divisor)) * self.pad_size_divisor  
            pad_w = int(  
                np.ceil(_batch_inputs.shape[2] /  
                        self.pad_size_divisor)) * self.pad_size_divisor  
            batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]  
        else:  
            raise TypeError('Output of `cast_data` should be a dict '  
                            'or a tuple with inputs and data_samples, but got'  
                            f'{type(data)}{data}')  
        return batch_pad_shape  

    def pad_gt_masks(self,  
                     batch_data_samples: Sequence[DetDataSample]) -> None:  
        """Pad gt_masks to shape of batch_input_shape."""  
        if 'masks' in batch_data_samples[0].gt_instances:  
            for data_samples in batch_data_samples:  
                masks = data_samples.gt_instances.masks  
                data_samples.gt_instances.masks = masks.pad(  
                    data_samples.batch_input_shape,  
                    pad_val=self.mask_pad_value)  

    def pad_gt_sem_seg(self,  
                       batch_data_samples: Sequence[DetDataSample]) -> None:  
        """Pad gt_sem_seg to shape of batch_input_shape."""  
        if 'gt_sem_seg' in batch_data_samples[0]:  
            for data_samples in batch_data_samples:  
                gt_sem_seg = data_samples.gt_sem_seg.sem_seg  
                h, w = gt_sem_seg.shape[-2:]  
                pad_h, pad_w = data_samples.batch_input_shape  
                gt_sem_seg = F.pad(  
                    gt_sem_seg,  
                    pad=(0, max(pad_w - w, 0), 0, max(pad_h - h, 0)),  
                    mode='constant',  
                    value=self.seg_pad_value)  
                data_samples.gt_sem_seg = PixelData(sem_seg=gt_sem_seg)

在数据预处理器中注册

# 修改 mmdet/models/data_preprocessors/__init__.py

# ... 上面省略
# 添加自己的data_preprocessor
from .ml_data_preprocessor import MlDetDataPreprocessor, MlImgDataPreprocessor

__all__ = [
    'DetDataPreprocessor', 'BatchSyncRandomResize', 'BatchFixedSizePad',
    'MultiBranchDataPreprocessor', 'BatchResize', 'BoxInstDataPreprocessor',
    'TrackDataPreprocessor', 'ReIDDataPreprocessor',
    # 添加自己的data_preprocessor
    'MlDetDataPreprocessor', 'MlImgDataPreprocessor'
]

最后

至此多光谱数据集加载器已经实现,到模型中按照索引分离出两个模态数据即可