# -*- coding: utf-8 -*- # BSD 3-Clause License # Copyright (c) Soumith Chintala 2016, # All rights reserved. # --------------------------------------------------------------------- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # This file has been modified by Megvii ("Megvii Modifications"). # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. # --------------------------------------------------------------------- import os from typing import Dict, List, Tuple import cv2 import numpy as np from .meta_vision import VisionDataset from .utils import is_img class ImageFolder(VisionDataset): r"""ImageFolder is a class for loading image data and labels from a organized folder. The folder is expected to be organized as followed: root/cls/xxx.img_ext Labels are indices of sorted classes in the root directory. Args: root: root directory of an image folder. loader: a function used to load image from path, if ``None``, default function that loads images with PIL will be called. check_valid_func: a function used to check if files in folder are expected image files, if ``None``, default function that checks file extensions will be called. class_name: if ``True``, return class name instead of class index. """ def __init__(self, root: str, check_valid_func=None, class_name: bool = False): super().__init__(root, order=("image", "image_category")) self.root = root if check_valid_func is not None: self.check_valid = check_valid_func else: self.check_valid = is_img self.class_name = class_name self.class_dict = self.collect_class() self.samples = self.collect_samples() def collect_samples(self) -> List: samples = [] directory = os.path.expanduser(self.root) for key in sorted(self.class_dict.keys()): d = os.path.join(directory, key) if not os.path.isdir(d): continue for r, _, filename in sorted(os.walk(d, followlinks=True)): for name in sorted(filename): path = os.path.join(r, name) if self.check_valid(path): if self.class_name: samples.append((path, key)) else: samples.append((path, self.class_dict[key])) return samples def collect_class(self) -> Dict: classes = [d.name for d in os.scandir(self.root) if d.is_dir()] classes.sort() return {classes[i]: np.int32(i) for i in range(len(classes))} def __getitem__(self, index: int) -> Tuple: path, label = self.samples[index] img = cv2.imread(path, cv2.IMREAD_COLOR) return img, label def __len__(self): return len(self.samples)