Loading...
No commits yet
Not committed History
Blame
04_clf_svm.py • 1.9 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-11-18 10:47:54 (ywatanabe)"
# File: /home/ywatanabe/proj/examples/scitex-research-template/scripts/mnist/04_clf_svm.py


"""Trains and evaluates SVM classifier on MNIST dataset"""

# Imports
import scitex as stx
from typing import Dict
import numpy as np
from sklearn.metrics import classification_report
from sklearn.svm import SVC


# Functions and Classes
def train_svm(features: np.ndarray, labels: np.ndarray, CONFIG) -> SVC:
    model = SVC(kernel="rbf", random_state=CONFIG.MNIST.RANDOM_STATE)
    model.fit(features, labels)
    return model


def evaluate(
    model: SVC,
    features: np.ndarray,
    labels: np.ndarray,
) -> Dict[str, float]:
    predictions = model.predict(features)
    report = classification_report(labels, predictions, output_dict=True)

    stx.io.save(report, "./classification_report.csv", symlink_to="./data/mnist")
    stx.io.save(predictions, "./predictions.npy", symlink_to="./data/mnist")
    stx.io.save(labels, "./labels.npy", symlink_to="./data/mnist")

    return {
        "accuracy": report["accuracy"],
        "macro_f1": report["macro avg"]["f1-score"],
    }


@stx.session
def main(
    CONFIG=stx.INJECTED,
    plt=stx.INJECTED,
    COLORS=stx.INJECTED,
    rng_manager=stx.INJECTED,
    logger=stx.INJECTED,
):
    """Train SVM classifier on MNIST"""
    train_data = stx.io.load(CONFIG.PATH.MNIST.FLATTENED.TRAIN)
    train_labels = stx.io.load(CONFIG.PATH.MNIST.LABELS.TRAIN)
    test_data = stx.io.load(CONFIG.PATH.MNIST.FLATTENED.TEST)
    test_labels = stx.io.load(CONFIG.PATH.MNIST.LABELS.TEST)

    model = train_svm(train_data, train_labels, CONFIG)
    metrics = evaluate(model, test_data, test_labels)

    logger.success(
        f"Test Accuracy: {metrics['accuracy']:.4f}, Macro F1: {metrics['macro_f1']:.4f}"
    )

    stx.io.save(model, eval(CONFIG.PATH.MNIST.MODEL_SVM), symlink_to="./data/mnist")
    return 0


if __name__ == "__main__":
    main()

# EOF