#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Timestamp: "2025-11-18 10:47:54 (ywatanabe)" # File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/04_clf_svm.py """Trains and evaluates SVM classifier on MNIST dataset""" # Imports import scitex as stx from typing import Dict import numpy as np from sklearn.metrics import classification_report from sklearn.svm import SVC # Functions and Classes def train_svm(features: np.ndarray, labels: np.ndarray, CONFIG) -> SVC: model = SVC(kernel="rbf", random_state=CONFIG.MNIST.RANDOM_STATE) model.fit(features, labels) return model def evaluate( model: SVC, features: np.ndarray, labels: np.ndarray, ) -> Dict[str, float]: predictions = model.predict(features) report = classification_report(labels, predictions, output_dict=True) stx.io.save(report, "./classification_report.csv", symlink_to="./data/mnist") stx.io.save(predictions, "./predictions.npy", symlink_to="./data/mnist") stx.io.save(labels, "./labels.npy", symlink_to="./data/mnist") return { "accuracy": report["accuracy"], "macro_f1": report["macro avg"]["f1-score"], } @stx.session def main( CONFIG=stx.INJECTED, plt=stx.INJECTED, COLORS=stx.INJECTED, rng_manager=stx.INJECTED, logger=stx.INJECTED, ): """Train SVM classifier on MNIST""" train_data = stx.io.load(CONFIG.PATH.MNIST.FLATTENED.TRAIN) train_labels = stx.io.load(CONFIG.PATH.MNIST.LABELS.TRAIN) test_data = stx.io.load(CONFIG.PATH.MNIST.FLATTENED.TEST) test_labels = stx.io.load(CONFIG.PATH.MNIST.LABELS.TEST) model = train_svm(train_data, train_labels, CONFIG) metrics = evaluate(model, test_data, test_labels) logger.success( f"Test Accuracy: {metrics['accuracy']:.4f}, Macro F1: {metrics['macro_f1']:.4f}" ) stx.io.save(model, eval(CONFIG.PATH.MNIST.MODEL_SVM), symlink_to="./data/mnist") return 0 if __name__ == "__main__": main() # EOF