#!/usr/bin/env python3
"""policy_snapshot.py — helper for the Satsignal policy-snapshot pattern.

What was the agent allowed to do at the time it acted? An auditor
investigating an incident later — a regulator, a customer, a court —
will want to know:

  1. what system policy was in force,
  2. what user instruction the agent was responding to,
  3. which tools it could call,
  4. what budget caps were set,
  5. which model + config it was running.

A *policy snapshot* is a JSON object with a sha256 fingerprint of each
of those five components. The whole snapshot is canonicalized + sha256'd,
and that hash is anchored on chain via /api/v1/anchors with
``category: "policy_snapshot"``. The on-chain commitment binds to all
five fingerprints simultaneously: an auditor with the original system
prompt can verify the system_policy_hash matches the receipt without
ever seeing the budget config — selective disclosure of the agent's
configuration.

JSON shape (v1):

    {
      "version": "satsignal-policy-snapshot-v1",
      "snapshot_at_utc": "<RFC3339Z>",
      "agent": { "name": "<label>", "version": "<version>" },
      "system_policy_hash":      "<sha256 hex of system prompt / policy doc>",
      "user_instruction_hash":   "<sha256 hex of user instruction>",
      "tool_permissions_hash":   "<sha256 hex of canonicalize(tools)>",
      "budget_limits_hash":      "<sha256 hex of canonicalize(budgets)>",
      "model_config_hash":       "<sha256 hex of canonicalize(model_cfg)>",
      "extra":                   { ... optional forward-compat ... }
    }

Any of the five hashes can be omitted if the field doesn't apply to
your setup; missing keys hash to nothing (the on-chain commitment
just doesn't bind to that component). ``extra`` is a free-form bag
for additional fingerprints that don't fit the five canonical slots.

Stdlib only. No Satsignal repo dependency. Copy freely.

Usage:

    # 1. Hash each component (each command prints {"sha256_hex": ...})
    python3 policy_snapshot.py hash-component --file system_prompt.txt
    python3 policy_snapshot.py hash-component --text "summarise this thread"
    python3 policy_snapshot.py hash-component --json-file tools.json
    python3 policy_snapshot.py hash-component --json-string '{"max_usd":5}'

    # 2. Build the snapshot. Pass the 5 hashes (or fewer); script prints
    #    snapshot.json + the sha256 to feed /api/v1/anchors.
    python3 policy_snapshot.py build \\
        --agent-name my-evaluator-bot \\
        --agent-version 1.4.2 \\
        --system-policy-hash <sha256> \\
        --user-instruction-hash <sha256> \\
        --tool-permissions-hash <sha256> \\
        --budget-limits-hash <sha256> \\
        --model-config-hash <sha256> \\
        --out snapshot.json

    # 3. Anchor (separately, against your scoped API key)
    SHA=$(jq -r .anchor.sha256_hex snapshot.json)
    SIZE=$(jq -r .anchor.file_size snapshot.json)
    curl -H "Authorization: Bearer sk_..." -H "Content-Type: application/json" \\
         -d "{\\"matter_slug\\":\\"agent-runs\\",\\"sha256_hex\\":\\"$SHA\\", \\
              \\"file_size\\":$SIZE,\\"category\\":\\"policy_snapshot\\", \\
              \\"label\\":\\"$(date -u +%FT%TZ)\\"}" \\
         https://app.satsignal.cloud/api/v1/anchors

    # 4. Verify (auditor side): given snapshot.json + the original system
    #    prompt, recompute and check.
    python3 policy_snapshot.py verify \\
        --snapshot snapshot.json \\
        --system-policy-file system_prompt.txt
"""
from __future__ import annotations

import argparse
import hashlib
import json
import sys
import time
import unicodedata
from typing import Any, Optional


# ---- JCS-style canonicalize (matches Satsignal notary.canonicalize) ----

def _nfc_deep(value: Any) -> Any:
    if value is None or isinstance(value, bool):
        return value
    if isinstance(value, int):
        return value
    if isinstance(value, float):
        # Real-world model configs commonly carry floats (temperature,
        # top_p, ...). The strict JCS rule rejects them; we relax so
        # ``hash_canonical`` is useful on actual configs. Encoding uses
        # Python's repr() (shortest round-trip), which matches
        # JS Number.prototype.toString() for IEEE 754 doubles. Reject
        # NaN/Infinity (no defined JSON shape).
        import math
        if not math.isfinite(value):
            raise ValueError(
                f"NaN/Infinity not allowed in canonical form: {value!r}"
            )
        return value
    if isinstance(value, str):
        return unicodedata.normalize("NFC", value)
    if isinstance(value, list):
        return [_nfc_deep(v) for v in value]
    if isinstance(value, dict):
        return {unicodedata.normalize("NFC", k): _nfc_deep(v) for k, v in value.items()}
    raise TypeError(
        f"non-canonicalizable type {type(value).__name__}: {value!r}"
    )


def canonicalize(doc: Any) -> bytes:
    """JCS-style canonical bytes: NFC + sorted keys + minimal JSON.

    NOTE: this is *slightly* more lenient than ``notary.canonicalize``
    in the Satsignal repo — finite floats are permitted here because
    model configs (the canonical Phase-8d use case) routinely contain
    them. Python's json.dumps emits floats via repr() (shortest
    round-trip), which IEEE-754 ↔ JS Number.toString() matches. The
    snapshot itself contains only strings/ints, so the snapshot's
    on-chain anchor hash is unaffected by this relaxation.
    """
    return json.dumps(
        _nfc_deep(doc), sort_keys=True, separators=(",", ":"),
        ensure_ascii=False, allow_nan=False,
    ).encode("utf-8")


# ---- component hashing ----

def hash_text(text: str) -> str:
    """sha256 of a UTF-8-encoded text string. Use for system prompts,
    user instructions — anything that's a single block of text. The
    exact bytes are what's hashed, no NFC normalization (auditors who
    have a copy of the original want a byte-exact comparison)."""
    return hashlib.sha256(text.encode("utf-8")).hexdigest()


def hash_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            chunk = f.read(1 << 20)
            if not chunk:
                break
            h.update(chunk)
    return h.hexdigest()


def hash_canonical(obj: Any) -> str:
    """sha256 of the JCS-canonical form. Use for tool lists, budget
    configs, model config dicts — anything that's a structured object
    where key order, whitespace, etc. shouldn't matter."""
    return hashlib.sha256(canonicalize(obj)).hexdigest()


# ---- snapshot construction ----

WIRE_VERSION = "satsignal-policy-snapshot-v1"

_HASH_FIELDS = (
    "system_policy_hash",
    "user_instruction_hash",
    "tool_permissions_hash",
    "budget_limits_hash",
    "model_config_hash",
)


def build_snapshot(
    *,
    agent_name: Optional[str] = None,
    agent_version: Optional[str] = None,
    system_policy_hash: Optional[str] = None,
    user_instruction_hash: Optional[str] = None,
    tool_permissions_hash: Optional[str] = None,
    budget_limits_hash: Optional[str] = None,
    model_config_hash: Optional[str] = None,
    extra: Optional[dict] = None,
    snapshot_at_utc: Optional[str] = None,
) -> dict:
    """Assemble a snapshot dict + the anchor metadata (sha256 + size).

    Returns:
        {
          "snapshot": { ...the canonical snapshot... },
          "anchor": { "sha256_hex": <hex>, "file_size": <int> }
        }

    The ``snapshot`` is what auditors receive (alongside any of the
    original components they want to verify). The ``anchor`` is what
    you POST to /api/v1/anchors with category=policy_snapshot.
    """
    if snapshot_at_utc is None:
        snapshot_at_utc = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    snapshot: dict = {
        "version": WIRE_VERSION,
        "snapshot_at_utc": snapshot_at_utc,
    }
    agent_obj: dict = {}
    if agent_name:
        agent_obj["name"] = agent_name
    if agent_version:
        agent_obj["version"] = agent_version
    if agent_obj:
        snapshot["agent"] = agent_obj

    pairs = (
        ("system_policy_hash", system_policy_hash),
        ("user_instruction_hash", user_instruction_hash),
        ("tool_permissions_hash", tool_permissions_hash),
        ("budget_limits_hash", budget_limits_hash),
        ("model_config_hash", model_config_hash),
    )
    for k, v in pairs:
        if v is None:
            continue
        v = v.strip().lower()
        if len(v) != 64 or any(c not in "0123456789abcdef" for c in v):
            raise ValueError(f"{k}: must be a 64-char lowercase sha256 hex")
        snapshot[k] = v
    if extra:
        snapshot["extra"] = extra

    canonical = canonicalize(snapshot)
    return {
        "snapshot": snapshot,
        "anchor": {
            "sha256_hex": hashlib.sha256(canonical).hexdigest(),
            "file_size": len(canonical),
        },
    }


def verify_component(
    snapshot: dict, *, field: str, candidate_hash: str,
) -> bool:
    """Check that ``candidate_hash`` matches snapshot[field] in
    constant time. Returns False if the field isn't in the snapshot
    at all (not present means not committed)."""
    declared = snapshot.get(field)
    if not isinstance(declared, str):
        return False
    import secrets as _secrets
    return _secrets.compare_digest(
        declared.lower(), candidate_hash.lower()
    )


# ---- CLI ----

def _cmd_hash_component(args: argparse.Namespace) -> int:
    n_set = sum(1 for x in (args.file, args.text, args.json_file, args.json_string) if x is not None)
    if n_set != 1:
        sys.stderr.write(
            "[hash-component] exactly one of --file / --text / --json-file / "
            "--json-string is required\n"
        )
        return 2
    if args.file is not None:
        sha = hash_file(args.file)
    elif args.text is not None:
        sha = hash_text(args.text)
    elif args.json_file is not None:
        with open(args.json_file, "r", encoding="utf-8") as f:
            sha = hash_canonical(json.load(f))
    else:
        sha = hash_canonical(json.loads(args.json_string))
    print(json.dumps({"sha256_hex": sha}))
    return 0


def _cmd_build(args: argparse.Namespace) -> int:
    out = build_snapshot(
        agent_name=args.agent_name,
        agent_version=args.agent_version,
        system_policy_hash=args.system_policy_hash,
        user_instruction_hash=args.user_instruction_hash,
        tool_permissions_hash=args.tool_permissions_hash,
        budget_limits_hash=args.budget_limits_hash,
        model_config_hash=args.model_config_hash,
        snapshot_at_utc=args.snapshot_at_utc,
    )
    out_path = args.out or "snapshot.json"
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(out, f, indent=2, sort_keys=True)
        f.write("\n")
    sys.stderr.write(f"[build] wrote {out_path}\n")
    sys.stderr.write(
        "[build] anchor with: POST /api/v1/anchors "
        "{category: 'policy_snapshot'}\n"
    )
    print(json.dumps(out["anchor"]))
    return 0


def _cmd_verify(args: argparse.Namespace) -> int:
    with open(args.snapshot, "r", encoding="utf-8") as f:
        wrapped = json.load(f)
    snapshot = wrapped.get("snapshot") if isinstance(wrapped.get("snapshot"), dict) else wrapped
    if not isinstance(snapshot, dict):
        sys.stderr.write("[verify] snapshot file is not a dict\n")
        return 2

    checks = []  # (field, status, candidate_hash, declared_hash)

    def _check(field, candidate):
        if candidate is None:
            return
        ok = verify_component(snapshot, field=field, candidate_hash=candidate)
        checks.append({
            "field": field, "verified": ok,
            "candidate_sha256_hex": candidate,
            "declared_sha256_hex": snapshot.get(field, ""),
        })

    if args.system_policy_file:
        _check("system_policy_hash", hash_file(args.system_policy_file))
    if args.system_policy_text is not None:
        _check("system_policy_hash", hash_text(args.system_policy_text))
    if args.user_instruction_file:
        _check("user_instruction_hash", hash_file(args.user_instruction_file))
    if args.user_instruction_text is not None:
        _check("user_instruction_hash", hash_text(args.user_instruction_text))
    if args.tool_permissions_json_file:
        with open(args.tool_permissions_json_file, "r", encoding="utf-8") as f:
            _check("tool_permissions_hash", hash_canonical(json.load(f)))
    if args.budget_limits_json_file:
        with open(args.budget_limits_json_file, "r", encoding="utf-8") as f:
            _check("budget_limits_hash", hash_canonical(json.load(f)))
    if args.model_config_json_file:
        with open(args.model_config_json_file, "r", encoding="utf-8") as f:
            _check("model_config_hash", hash_canonical(json.load(f)))

    print(json.dumps({
        "snapshot_at_utc": snapshot.get("snapshot_at_utc"),
        "checks": checks,
        "all_verified": bool(checks) and all(c["verified"] for c in checks),
    }, indent=2))
    return 0 if (checks and all(c["verified"] for c in checks)) else 1


def main(argv=None) -> int:
    p = argparse.ArgumentParser(
        prog="policy_snapshot", description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    ph = sub.add_parser("hash-component",
                        help="hash one of the five canonical components")
    ph.add_argument("--file",
                    help="path to a file (system prompt, instructions, etc.) "
                         "to sha256")
    ph.add_argument("--text",
                    help="literal text to sha256 (mutually exclusive with --file)")
    ph.add_argument("--json-file",
                    help="path to a JSON file; canonicalize-then-sha256")
    ph.add_argument("--json-string",
                    help="literal JSON string; canonicalize-then-sha256")
    ph.set_defaults(func=_cmd_hash_component)

    pb = sub.add_parser("build", help="assemble snapshot.json + anchor info")
    pb.add_argument("--agent-name")
    pb.add_argument("--agent-version")
    pb.add_argument("--system-policy-hash")
    pb.add_argument("--user-instruction-hash")
    pb.add_argument("--tool-permissions-hash")
    pb.add_argument("--budget-limits-hash")
    pb.add_argument("--model-config-hash")
    pb.add_argument("--snapshot-at-utc",
                    help="(testing only) override the timestamp; default is now")
    pb.add_argument("--out",
                    help="path for snapshot.json (default snapshot.json)")
    pb.set_defaults(func=_cmd_build)

    pv = sub.add_parser("verify", help="recompute hashes from originals + check")
    pv.add_argument("--snapshot", required=True,
                    help="path to a snapshot.json (or the wrapped output of "
                         "the build command)")
    pv.add_argument("--system-policy-file")
    pv.add_argument("--system-policy-text")
    pv.add_argument("--user-instruction-file")
    pv.add_argument("--user-instruction-text")
    pv.add_argument("--tool-permissions-json-file")
    pv.add_argument("--budget-limits-json-file")
    pv.add_argument("--model-config-json-file")
    pv.set_defaults(func=_cmd_verify)

    args = p.parse_args(argv)
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())
