#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 11:43:04 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/03_plot_umap_space.py
"""Creates UMAP visualization of MNIST dataset"""
# Imports
import scitex as stx
import numpy as np
import umap
# Functions and Classes
def create_umap_embedding(data: np.ndarray, CONFIG) -> np.ndarray:
reducer = umap.UMAP(random_state=CONFIG.MNIST.UMAP_RANDOM_STATE, n_jobs=-1)
embedding = reducer.fit_transform(data)
return embedding
def plot_umap(embedding: np.ndarray, labels: np.ndarray, CONFIG, plt) -> None:
fig, ax = stx.plt.subplots(figsize=(12, 8))
scatter = ax.scatter(
embedding[:, 0], embedding[:, 1], c=labels, cmap="tab10", alpha=0.5
)
plt.colorbar(scatter)
ax.set_xyt("UMAP 1", "UMAP 2", "UMAP Projection of MNIST Digits")
return fig
@stx.session
def main(
CONFIG=stx.INJECTED,
plt=stx.INJECTED,
COLORS=stx.INJECTED,
rng_manager=stx.INJECTED,
logger=stx.INJECTED,
):
"""Create UMAP visualization of MNIST"""
train_data = stx.io.load(CONFIG.PATH.MNIST.FLATTENED.TRAIN)
train_labels = stx.io.load(CONFIG.PATH.MNIST.LABELS.TRAIN)
embedding = create_umap_embedding(train_data, CONFIG)
fig = plot_umap(embedding, train_labels, CONFIG, plt)
stx.io.save(
fig, CONFIG.PATH.MNIST.FIGURES + "umap.jpg", symlink_to="./data/mnist"
)
return 0
if __name__ == "__main__":
main()
# EOF