Loading...
No commits yet
Not committed History
Blame
01_download.py • 3.0 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 10:47:57 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/01_download.py


"""Downloads MNIST dataset and saves preprocessed versions"""

# Imports
import scitex as stx
from typing import Dict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms


# Functions and Classes
def download_mnist(CONFIG) -> Dict[str, torch.utils.data.Dataset]:
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                eval(CONFIG.MNIST.NORMALIZE.MEAN),
                eval(CONFIG.MNIST.NORMALIZE.STD),
            ),
        ]
    )
    train_dataset = datasets.MNIST(
        CONFIG.PATH.MNIST.RAW, train=True, download=True, transform=transform
    )
    test_dataset = datasets.MNIST(
        CONFIG.PATH.MNIST.RAW, train=False, transform=transform
    )
    return {"train": train_dataset, "test": test_dataset}


def create_loaders(
    datasets: Dict[str, torch.utils.data.Dataset], CONFIG
) -> Dict[str, DataLoader]:
    train_loader = DataLoader(
        datasets["train"],
        batch_size=CONFIG.MNIST.BATCH_SIZE.TRAIN,
        shuffle=True,
    )
    test_loader = DataLoader(datasets["test"], batch_size=CONFIG.MNIST.BATCH_SIZE.TEST)

    return {"train": train_loader, "test": test_loader}


def prepare_flattened_data(
    datasets: Dict[str, torch.utils.data.Dataset],
) -> Dict[str, np.ndarray]:
    flattened_data = {}
    labels = {}

    for split, dataset in datasets.items():
        data = dataset.data.numpy()
        flattened_data[split] = data.reshape(len(data), -1) / 255.0
        labels[split] = dataset.targets.numpy()

    return {"data": flattened_data, "labels": labels}


@stx.session
def main(
    CONFIG=stx.INJECTED,
    plt=stx.INJECTED,
    COLORS=stx.INJECTED,
    rng_manager=stx.INJECTED,
    logger=stx.INJECTED,
):
    """Download and preprocess MNIST dataset"""
    datasets = download_mnist(CONFIG)
    loaders = create_loaders(datasets, CONFIG)
    flat_data = prepare_flattened_data(datasets)

    stx.io.save(
        loaders["train"],
        CONFIG.PATH.MNIST.LOADER.TRAIN,
        symlink_to="./data/mnist",
    )
    stx.io.save(
        loaders["test"],
        CONFIG.PATH.MNIST.LOADER.TEST,
        symlink_to="./data/mnist",
    )
    stx.io.save(
        flat_data["data"]["train"],
        CONFIG.PATH.MNIST.FLATTENED.TRAIN,
        symlink_to="./data/mnist",
    )
    stx.io.save(
        flat_data["data"]["test"],
        CONFIG.PATH.MNIST.FLATTENED.TEST,
        symlink_to="./data/mnist",
    )
    stx.io.save(
        flat_data["labels"]["train"],
        CONFIG.PATH.MNIST.LABELS.TRAIN,
        symlink_to="./data/mnist",
    )
    stx.io.save(
        flat_data["labels"]["test"],
        CONFIG.PATH.MNIST.LABELS.TEST,
        symlink_to="./data/mnist",
    )
    return 0


if __name__ == "__main__":
    main()

# EOF