#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 10:47:55 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/05_plot_conf_mat.py
"""Plots confusion matrix from saved predictions and labels"""
# Imports
import scitex as stx
import numpy as np
from sklearn.metrics import confusion_matrix
# Functions and Classes
def plot_confusion_matrix(labels: np.ndarray, predictions: np.ndarray, CONFIG) -> None:
cm = confusion_matrix(labels, predictions)
fig, ax = stx.plt.subplots(figsize=(10, 8))
ax.imshow2d(cm)
ax.set_xyt("Predicted", "True", "Confusion Matrix")
return fig
@stx.session
def main(
CONFIG=stx.INJECTED,
plt=stx.INJECTED,
COLORS=stx.INJECTED,
rng_manager=stx.INJECTED,
logger=stx.INJECTED,
):
"""Plot confusion matrix"""
predictions = stx.io.load("./scripts/mnist/04_clf_svm_out/predictions.npy")
labels = stx.io.load("./scripts/mnist/04_clf_svm_out/labels.npy")
fig = plot_confusion_matrix(labels, predictions, CONFIG)
stx.io.save(
fig,
CONFIG.PATH.MNIST.FIGURES + "confusion_matrix.jpg",
symlink_to="./data/mnist",
)
return 0
if __name__ == "__main__":
main()
# EOF