#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 10:47:57 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/01_download.py
"""Downloads MNIST dataset and saves preprocessed versions"""
# Imports
import scitex as stx
from typing import Dict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
# Functions and Classes
def download_mnist(CONFIG) -> Dict[str, torch.utils.data.Dataset]:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
eval(CONFIG.MNIST.NORMALIZE.MEAN),
eval(CONFIG.MNIST.NORMALIZE.STD),
),
]
)
train_dataset = datasets.MNIST(
CONFIG.PATH.MNIST.RAW, train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
CONFIG.PATH.MNIST.RAW, train=False, transform=transform
)
return {"train": train_dataset, "test": test_dataset}
def create_loaders(
datasets: Dict[str, torch.utils.data.Dataset], CONFIG
) -> Dict[str, DataLoader]:
train_loader = DataLoader(
datasets["train"],
batch_size=CONFIG.MNIST.BATCH_SIZE.TRAIN,
shuffle=True,
)
test_loader = DataLoader(datasets["test"], batch_size=CONFIG.MNIST.BATCH_SIZE.TEST)
return {"train": train_loader, "test": test_loader}
def prepare_flattened_data(
datasets: Dict[str, torch.utils.data.Dataset],
) -> Dict[str, np.ndarray]:
flattened_data = {}
labels = {}
for split, dataset in datasets.items():
data = dataset.data.numpy()
flattened_data[split] = data.reshape(len(data), -1) / 255.0
labels[split] = dataset.targets.numpy()
return {"data": flattened_data, "labels": labels}
@stx.session
def main(
CONFIG=stx.INJECTED,
plt=stx.INJECTED,
COLORS=stx.INJECTED,
rng_manager=stx.INJECTED,
logger=stx.INJECTED,
):
"""Download and preprocess MNIST dataset"""
datasets = download_mnist(CONFIG)
loaders = create_loaders(datasets, CONFIG)
flat_data = prepare_flattened_data(datasets)
stx.io.save(
loaders["train"],
CONFIG.PATH.MNIST.LOADER.TRAIN,
symlink_to="./data/mnist",
)
stx.io.save(
loaders["test"],
CONFIG.PATH.MNIST.LOADER.TEST,
symlink_to="./data/mnist",
)
stx.io.save(
flat_data["data"]["train"],
CONFIG.PATH.MNIST.FLATTENED.TRAIN,
symlink_to="./data/mnist",
)
stx.io.save(
flat_data["data"]["test"],
CONFIG.PATH.MNIST.FLATTENED.TEST,
symlink_to="./data/mnist",
)
stx.io.save(
flat_data["labels"]["train"],
CONFIG.PATH.MNIST.LABELS.TRAIN,
symlink_to="./data/mnist",
)
stx.io.save(
flat_data["labels"]["test"],
CONFIG.PATH.MNIST.LABELS.TEST,
symlink_to="./data/mnist",
)
return 0
if __name__ == "__main__":
main()
# EOF