#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 11:38:35 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/02_plot_digits.py
"""Visualizes MNIST dataset samples"""
# Imports
import scitex as stx
from torch.utils.data import DataLoader
# Functions and Classes
def plot_samples(loader: DataLoader, CONFIG, plt, n_samples: int = 25) -> None:
images, labels = next(iter(loader))
fig, axes = stx.plt.subplots(5, 5, figsize=(10, 10))
for idx, ax in enumerate(axes.flat):
if idx < n_samples:
ax.plot_imshow(images[idx].squeeze(), cmap="gray")
ax.set_title(f"Label: {labels[idx]}")
ax.axis("off")
plt.tight_layout()
return fig
def plot_label_examples(loader: DataLoader, CONFIG, plt) -> None:
images, labels = next(iter(loader))
fig, axes = stx.plt.subplots(2, 5, figsize=(15, 6))
label_examples = {}
for img, label in zip(images, labels):
if label.item() not in label_examples and len(label_examples) < 10:
label_examples[label.item()] = img
for idx, (label, img) in enumerate(sorted(label_examples.items())):
row, col = idx // 5, idx % 5
ax = axes[row, col]
ax.plot_imshow(img.squeeze(), cmap="gray")
ax.set_title(f"Digit: {label}")
ax.axis("off")
plt.tight_layout()
return fig
@stx.session
def main(
CONFIG=stx.INJECTED,
plt=stx.INJECTED,
COLORS=stx.INJECTED,
rng_manager=stx.INJECTED,
logger=stx.INJECTED,
):
"""Visualize MNIST samples"""
train_loader = stx.io.load(CONFIG.PATH.MNIST.LOADER.TRAIN)
fig = plot_samples(train_loader, CONFIG, plt)
stx.io.save(
fig,
"mnist_samples.jpg",
symlink_to="./data/mnist",
)
fig = plot_label_examples(train_loader, CONFIG, plt)
stx.io.save(
fig,
"mnist_digits.jpg",
symlink_to="./data/mnist",
)
return 0
if __name__ == "__main__":
main()
# EOF