#!/usr/bin/env python3
##############################################################################
# MIT License
#
# Copyright (c) 2025 Advanced Micro Devices, Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

##############################################################################

"""
verify_against_config_template.py

Validate per-architecture panel YAMLs against a shared config template.
- Validate structure + ordering only.
- Treat any deviation as an error.
- Collect all errors and report at end.

Template format (generated by parse_config_template.py):
  latest_arch: gfx###   (optional)
  panels:
    - file: <filename without numeric prefix>
      panel_id: <normalized panel id>
      panel_title: <title>
      panel_alias: <optional>
      data_sources:
        - type: metric_table|raw_csv_table|...
          id: <normalized table id>
          title: <title>

Usage:
  python verify_against_config_template.py <analysis_configs_dir> <template_yaml>
"""

from __future__ import annotations

import argparse
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from config_management import utils_ruamel as cm_utils  # noqa: E402

REQUIRED_PANEL_KEYS = ("id", "title", "data source", "metrics_description")
OPTIONAL_PANEL_KEYS = ("alias",)
DEFAULT_ALLOWED_PANEL_KEYS = set(REQUIRED_PANEL_KEYS) | set(OPTIONAL_PANEL_KEYS)


def normalize_panel_id(panel_id: int) -> int:
    return panel_id // 100 if panel_id >= 100 else panel_id


def normalize_table_id(table_id: int) -> int:
    return table_id % 100


@dataclass(frozen=True)
class TemplateDataSource:
    type: str
    id: int
    title: str


@dataclass(frozen=True)
class TemplatePanel:
    file: str
    panel_id: int
    panel_title: str
    panel_alias: Any
    data_sources: tuple[TemplateDataSource, ...]


def _as_str(v: Any) -> str:
    return "" if v is None else str(v)


def load_template(
    template_file: Path,
) -> tuple[list[TemplatePanel], dict[int, TemplatePanel]]:
    data = cm_utils.load_yaml(template_file) or {}
    panels_raw = data.get("panels", [])
    if not isinstance(panels_raw, list):
        raise ValueError("Template YAML must contain a top-level 'panels' list")

    panels: list[TemplatePanel] = []
    by_id: dict[int, TemplatePanel] = {}

    for idx, p in enumerate(panels_raw):
        if not isinstance(p, dict):
            raise ValueError(f"Template panels[{idx}] must be a mapping")
        if "panel_id" not in p or "panel_title" not in p:
            raise ValueError(
                f"Template panels[{idx}] missing 'panel_id' or 'panel_title'"
            )

        pid_raw = p.get("panel_id")
        if not isinstance(pid_raw, int):
            raise ValueError(
                f"Template panels[{idx}].panel_id must be int, got {pid_raw!r}"
            )
        pid = normalize_panel_id(pid_raw)

        ds_list = p.get("data_sources", []) or []
        if not isinstance(ds_list, list):
            raise ValueError(f"Template panels[{idx}].data_sources must be list")

        ds_out: list[TemplateDataSource] = []
        for j, ds in enumerate(ds_list):
            if not isinstance(ds, dict):
                raise ValueError(
                    f"Template panels[{idx}].data_sources[{j}] must be mapping"
                )
            for k in ("type", "id", "title"):
                if k not in ds:
                    raise ValueError(
                        f"Template panels[{idx}].data_sources[{j}] missing '{k}'"
                    )

            ds_id = ds["id"]
            if not isinstance(ds_id, int):
                raise ValueError(
                    f"Template panels[{idx}].data_sources[{j}].id must be int, "
                    f"got {ds_id!r}"
                )

            ds_out.append(
                TemplateDataSource(
                    type=_as_str(ds["type"]),
                    id=normalize_table_id(ds_id),
                    title=_as_str(ds["title"]),
                )
            )

        panel = TemplatePanel(
            file=_as_str(p.get("file", "")),
            panel_id=pid,
            panel_title=_as_str(p.get("panel_title")),
            panel_alias=p.get("panel_alias"),
            data_sources=tuple(ds_out),
        )

        if pid in by_id:
            raise ValueError(f"Duplicate panel_id {pid} in template")

        panels.append(panel)
        by_id[pid] = panel

    return panels, by_id


def extract_panel_info(
    yaml_file: Path,
) -> tuple[Optional[int], dict[str, Any], list[dict[str, Any]]]:
    """Return (panel_id, panel_config, extracted_data_sources)."""
    data = cm_utils.load_yaml(yaml_file) or {}
    panel_config = data.get("Panel Config")
    if not isinstance(panel_config, dict):
        return None, {}, []

    pid_raw = panel_config.get("id")
    pid = normalize_panel_id(pid_raw) if isinstance(pid_raw, int) else None

    ds_extracted: list[dict[str, Any]] = []
    ds_list = panel_config.get("data source", [])
    if isinstance(ds_list, list):
        for item in ds_list:
            if not isinstance(item, dict):
                continue
            for ds_type, value in item.items():
                if (
                    isinstance(value, dict)
                    and isinstance(value.get("id"), int)
                    and "title" in value
                ):
                    ds_extracted.append({
                        "type": str(ds_type),
                        "id": normalize_table_id(value["id"]),
                        "title": _as_str(value.get("title")),
                    })

    return pid, panel_config, ds_extracted


def validate_arch(
    arch_dir: Path,
    template_panels: list[TemplatePanel],
    template_by_id: dict[int, TemplatePanel],
    allowed_panel_keys: set[str],
) -> list[str]:
    """Validate one architecture directory. Returns list of errors."""
    errors: list[str] = []

    panel_files = sorted(arch_dir.glob("*.yaml"))
    actual_by_id: dict[int, Path] = {}
    actual_order: list[int] = []

    for f in panel_files:
        pid, panel_config, ds_actual = extract_panel_info(f)
        rel = f"{arch_dir.name}/{f.name}"

        if pid is None:
            errors.append(f"ERROR [{rel}]: Missing or non-integer Panel Config.id")
            continue

        # required keys
        missing = [k for k in REQUIRED_PANEL_KEYS if k not in panel_config]
        if missing:
            errors.append(
                f"ERROR [{rel}]: Missing required Panel Config keys: "
                f"{', '.join(missing)}"
            )

        # prohibited keys (unknown keys)
        for k in panel_config.keys():
            if k not in allowed_panel_keys:
                errors.append(
                    f"ERROR [{rel}]: Prohibited/unknown Panel Config key '{k}' "
                    f"(allowed: {sorted(allowed_panel_keys)})"
                )

        # panel must exist in template
        if pid not in template_by_id:
            errors.append(f"ERROR [{rel}]: Panel ID {pid} not found in template")
        else:
            expected = template_by_id[pid]
            actual_title = _as_str(panel_config.get("title"))
            if actual_title != expected.panel_title:
                errors.append(
                    f"ERROR [{rel}]: Panel title mismatch for id {pid}: "
                    f"expected '{expected.panel_title}', got '{actual_title}'"
                )

            # data sources must match count + order strictly
            if len(ds_actual) != len(expected.data_sources):
                errors.append(
                    f"ERROR [{rel}]: Data source count mismatch for panel "
                    f"{pid}: expected {len(expected.data_sources)}, "
                    f"got {len(ds_actual)}"
                )

            for i, exp_ds in enumerate(expected.data_sources):
                if i >= len(ds_actual):
                    break
                act = ds_actual[i]
                if (
                    act["type"] != exp_ds.type
                    or act["id"] != exp_ds.id
                    or act["title"] != exp_ds.title
                ):
                    errors.append(
                        f"ERROR [{rel}]: Data source #{i + 1} mismatch "
                        f"for panel {pid}: expected {exp_ds.type} id={exp_ds.id} "
                        f"title='{exp_ds.title}', got {act['type']} "
                        f"id={act['id']} title='{act['title']}'"
                    )

        # duplicates
        if pid in actual_by_id:
            errors.append(
                f"ERROR [{rel}]: Duplicate panel id {pid} "
                f"(also in {arch_dir.name}/{actual_by_id[pid].name})"
            )
        else:
            actual_by_id[pid] = f
            actual_order.append(pid)

    # missing / extra panels
    expected_ids = [p.panel_id for p in template_panels]
    actual_ids = set(actual_by_id.keys())
    expected_set = set(expected_ids)

    for pid in expected_ids:
        if pid not in actual_ids:
            errors.append(
                f"ERROR [{arch_dir.name}]: Missing panel id {pid} required by template"
            )

    for pid in sorted(actual_ids - expected_set):
        errors.append(
            f"ERROR [{arch_dir.name}/{actual_by_id[pid].name}]: "
            f"Extra panel id {pid} not present in template"
        )

    # panel ordering (based on file sorting)
    expected_order = [pid for pid in expected_ids if pid in actual_ids]
    if actual_order and expected_order and actual_order != expected_order:
        for i, (a, e) in enumerate(zip(actual_order, expected_order)):
            if a != e:
                errors.append(
                    f"ERROR [{arch_dir.name}]: Panel file order mismatch at position "
                    f"{i + 1}: expected panel id {e}, got {a} "
                    "(files must follow template order)"
                )
                break

    return errors


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Validate per-arch panel YAMLs against a shared config template."
    )
    parser.add_argument(
        "analysis_configs_dir", help="Directory containing architecture subdirs"
    )
    parser.add_argument("template_yaml", help="Template YAML (config_template.yaml)")
    parser.add_argument(
        "--allow-panel-key",
        action="append",
        default=[],
        help="Allow an additional key under 'Panel Config' (repeatable)",
    )
    args = parser.parse_args()

    configs_dir = Path(args.analysis_configs_dir)
    template_file = Path(args.template_yaml)

    if not configs_dir.is_dir():
        print(f"Error: {configs_dir} is not a directory")
        sys.exit(1)
    if not template_file.is_file():
        print(f"Error: {template_file} is not a file")
        sys.exit(1)

    template_panels, template_by_id = load_template(template_file)
    allowed_panel_keys = set(DEFAULT_ALLOWED_PANEL_KEYS) | set(args.allow_panel_key)
    print(f"Loading template from {template_file}")
    print(f"Template loaded: {len(template_panels)} panels\n")

    all_errors: list[str] = []
    total_arches = 0

    for arch_dir in sorted(configs_dir.iterdir()):
        if not arch_dir.is_dir():
            continue
        total_arches += 1
        print(f"{'=' * 80}\nValidating architecture: {arch_dir.name}\n{'=' * 80}")
        arch_errors = validate_arch(
            arch_dir=arch_dir,
            template_panels=template_panels,
            template_by_id=template_by_id,
            allowed_panel_keys=allowed_panel_keys,
        )
        if arch_errors:
            for e in arch_errors:
                print(e)
            all_errors.extend(arch_errors)
        else:
            print(f"PASS [{arch_dir.name}]: All panel YAMLs match template")
        print()

    print(f"{'=' * 80}\nVALIDATION SUMMARY\n{'=' * 80}")
    print(f"Architectures checked: {total_arches}")
    print(f"Total errors: {len(all_errors)}")

    if all_errors:
        print("\nValidation FAILED")
        sys.exit(1)
    else:
        print("\nValidation PASSED")


if __name__ == "__main__":
    main()
