#!/usr/bin/env python3
##############################################################################
# MIT License
#
# Copyright (c) 2025 Advanced Micro Devices, Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

##############################################################################

"""
Metric description manager.
Syncs metric descriptions between config YAMLs and documentation files.

Usage:
    python metric_description_manager.py --sync-arch <arch_name> <configs_dir>
    python metric_description_manager.py --sync-all <configs_dir>
    python metric_description_manager.py --validate <arch_name> <configs_dir>
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Union

import yaml

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from config_management import utils_ruamel as cm_utils  # noqa: E402

# Section to panel ID mapping for organizing descriptions
SECTION_PANEL_MAP: dict[str, int] = {
    "Wavefront launch stats": 701,
    "Wavefront runtime stats": 702,
    "Overall instruction mix": 1001,
    "VALU arithmetic instruction mix": 1002,
    "MFMA instruction mix": 1004,
    "Compute Speed-of-Light": 1101,
    "Pipeline statistics": 1102,
    "Arithmetic operations": 1103,
    "LDS Speed-of-Light": 1201,
    "LDS Statistics": 1202,
    "vL1D Speed-of-Light": 1601,
    "Busy / stall metrics": 1501,
    "Instruction counts": 1502,
    "Spill / stack metrics": 1503,
    "L1 Unified Translation Cache (UTCL1)": 1605,
    "vL1D cache stall metrics": 1602,
    "vL1D cache access metrics": 1603,
    "Vector L1 data-return path or Texture Data (TD)": 1504,
    "L2 Speed-of-Light": 1701,
    "L2 cache accesses": 1703,
    "L2-Fabric interface metrics": 1702,
    "L2 - Fabric interface detailed metrics": 1706,
    "L2 - Fabric Interface stalls": 1705,
    "Scalar L1D Speed-of-Light": 1401,
    "Scalar L1D cache accesses": 1402,
    "Scalar L1D Cache - L2 Interface": 1403,
    "L1I Speed-of-Light": 1301,
    "L1I cache accesses": 1302,
    "L1I <-> L2 interface": 1303,
    "Workgroup manager utilizations": 601,
    "Workgroup Manager - Resource Allocation": 602,
    "Command processor fetcher (CPF)": 501,
    "Command processor packet processor (CPC)": 502,
    "System Speed-of-Light": 201,
}

PANEL_ID_TO_SECTION: dict[int, str] = {v: k for k, v in SECTION_PANEL_MAP.items()}


def merge_docs_rst_as_default(descs: dict, docs_file: Path) -> dict:
    """
    For each metric that does NOT explicitly carry an 'rst' in panel YAMLs,
    fill 'rst' from docs/data/metrics_description.yaml if present.
    This makes docs the default RST source unless the panel overrides it.
    """
    docs: dict = {}
    if docs_file.exists():
        with open(docs_file, "r", encoding="utf-8") as f:
            docs = yaml.safe_load(f) or {}

    for section, metrics in descs.items():
        docs_section = docs.get(section) or {}
        for metric_name, d in metrics.items():
            doc_entry = docs_section.get(metric_name) or {}
            if doc_entry.get("rst"):
                d["rst"] = doc_entry["rst"]
    return descs


def merge_units_as_default(descs: dict, docs_file: Path, per_arch_file: Path) -> dict:
    """
    Fill 'unit' ONLY when missing from panel extraction:
      1) take from existing per-arch file if present,
      2) else from docs file,
      3) else leave as-is (missing).
    """

    docs: dict = {}
    if docs_file.exists():
        with open(docs_file, "r", encoding="utf-8") as f:
            docs = yaml.safe_load(f) or {}

    for section, metrics in descs.items():
        dsec = docs.get(section) or {}
        for metric, data in metrics.items():
            doc_entry = dsec.get(metric)
            if doc_entry and "unit" in doc_entry:
                data["unit"] = doc_entry["unit"]
    return descs


def panel_rst_override_keys(descs: dict) -> set:
    """
    Return {(section, metric)} for metrics that explicitly
    included 'rst' in panel YAMLs.
    """
    keys = set()
    for section, metrics in descs.items():
        for metric_name, d in metrics.items():
            if "rst" in d and d["rst"]:
                keys.add((section, metric_name))
    return keys


def panel_unit_override_keys(descs: dict) -> set[tuple[str, str]]:
    keys: set[tuple[str, str]] = set()
    for section, metrics in descs.items():
        for metric, d in metrics.items():
            if "unit" in d and d["unit"] is not None:
                keys.add((section, metric))
    return keys


def validate_rst_syntax(text: str) -> tuple[bool, str]:
    """Basic RST syntax validation."""
    if not text:
        return True, ""

    errors: list[str] = []

    single_backticks = text.count("`")
    if single_backticks % 2 != 0:
        errors.append("Unmatched single backticks")

    double_backticks = text.count("``")
    remaining_singles = single_backticks - (double_backticks * 2)
    if remaining_singles % 2 != 0:
        errors.append("Unmatched backticks after accounting for code literals")

    if ":ref:`" in text:
        ref_count = text.count(":ref:`")
        closing_count = text[text.find(":ref:`") :].count("`")
        if ref_count > closing_count:
            errors.append("Unclosed :ref: directive")

    if ":doc:`" in text:
        doc_count = text.count(":doc:`")
        closing_count = text[text.find(":doc:`") :].count("`")
        if doc_count > closing_count:
            errors.append("Unclosed :doc: directive")

    if errors:
        return False, "; ".join(errors)
    return True, ""


def extract_descriptions_from_arch(
    arch_dir: Union[str, Path],
) -> dict[str, dict[str, dict]]:
    """
    Extract metric descriptions from all config YAMLs in an arch.
    Returns dict organized by section name.
    """
    arch_path = Path(arch_dir)
    descriptions_by_section: dict[str, dict[str, dict]] = {}

    for yaml_file in sorted(arch_path.glob("*.yaml")):
        data = cm_utils.load_yaml(yaml_file)

        panel_config = data.get("Panel Config")
        if not isinstance(panel_config, dict):
            continue

        panel_descriptions: dict = panel_config.get("metrics_description", {})

        metrics_with_units: dict[str, dict[str, str]] = {}
        for ds in panel_config.get("data source", []):
            for key, value in ds.items():
                if isinstance(value, dict) and "metric" in value:
                    table_id = value.get("id")
                    section_name = PANEL_ID_TO_SECTION.get(table_id)
                    if not section_name:
                        continue
                    for metric_name, metric_data in value["metric"].items():
                        unit = metric_data.get("unit")
                        if unit:
                            metrics_with_units[metric_name] = {
                                "section": section_name,
                                "unit": unit,
                            }

        for metric_name, description in panel_descriptions.items():
            section_name = (
                metrics_with_units[metric_name]["section"]
                if metric_name in metrics_with_units
                else "General"
            )

            if isinstance(description, dict):
                plain = description.get("plain", "")
                rst = description.get("rst", "")
                unit = description.get("unit", None)
            else:
                plain = description
                rst = ""
                unit = None

            desc_data = {"plain": plain, "rst": rst}
            if unit is not None:
                desc_data["unit"] = unit

            descriptions_by_section.setdefault(section_name, {})
            descriptions_by_section[section_name][metric_name] = desc_data

    return descriptions_by_section


def update_per_arch_metrics_file(
    arch_name: str, descriptions: dict, output_dir: Union[str, Path]
) -> None:
    """Write per-arch RST descriptions with units if available."""
    output_path = Path(output_dir) / f"{arch_name}_metrics_description.yaml"
    output_path.parent.mkdir(parents=True, exist_ok=True)

    rst_descriptions: dict[str, dict[str, dict]] = {}
    for section, metrics in descriptions.items():
        rst_descriptions[section] = {}
        for metric_name, desc_data in metrics.items():
            entry = {"rst": desc_data["rst"]}
            if "unit" in desc_data:
                entry["unit"] = desc_data["unit"]
            rst_descriptions[section][metric_name] = entry

    cm_utils.save_yaml(rst_descriptions, output_path)
    print(f"Updated: {output_path}")


def update_docs_metrics_file(
    descriptions: dict,
    docs_file: str,
    panel_rst_overrides: set,
    panel_unit_overrides: set,
) -> bool:
    docs_path = Path(docs_file)
    existing: dict = {}
    if docs_path.exists():
        with open(docs_path, "r", encoding="utf-8") as f:
            existing = yaml.safe_load(f) or {}

    for section, metrics in descriptions.items():
        existing.setdefault(section, {})
        for metric_name, desc_data in metrics.items():
            existing[section].setdefault(metric_name, {})
            # Only overwrite rst if panel provided an explicit override
            if (section, metric_name) in panel_rst_overrides and desc_data.get("rst"):
                existing[section][metric_name]["rst"] = desc_data["rst"]
            # Always keep unit if provided (optional)
            if (section, metric_name) in panel_unit_overrides and "unit" in desc_data:
                existing[section][metric_name]["unit"] = desc_data["unit"]

    docs_path.parent.mkdir(parents=True, exist_ok=True)

    cm_utils.save_yaml(existing, docs_path)
    return True


def validate_descriptions(
    arch_dir: Union[str, Path],
) -> tuple[bool, list[str], list[str]]:
    """Validate: missing descriptions and basic RST syntax."""
    arch_path = Path(arch_dir)
    warnings: list[str] = []
    errors: list[str] = []

    for yaml_file in sorted(arch_path.glob("*.yaml")):
        with open(yaml_file) as f:
            data = yaml.safe_load(f) or {}

        panel_config = data.get("Panel Config")
        if not isinstance(panel_config, dict):
            continue

        panel_descriptions: dict = panel_config.get("metrics_description", {})
        all_metrics: set[str] = set()

        for ds in panel_config.get("data source", []):
            for _, value in ds.items():
                if isinstance(value, dict) and "metric" in value:
                    all_metrics.update(value["metric"].keys())

        missing = sorted(all_metrics - set(panel_descriptions.keys()))
        if missing:
            warnings.append(
                f"{yaml_file.name}: Missing descriptions "
                f"for metrics: {', '.join(missing)}"
            )

        for metric_name, description in panel_descriptions.items():
            rst_text = (
                description.get("rst", "")
                if isinstance(description, dict)
                else description
            )
            ok, err = validate_rst_syntax(rst_text)
            if not ok:
                errors.append(
                    f"{yaml_file.name}: Metric '{metric_name}' has invalid RST: {err}"
                )

    return len(errors) == 0, warnings, errors


def sync_arch(
    arch_name: str,
    configs_dir: str,
    per_arch_metrics_dir: str,
    docs_metrics_file: str,
    is_latest: bool,
) -> bool:
    """Sync descriptions for a single architecture."""
    arch_dir = Path(configs_dir) / arch_name
    docs_file = Path(docs_metrics_file)
    per_arch_file = Path(per_arch_metrics_dir) / f"{arch_name}_metrics_description.yaml"

    if not arch_dir.is_dir():
        print(f"Error: {arch_dir} is not a directory")
        return False

    print(f"Syncing descriptions for {arch_name}...")
    is_valid, warnings, errors = validate_descriptions(arch_dir)

    # 1) Extract descriptions from panel YAMLs (source for 'plain', optional 'rst')
    descriptions = extract_descriptions_from_arch(arch_dir)
    if not descriptions:
        print(f"No descriptions found in {arch_name}")
        return True

    # 2) Capture which metrics had explicit panel RST (BEFORE merging docs)
    panel_rst_overrides = panel_rst_override_keys(descriptions)
    panel_unit_overrides = panel_unit_override_keys(descriptions)

    # 3) Merge docs' RST as the default (unless panel overrides)
    descriptions = merge_docs_rst_as_default(descriptions, docs_file)
    descriptions = merge_units_as_default(descriptions, docs_file, per_arch_file)

    # 4) Write per-arch file (plain from panel; rst = panel override or docs default)
    update_per_arch_metrics_file(arch_name, descriptions, per_arch_metrics_dir)

    # 5) Only when latest: update docs, but overwrite 'rst' only for overrides
    if is_latest and (panel_rst_overrides or panel_unit_overrides):
        if not update_docs_metrics_file(
            descriptions,
            docs_metrics_file,
            panel_rst_overrides,
            panel_unit_overrides,
        ):
            return False

    return True


def main() -> int:
    parser = argparse.ArgumentParser(description="Manage metric descriptions")
    parser.add_argument(
        "--sync-arch",
        metavar="ARCH",
        help="Sync descriptions for specific architecture",
    )
    parser.add_argument(
        "--sync-all",
        action="store_true",
        help="Sync descriptions for all architectures",
    )
    parser.add_argument(
        "--validate",
        metavar="ARCH",
        help="Validate descriptions for specific architecture",
    )
    parser.add_argument(
        "--latest-arch", help="Specify which arch is latest (for docs update)"
    )
    parser.add_argument("configs_dir", help="Path to analysis_configs directory")
    parser.add_argument(
        "--per-arch-output",
        default="tools/per_arch_metric_definitions",
        help="Output directory for per-arch files",
    )
    parser.add_argument(
        "--docs-file",
        default="docs/data/metrics_description.yaml",
        help="Path to docs metrics description file",
    )

    args = parser.parse_args()

    if args.sync_arch:
        is_latest = (args.latest_arch == args.sync_arch) if args.latest_arch else False
        ok = sync_arch(
            args.sync_arch,
            args.configs_dir,
            args.per_arch_output,
            args.docs_file,
            is_latest,
        )
        return 0 if ok else 1

    if args.sync_all:
        configs_path = Path(args.configs_dir)
        archs = sorted([
            d.name
            for d in configs_path.iterdir()
            if d.is_dir() and d.name.startswith("gfx")
        ])
        if not archs:
            print("No architecture directories found")
            return 1
        latest_arch = args.latest_arch if args.latest_arch else archs[-1]
        for arch in archs:
            ok = sync_arch(
                arch,
                args.configs_dir,
                args.per_arch_output,
                args.docs_file,
                arch == latest_arch,
            )
            if not ok:
                return 1
        return 0

    if args.validate:
        arch_dir = Path(args.configs_dir) / args.validate
        if not arch_dir.is_dir():
            print(f"Error: {arch_dir} is not a directory")
            return 1

        is_valid, warnings, errors = validate_descriptions(arch_dir)
        print(f"Validation results for {args.validate}:\n{'=' * 80}")

        if warnings:
            print("\nWarnings:")
            for w in warnings:
                print(f"   {w}")

        if errors:
            print("\nErrors:")
            for e in errors:
                print(f"   {e}")

        if is_valid and not warnings:
            print("\nAll validations passed")

        return 0 if is_valid else 1

    parser.print_help()
    return 1


if __name__ == "__main__":
    sys.exit(main())
