Loading...
No commits yet
Not committed History
Blame
03_plot_umap_space.py • 1.5 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 11:43:04 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/03_plot_umap_space.py


"""Creates UMAP visualization of MNIST dataset"""

# Imports
import scitex as stx
import numpy as np
import umap


# Functions and Classes
def create_umap_embedding(data: np.ndarray, CONFIG) -> np.ndarray:
    reducer = umap.UMAP(random_state=CONFIG.MNIST.UMAP_RANDOM_STATE, n_jobs=-1)
    embedding = reducer.fit_transform(data)
    return embedding


def plot_umap(embedding: np.ndarray, labels: np.ndarray, CONFIG, plt) -> None:
    fig, ax = stx.plt.subplots(figsize=(12, 8))
    scatter = ax.scatter(
        embedding[:, 0], embedding[:, 1], c=labels, cmap="tab10", alpha=0.5
    )

    plt.colorbar(scatter)
    ax.set_xyt("UMAP 1", "UMAP 2", "UMAP Projection of MNIST Digits")

    return fig


@stx.session
def main(
    CONFIG=stx.INJECTED,
    plt=stx.INJECTED,
    COLORS=stx.INJECTED,
    rng_manager=stx.INJECTED,
    logger=stx.INJECTED,
):
    """Create UMAP visualization of MNIST"""
    train_data = stx.io.load(CONFIG.PATH.MNIST.FLATTENED.TRAIN)
    train_labels = stx.io.load(CONFIG.PATH.MNIST.LABELS.TRAIN)
    embedding = create_umap_embedding(train_data, CONFIG)
    fig = plot_umap(embedding, train_labels, CONFIG, plt)
    stx.io.save(
        fig, CONFIG.PATH.MNIST.FIGURES + "umap.jpg", symlink_to="./data/mnist"
    )

    return 0


if __name__ == "__main__":
    main()

# EOF