# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import gzip import os import struct from typing import Tuple import numpy as np from tqdm import tqdm from ....logger import get_logger from .meta_vision import VisionDataset from .utils import _default_dataset_root, load_raw_data_from_url logger = get_logger(__name__) class MNIST(VisionDataset): r""":class:`~.Dataset` for MNIST meta data.""" url_path = "http://yann.lecun.com/exdb/mnist/" """ Url prefix for downloading raw file. """ raw_file_name = [ "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", ] """ Raw file names of both training set and test set (10k). """ raw_file_md5 = [ "f68b3c2dcbeaaa9fbdd348bbdeb94873", "d53e105ee54ea40749a09fcbcd1e9432", "9fb629c4189551a2d022fa330f9573f3", "ec29112dd5afa0611ce80d1b7f02629c", ] """ Md5 for checking raw files. """ def __init__( self, root: str = None, train: bool = True, download: bool = True, timeout: int = 500, ): r""" :param root: path for mnist dataset downloading or loading, if ``None``, set ``root`` to the ``_default_root``. :param train: if ``True``, loading trainingset, else loading test set. :param download: if raw files do not exists and download sets to ``True``, download raw files and process, otherwise raise ValueError, default is True. """ super().__init__(root, order=("image", "image_category")) self.timeout = timeout # process the root path if root is None: self.root = self._default_root if not os.path.exists(self.root): os.makedirs(self.root) else: self.root = root if not os.path.exists(self.root): if download: logger.debug( "dir %s does not exist, will be automatically created", self.root, ) os.makedirs(self.root) else: raise ValueError("dir %s does not exist" % self.root) if self._check_raw_files(): self.process(train) elif download: self.download() self.process(train) else: raise ValueError( "root does not contain valid raw files, please set download=True" ) def __getitem__(self, index: int) -> Tuple: return tuple(array[index] for array in self.arrays) def __len__(self) -> int: return len(self.arrays[0]) @property def _default_root(self): return os.path.join(_default_dataset_root(), self.__class__.__name__) @property def meta(self): return self._meta_data def _check_raw_files(self): return all( [ os.path.exists(os.path.join(self.root, path)) for path in self.raw_file_name ] ) def download(self): for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): url = self.url_path + file_name load_raw_data_from_url(url, file_name, md5, self.root) def process(self, train): # load raw files and transform them into meta data and datasets Tuple(np.array) logger.info("process the raw files of %s set...", "train" if train else "test") if train: meta_data_images, images = parse_idx3( os.path.join(self.root, self.raw_file_name[0]) ) meta_data_labels, labels = parse_idx1( os.path.join(self.root, self.raw_file_name[1]) ) else: meta_data_images, images = parse_idx3( os.path.join(self.root, self.raw_file_name[2]) ) meta_data_labels, labels = parse_idx1( os.path.join(self.root, self.raw_file_name[3]) ) self._meta_data = { "images": meta_data_images, "labels": meta_data_labels, } self.arrays = (images, labels.astype(np.int32)) def parse_idx3(idx3_file): # parse idx3 file to meta data and data in numpy array (images) logger.debug("parse idx3 file %s ...", idx3_file) assert idx3_file.endswith(".gz") with gzip.open(idx3_file, "rb") as f: bin_data = f.read() # parse meta data offset = 0 fmt_header = ">iiii" magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} # parse images image_size = height * width offset += struct.calcsize(fmt_header) fmt_image = ">" + str(image_size) + "B" images = [] bar = tqdm(total=meta_data["imgs"], ncols=80) for image in struct.iter_unpack(fmt_image, bin_data[offset:]): images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) bar.update() bar.close() return meta_data, images def parse_idx1(idx1_file): # parse idx1 file to meta data and data in numpy array (labels) logger.debug("parse idx1 file %s ...", idx1_file) assert idx1_file.endswith(".gz") with gzip.open(idx1_file, "rb") as f: bin_data = f.read() # parse meta data offset = 0 fmt_header = ">ii" magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) meta_data = {"magic": magic, "imgs": imgs} # parse labels offset += struct.calcsize(fmt_header) fmt_image = ">B" labels = np.empty(imgs, dtype=int) bar = tqdm(total=meta_data["imgs"], ncols=80) for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): labels[i] = label[0] bar.update() bar.close() return meta_data, labels