Loading...
No commits yet
Not committed History
Blame
check_references.py • 13.3 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File: scripts/python/check_references.py
# Purpose: Validate all cross-references, citations, and labels in LaTeX manuscripts
# Usage:
#   python check_references.py [project_dir] [--doc-type manuscript|supplementary|all]
#   python check_references.py [project_dir] --log  # Also parse .log for LaTeX warnings
#
# Checks:
#   1. Undefined references: \ref{X} where \label{X} doesn't exist
#   2. Undefined citations: \cite{X} where X not in any .bib file
#   3. Multiply defined labels: \label{X} defined more than once
#   4. Orphan labels: \label{X} never referenced
#   5. Orphan bib entries: @article{X,...} never cited (info only)
#   6. LaTeX log warnings (optional --log)

import argparse
import re
import sys
from collections import defaultdict
from pathlib import Path

# ANSI colors
GREEN = "\033[0;32m"
YELLOW = "\033[1;33m"
RED = "\033[0;31m"
DIM = "\033[0;90m"
BOLD = "\033[1m"
NC = "\033[0m"

PASS_COUNT = 0
WARN_COUNT = 0
FAIL_COUNT = 0


def log_pass(msg):
    global PASS_COUNT
    print(f"  {GREEN}[PASS]{NC} {msg}")
    PASS_COUNT += 1


def log_warn(msg):
    global WARN_COUNT
    print(f"  {YELLOW}[WARN]{NC} {msg}")
    WARN_COUNT += 1


def log_fail(msg):
    global FAIL_COUNT
    print(f"  {RED}[FAIL]{NC} {msg}")
    FAIL_COUNT += 1


def log_detail(msg):
    print(f"    {DIM}{msg}{NC}")


def collect_tex_files(doc_dir):
    """Collect SOURCE .tex files only (not generated/compiled files).

    Scans: contents/*.tex, caption_and_media/*.tex, base.tex
    Skips: manuscript.tex, manuscript_diff.tex, supplementary.tex, etc. (auto-generated)
    """
    # Patterns for generated/archived files to skip
    skip_patterns = re.compile(r"_v\d+\.tex$|_diff\.tex$")

    files = []
    content_dir = doc_dir / "contents"
    if content_dir.exists():
        for f in content_dir.glob("*.tex"):
            if not skip_patterns.search(f.name):
                files.append(f)
        for subdir in ["figures/caption_and_media", "tables/caption_and_media"]:
            d = content_dir / subdir
            if d.exists():
                files.extend(d.glob("*.tex"))
    # Include base.tex (structural template) but NOT generated files
    base = doc_dir / "base.tex"
    if base.exists():
        files.append(base)
    return list(set(files))


def extract_refs(tex_files):
    """Extract all \\ref{...} from tex files.

    Returns dict: ref_key -> [(file, line_no), ...]
    """
    refs = defaultdict(list)
    pattern = re.compile(r"\\ref\{([^}]+)\}")
    for f in tex_files:
        text = f.read_text(encoding="utf-8", errors="replace")
        for line_no, line in enumerate(text.splitlines(), 1):
            # Skip comments
            stripped = line.split("%")[0] if "%" in line else line
            for m in pattern.finditer(stripped):
                key = m.group(1)
                # Skip LaTeX macro arguments like #1
                if key.startswith("#"):
                    continue
                refs[key].append((f, line_no))
    return dict(refs)


def extract_labels(tex_files):
    """Extract all \\label{...} from tex files.

    Returns dict: label_key -> [(file, line_no), ...]
    """
    labels = defaultdict(list)
    pattern = re.compile(r"\\label\{([^}]+)\}")
    for f in tex_files:
        text = f.read_text(encoding="utf-8", errors="replace")
        for line_no, line in enumerate(text.splitlines(), 1):
            stripped = line.split("%")[0] if "%" in line else line
            for m in pattern.finditer(stripped):
                labels[m.group(1)].append((f, line_no))
    return dict(labels)


def infer_auto_labels(doc_dir):
    """Infer labels auto-generated by scitex-writer preprocessing.

    The compile pipeline creates \\label{fig:STEM} and \\label{tab:STEM}
    from filenames in caption_and_media/ directories.

    Returns dict: label_key -> [(file, 0), ...]  (line 0 = auto-generated)
    """
    labels = defaultdict(list)
    content_dir = doc_dir / "contents"
    if not content_dir.exists():
        return dict(labels)

    for float_type, subdir in [("fig", "figures"), ("tab", "tables")]:
        media_dir = content_dir / subdir / "caption_and_media"
        if not media_dir.exists():
            continue
        for f in media_dir.glob("[0-9]*.tex"):
            stem = f.stem
            # Skip panel patterns (e.g., 01a_name)
            if re.match(r"^\d+[a-zA-Z]_", stem):
                continue
            key = f"{float_type}:{stem}"
            labels[key].append((f, 0))

    return dict(labels)


def extract_citations(tex_files):
    """Extract all \\cite{...}, \\citep{...}, \\citet{...} etc.

    Handles multi-key citations like \\citep{Key1, Key2}.

    Returns dict: cite_key -> [(file, line_no), ...]
    """
    cites = defaultdict(list)
    pattern = re.compile(
        r"\\(?:cite|citep|citet|citealt|citeauthor|citeyear)\{([^}]+)\}"
    )
    for f in tex_files:
        text = f.read_text(encoding="utf-8", errors="replace")
        for line_no, line in enumerate(text.splitlines(), 1):
            stripped = line.split("%")[0] if "%" in line else line
            for m in pattern.finditer(stripped):
                keys = m.group(1)
                for key in keys.split(","):
                    key = key.strip()
                    if key:
                        cites[key].append((f, line_no))
    return dict(cites)


def extract_bib_keys(bib_dir):
    """Extract all entry keys from .bib files.

    Returns dict: bib_key -> bib_file
    """
    keys = {}
    if not bib_dir.exists():
        return keys
    pattern = re.compile(r"@\w+\{([^,\s]+)")
    for f in bib_dir.glob("*.bib"):
        text = f.read_text(encoding="utf-8", errors="replace")
        for m in pattern.finditer(text):
            keys[m.group(1).strip()] = f
    return keys


def parse_log_warnings(log_file):
    """Parse LaTeX .log file for reference/citation warnings.

    Returns list of warning strings.
    """
    warnings = []
    if not log_file.exists():
        return warnings
    text = log_file.read_text(encoding="utf-8", errors="replace")
    for line in text.splitlines():
        if "Reference" in line and "undefined" in line:
            warnings.append(line.strip())
        elif "Citation" in line and "undefined" in line:
            warnings.append(line.strip())
        elif "multiply defined" in line:
            warnings.append(line.strip())
    return warnings


def check_undefined_refs(refs, labels, doc_label):
    """Check for \\ref{X} where no \\label{X} exists."""
    missing = {k: v for k, v in refs.items() if k not in labels}
    if not missing:
        log_pass(f"All references resolved ({doc_label}): {len(refs)} refs")
    else:
        log_fail(f"Undefined references ({doc_label}): {len(missing)} broken")
        for key, locations in sorted(missing.items()):
            for f, line in locations:
                log_detail(f"{f.name}:{line}: \\ref{{{key}}} -> ?? (no \\label)")


def check_undefined_cites(cites, bib_keys, doc_label):
    """Check for \\cite{X} where X not in any .bib file."""
    missing = {k: v for k, v in cites.items() if k not in bib_keys}
    if not missing:
        log_pass(f"All citations resolved ({doc_label}): {len(cites)} citations")
    else:
        log_fail(f"Undefined citations ({doc_label}): {len(missing)} missing from .bib")
        for key, locations in sorted(missing.items()):
            for f, line in locations:
                log_detail(f"{f.name}:{line}: \\cite{{{key}}} -> not in bibliography")


def check_multiply_defined(labels, doc_label):
    """Check for \\label{X} defined more than once."""
    dupes = {k: v for k, v in labels.items() if len(v) > 1}
    if not dupes:
        log_pass(f"No multiply-defined labels ({doc_label})")
    else:
        log_warn(f"Multiply-defined labels ({doc_label}): {len(dupes)} duplicates")
        for key, locations in sorted(dupes.items()):
            for f, line in locations:
                log_detail(f"{f.name}:{line}: \\label{{{key}}}")


def check_orphan_labels(refs, labels, doc_label):
    """Check for \\label{X} never referenced."""
    # Exclude common structural labels that are referenced by LaTeX internals
    structural_prefixes = (
        "star ",
        "acknowledgment",
        "author ",
        "declaration",
        "data and code",
        "figures",
        "tables",
    )
    orphans = {}
    for k, v in labels.items():
        if k not in refs:
            if not any(k.lower().startswith(p) for p in structural_prefixes):
                orphans[k] = v

    if not orphans:
        log_pass(f"No orphan labels ({doc_label})")
    else:
        log_warn(
            f"Orphan labels ({doc_label}): {len(orphans)} defined but never referenced"
        )
        for key, locations in sorted(orphans.items()):
            for f, line in locations:
                log_detail(f"{f.name}:{line}: \\label{{{key}}} never referenced")


def check_orphan_bib(cites, bib_keys):
    """Report bib entries that are never cited (info only)."""
    all_cited = set(cites.keys())
    unused = {k: v for k, v in bib_keys.items() if k not in all_cited}
    total = len(bib_keys)
    used = total - len(unused)
    if not unused:
        log_pass(f"All {total} bib entries cited")
    else:
        # This is info, not a warning - unused bib entries are fine
        print(
            f"  {DIM}[INFO]{NC} Bibliography: {used}/{total} entries cited, {len(unused)} unused"
        )


def check_log_warnings(log_file, doc_label):
    """Parse LaTeX log for reference warnings."""
    warnings = parse_log_warnings(log_file)
    if not warnings:
        log_pass(f"No LaTeX warnings ({doc_label})")
    else:
        log_warn(f"LaTeX warnings ({doc_label}): {len(warnings)}")
        for w in warnings[:20]:
            log_detail(w)
        if len(warnings) > 20:
            log_detail(f"... and {len(warnings) - 20} more")


def main():
    global PASS_COUNT, WARN_COUNT, FAIL_COUNT

    parser = argparse.ArgumentParser(
        description="Check cross-references, citations, and labels in LaTeX manuscripts"
    )
    parser.add_argument(
        "project_dir",
        nargs="?",
        default=".",
        help="Project root directory (default: current directory)",
    )
    parser.add_argument(
        "--doc-type",
        choices=["manuscript", "supplementary", "all"],
        default="all",
        help="Which document type to check (default: all)",
    )
    parser.add_argument(
        "--log",
        action="store_true",
        help="Also parse LaTeX .log files for warnings",
    )
    args = parser.parse_args()

    project_dir = Path(args.project_dir).resolve()
    bib_dir = project_dir / "00_shared" / "bib_files"

    # Collect document directories
    doc_dirs = []
    if args.doc_type in ("manuscript", "all"):
        d = project_dir / "01_manuscript"
        if d.exists():
            doc_dirs.append(("manuscript", d))
    if args.doc_type in ("supplementary", "all"):
        d = project_dir / "02_supplementary"
        if d.exists():
            doc_dirs.append(("supplementary", d))

    if not doc_dirs:
        print(f"{RED}No document directories found in {project_dir}{NC}")
        return 1

    print(f"\n{BOLD}=== Reference Check ==={NC}\n")

    # Global bib keys
    bib_keys = extract_bib_keys(bib_dir)

    # Aggregate all refs/labels/cites across all doc types for cross-document references
    all_refs = {}
    all_labels = {}
    all_cites = {}

    for doc_label, doc_dir in doc_dirs:
        tex_files = collect_tex_files(doc_dir)
        refs = extract_refs(tex_files)
        labels = extract_labels(tex_files)
        auto_labels = infer_auto_labels(doc_dir)
        cites = extract_citations(tex_files)

        # Merge into global
        for k, v in refs.items():
            all_refs.setdefault(k, []).extend(v)
        for k, v in labels.items():
            all_labels.setdefault(k, []).extend(v)
        # Only add auto-labels if not already explicitly defined
        for k, v in auto_labels.items():
            if k not in all_labels:
                all_labels.setdefault(k, []).extend(v)
        for k, v in cites.items():
            all_cites.setdefault(k, []).extend(v)

    # Run checks on aggregated data
    check_undefined_refs(all_refs, all_labels, "all documents")
    check_undefined_cites(all_cites, bib_keys, "all documents")
    check_multiply_defined(all_labels, "all documents")
    check_orphan_labels(all_refs, all_labels, "all documents")
    check_orphan_bib(all_cites, bib_keys)

    # Optionally check LaTeX logs
    if args.log:
        print()
        for doc_label, doc_dir in doc_dirs:
            log_dir = doc_dir / "logs"
            if log_dir.exists():
                for log_file in log_dir.glob("*.log"):
                    if log_file.name.startswith(("manuscript", "supplementary")):
                        check_log_warnings(log_file, f"{doc_label}/{log_file.name}")

    # Summary
    print()
    print(
        f"{BOLD}Summary:{NC} "
        f"{GREEN}{PASS_COUNT} passed{NC}, "
        f"{YELLOW}{WARN_COUNT} warnings{NC}, "
        f"{RED}{FAIL_COUNT} errors{NC}"
    )

    if FAIL_COUNT > 0:
        print(f"\n{RED}Broken references will show as ?? in the compiled PDF.{NC}")
        return 1
    elif WARN_COUNT > 0:
        print(f"\n{YELLOW}Warnings may indicate issues worth reviewing.{NC}")
        return 0
    else:
        print(f"\n{GREEN}All references and citations are valid.{NC}")
        return 0


if __name__ == "__main__":
    sys.exit(main())