#!/usr/bin/env python3
"""commit_reveal.py — helper for the Satsignal commit-reveal pattern.

The pattern lets two parties commit to values BEFORE either reveals,
with no trusted clock and no after-the-fact editing. Each party:

  1. Wraps their payload with a fresh 32-byte nonce.
  2. JCS-canonicalizes + sha256's the wrapped form.
  3. Anchors the resulting hash with category=commitment via
     POST https://app.satsignal.cloud/api/v1/anchors.
  4. Holds the wrapped form private until the reveal deadline.
  5. After the deadline, publishes the wrapped form. Anyone re-hashes
     and verifies against the receipt — and the chain timestamp proves
     the commit predates the reveal.

The nonce is what makes the commit hiding (no one can guess your
payload from the hash alone) AND binding (you can't change it after
seeing someone else's reveal — the hash already commits to a specific
canonical form).

This script needs only the Python standard library. No Satsignal repo
dependency. Copy it freely; pin its sha256 if you care about supply
chain. Spec compatibility: emits canonical bytes that match
notary.canonicalize() in the Satsignal repo (NFC normalize + sorted
keys + minimal JSON, no whitespace).

Usage:
    # 1. Commit (inline JSON payload via stdin, or pass a path)
    echo '{"agent_id": "alpha", "score": 73}' | \\
        python3 commit_reveal.py commit \\
            --payload-json - \\
            --out commit_record.json \\
            --out-anchor anchor_body.json
    # → writes commit_record.json — KEEP PRIVATE until reveal
    # → writes anchor_body.json — ready-to-curl body for /api/v1/anchors
    # → prints sha256_hex + file_size to stdout (for scripted callers)

    # 2. Anchor the commit (--out-anchor lets you POST it directly)
    curl -H "Authorization: Bearer sk_..." \\
         -H "Content-Type: application/json" \\
         -d @anchor_body.json \\
         https://app.satsignal.cloud/api/v1/anchors

    # NOTE: anchor_body.json carries {sha256_hex, file_size, category}.
    # Add matter_slug + label/filename if your API key targets multiple
    # matters. NEVER anchor sha256(commit_record.json) — that file
    # carries timestamps + the wrapper, its sha is NOT the commitment.

    # 3. Reveal: share the canonical bytes from commit_record.json
    python3 commit_reveal.py reveal \\
        --commit-record commit_record.json \\
        --out reveal_payload.bin

    # 4. Verify a counterparty's reveal against their receipt.
    # Pass either the record itself (auto-extracts canonical bytes)
    # or the bare canonical-bytes file:
    python3 commit_reveal.py verify \\
        --commit-record their_record.json \\
        --expected-sha256 <their on-chain commitment>
    # or
    python3 commit_reveal.py verify \\
        --canonical-bytes reveal_payload.bin \\
        --expected-sha256 <their on-chain commitment>
"""
from __future__ import annotations

import argparse
import base64
import hashlib
import json
import secrets
import sys
import time
import unicodedata
from typing import Any


# ---- JCS-style canonicalization (matches Satsignal notary.canonicalize) ----

def _nfc_deep(value: Any) -> Any:
    if value is None or isinstance(value, bool):
        return value
    if isinstance(value, int):
        return value
    if isinstance(value, float):
        # JCS forbids floats; the Satsignal canonicalizer raises here.
        # Caller should pre-quantize to integers (sat amounts, etc.)
        # or strings if a decimal is meaningful.
        raise ValueError(f"floats are not allowed in canonical form: {value!r}")
    if isinstance(value, str):
        return unicodedata.normalize("NFC", value)
    if isinstance(value, list):
        return [_nfc_deep(v) for v in value]
    if isinstance(value, dict):
        return {unicodedata.normalize("NFC", k): _nfc_deep(v) for k, v in value.items()}
    raise TypeError(
        f"non-canonicalizable type {type(value).__name__}: {value!r}"
    )


def canonicalize(doc: Any) -> bytes:
    """Return the canonical byte string for a JSON-shaped doc.

    NFC-normalize all strings, sort all dict keys, emit minimal JSON
    (no whitespace, no trailing commas, no NaN/Infinity). Matches the
    Satsignal verifier's JS canonicalize().
    """
    normalized = _nfc_deep(doc)
    return json.dumps(
        normalized, sort_keys=True, separators=(",", ":"),
        ensure_ascii=False, allow_nan=False,
    ).encode("utf-8")


# ---- commit / reveal / verify ----

WIRE_VERSION = "satsignal-commit-reveal-v1"


def make_commit(payload: Any, *, nonce_hex: str = "") -> dict:
    """Produce a commit record for ``payload``. Returns a dict with:

        version          — wire-format tag (for forward compat)
        nonce_hex        — 32 random bytes as 64 hex chars
        payload          — your original object (for your own records)
        canonical_b64    — the canonical bytes that hash to sha256_hex,
                           base64-encoded for ASCII-safe storage
        sha256_hex       — what you anchor with category=commitment
        file_size        — what /api/v1/anchors expects alongside it
        created_utc      — RFC3339 Z timestamp for your records

    KEEP THE WHOLE RECORD PRIVATE until the reveal deadline. After
    the deadline, publishing canonical_b64 (decoded to bytes) lets
    anyone verify your committed payload.
    """
    if not nonce_hex:
        nonce_hex = secrets.token_hex(32)
    if len(nonce_hex) != 64 or not all(c in "0123456789abcdef" for c in nonce_hex):
        raise ValueError("nonce_hex must be exactly 64 lowercase hex chars")
    wrapped = {"nonce_hex": nonce_hex, "payload": payload}
    canonical_bytes = canonicalize(wrapped)
    sha256_hex = hashlib.sha256(canonical_bytes).hexdigest()
    return {
        "version": WIRE_VERSION,
        "nonce_hex": nonce_hex,
        "payload": payload,
        "canonical_b64": base64.b64encode(canonical_bytes).decode("ascii"),
        "sha256_hex": sha256_hex,
        "file_size": len(canonical_bytes),
        "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    }


def verify_reveal(canonical_bytes: bytes, expected_sha256_hex: str) -> bool:
    """Recompute sha256 over the candidate canonical bytes and compare
    to the on-chain commitment hex. Constant-time compare to avoid
    leaking timing info on hash mismatch (probably overkill here, but
    the helper has it anyway)."""
    got = hashlib.sha256(canonical_bytes).hexdigest()
    return secrets.compare_digest(got.lower(), expected_sha256_hex.lower())


# ---- CLI ----

def _cmd_commit(args: argparse.Namespace) -> int:
    if args.payload_json == "-":
        payload = json.load(sys.stdin)
    else:
        with open(args.payload_json, "r", encoding="utf-8") as f:
            payload = json.load(f)
    record = make_commit(payload, nonce_hex=args.nonce_hex or "")
    out_path = args.out or "commit_record.json"
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(record, f, indent=2, sort_keys=True)
        f.write("\n")
    sys.stderr.write(
        f"[commit] wrote {out_path} (KEEP PRIVATE until reveal)\n"
    )
    if args.out_anchor:
        anchor_body = {
            "sha256_hex": record["sha256_hex"],
            "file_size": record["file_size"],
            "category": "commitment",
        }
        with open(args.out_anchor, "w", encoding="utf-8") as f:
            json.dump(anchor_body, f, indent=2, sort_keys=True)
            f.write("\n")
        sys.stderr.write(
            f"[commit] wrote {args.out_anchor} (ready-to-curl anchor body;\n"
            f"[commit]   POST it to /api/v1/anchors with -d @{args.out_anchor})\n"
        )
    sys.stderr.write(
        f"[commit] DO NOT sha256 {out_path} — its sha256 is NOT the\n"
        f"[commit]   commitment. The commitment is the sha256 below\n"
        f"[commit]   (or use --out-anchor for a ready-to-curl JSON body).\n"
    )
    print(json.dumps({
        "sha256_hex": record["sha256_hex"],
        "file_size": record["file_size"],
    }))
    return 0


def _cmd_reveal(args: argparse.Namespace) -> int:
    with open(args.commit_record, "r", encoding="utf-8") as f:
        record = json.load(f)
    if record.get("version") != WIRE_VERSION:
        sys.stderr.write(
            f"[reveal] unknown record version: {record.get('version')!r}\n"
        )
        return 2
    canonical_bytes = base64.b64decode(record["canonical_b64"])
    out = args.out or "reveal_payload.bin"
    with open(out, "wb") as f:
        f.write(canonical_bytes)
    sys.stderr.write(f"[reveal] wrote {out} ({len(canonical_bytes)} bytes)\n")
    sys.stderr.write(
        f"[reveal] expected sha256: {record['sha256_hex']}\n"
    )
    print(json.dumps({
        "sha256_hex": record["sha256_hex"],
        "file_size": len(canonical_bytes),
        "out": out,
    }))
    return 0


def _lookup_hash(host: str, sha_hex: str, *, timeout: float = 10.0) -> dict:
    """Probe lookup_hash to confirm a sha was actually anchored on
    chain. Returns one of:

      {"state": "confirmed", "txid": "...", "bundle_id": "..."}
      {"state": "missing"}                          # 200 {} from server
      {"state": "error", "note": "<reason>"}        # network / 4xx / 5xx
    """
    import urllib.request
    import urllib.error
    url = host.rstrip("/") + "/lookup_hash?sha=" + sha_hex
    try:
        with urllib.request.urlopen(url, timeout=timeout) as r:
            body = json.loads(r.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        return {"state": "error",
                "note": f"HTTP {e.code} from {url}"}
    except urllib.error.URLError as e:
        return {"state": "error",
                "note": f"network error reaching {url}: {e.reason}"}
    except Exception as e:  # noqa: BLE001
        return {"state": "error",
                "note": f"lookup_hash failed: {e}"}
    if not body or not body.get("txid"):
        return {"state": "missing"}
    out = {"state": "confirmed", "txid": body["txid"]}
    if body.get("bundle_id"):
        out["bundle_id"] = body["bundle_id"]
    return out


def _cmd_verify(args: argparse.Namespace) -> int:
    # `using_record_self_hash` flags the forgery-vulnerable mode: the
    # record self-claims its own sha256, the user has not supplied an
    # external chain hash. Cryptographic checks alone won't catch a
    # locally-fabricated record. Chain-confirm activates here.
    using_record_self_hash = False
    if args.commit_record:
        with open(args.commit_record, "r", encoding="utf-8") as f:
            record = json.load(f)
        if record.get("version") != WIRE_VERSION:
            sys.stderr.write(
                f"[verify] unknown record version: {record.get('version')!r}\n"
            )
            return 2
        data = base64.b64decode(record["canonical_b64"])
        expected = args.expected_sha256 or record.get("sha256_hex")
        if not expected:
            sys.stderr.write(
                "[verify] record has no sha256_hex; pass --expected-sha256\n"
            )
            return 2
        if not args.expected_sha256:
            using_record_self_hash = True
    else:
        with open(args.canonical_bytes, "rb") as f:
            data = f.read()
        expected = args.expected_sha256
        if not expected:
            sys.stderr.write(
                "[verify] --expected-sha256 is required when verifying "
                "raw --canonical-bytes\n"
            )
            return 2
    ok = verify_reveal(data, expected)
    if not ok:
        got = hashlib.sha256(data).hexdigest()
        print(json.dumps({
            "verified": False,
            "expected_sha256_hex": expected,
            "got_sha256_hex": got,
            "hint": "Either the canonical bytes were modified after commit, "
                    "or this isn't the canonical form that produced the receipt.",
        }))
        return 1

    out: dict = {"verified": True, "sha256_hex": expected}

    if not using_record_self_hash:
        # User supplied the chain hash via --expected-sha256. Crypto
        # check confirmed the canonical bytes match it. Trust anchor
        # is the user's chain-hash claim, not the helper's network.
        # Chain-confirm doesn't apply — print + done.
        print(json.dumps(out))
        return 0

    # Forgery-vulnerable mode (--commit-record without
    # --expected-sha256). The record self-claims its sha256_hex; an
    # attacker could fabricate {canonical_b64, sha256_hex} pair that
    # passes crypto. lookup_hash confirms the sha is actually anchored.
    if args.chain_confirm:
        chain = _lookup_hash(args.lookup_host, expected)
        out["chain_check"] = chain
        if chain["state"] == "confirmed":
            print(json.dumps(out))
            return 0
        if chain["state"] == "missing":
            out["verified"] = False
            out["forgery_suspected"] = True
            out["forgery_note"] = (
                "Record self-consistency holds, but lookup_hash shows "
                "no on-chain anchor for sha256_hex. Either the anchor "
                "has not yet been broadcast, or the record was "
                "fabricated locally. Treat as unverified until an "
                "on-chain anchor exists, or pass --expected-sha256 "
                "with the chain hash you trust."
            )
            print(json.dumps(out))
            return 1
        # state == "error": fail closed by default.
        out["verified"] = False
        out["chain_check_error"] = chain["note"]
        sys.stderr.write(
            "[verify] could not confirm chain anchor: "
            f"{chain['note']}\n"
            "[verify]   re-run when network is available, pass "
            "--expected-sha256 with the trusted chain hash, or pass "
            "--no-chain-confirm to accept self-consistency only.\n"
        )
        print(json.dumps(out))
        return 1

    # --no-chain-confirm: explicit offline self-consistency check.
    out["chain_check"] = {
        "state": "skipped",
        "note": "--no-chain-confirm: record self-consistency only; "
                "the on-chain anchor was not confirmed. A locally-"
                "fabricated record would also pass.",
    }
    sys.stderr.write(
        "[verify] WARNING: --no-chain-confirm set; record self-\n"
        "[verify]   consistency holds but no on-chain confirmation\n"
        "[verify]   was performed. A locally-fabricated record would\n"
        "[verify]   also pass.\n"
    )
    print(json.dumps(out))
    return 0


def main(argv=None) -> int:
    p = argparse.ArgumentParser(
        prog="commit_reveal", description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    pc = sub.add_parser("commit", help="produce a commit record")
    pc.add_argument("--payload-json", required=True,
                    help="path to JSON file with the payload to commit, "
                         "or '-' to read from stdin")
    pc.add_argument("--out", default=None,
                    help="path to write the commit record (default "
                         "commit_record.json). KEEP PRIVATE until reveal.")
    pc.add_argument("--out-anchor", default=None,
                    help="optional path to write a ready-to-curl JSON "
                         "body for /api/v1/anchors with {sha256_hex, "
                         "file_size, category: 'commitment'}. POST it "
                         "directly with -d @<path>.")
    pc.add_argument("--nonce-hex", default=None,
                    help="(testing only) supply a deterministic 64-hex "
                         "nonce; default is 32 random bytes from secrets")
    pc.set_defaults(func=_cmd_commit)

    pr = sub.add_parser("reveal", help="extract the canonical payload bytes "
                                       "from a commit record")
    pr.add_argument("--commit-record", required=True,
                    help="path to commit_record.json")
    pr.add_argument("--out", default=None,
                    help="path to write the canonical bytes (default "
                         "reveal_payload.bin)")
    pr.set_defaults(func=_cmd_reveal)

    pv = sub.add_parser("verify", help="verify revealed canonical bytes "
                                       "match an expected sha256")
    pv_src = pv.add_mutually_exclusive_group(required=True)
    pv_src.add_argument("--canonical-bytes",
                        help="path to the candidate canonical bytes "
                             "(produced by `reveal --out`)")
    pv_src.add_argument("--commit-record",
                        help="path to a commit_record.json (auto-extracts "
                             "canonical_b64 — no need to run `reveal` first)")
    pv.add_argument("--expected-sha256", default=None,
                    help="the on-chain committed sha256 (hex). Required "
                         "with --canonical-bytes; optional with "
                         "--commit-record (without it, falls back to "
                         "record.sha256_hex and chain-confirms via "
                         "lookup_hash).")
    # Chain-confirmation defends against the local-forgery vector
    # auditor flagged: --commit-record without --expected-sha256
    # falls back to record.sha256_hex, which the record self-attests.
    # Cryptographic checks alone won't catch a fabricated record;
    # lookup_hash confirms the sha is actually on chain. Default on.
    pv_chain = pv.add_mutually_exclusive_group()
    pv_chain.add_argument(
        "--chain-confirm", dest="chain_confirm",
        action="store_true", default=True,
        help="confirm on-chain anchor via lookup_hash when "
             "--expected-sha256 is omitted (default: on)",
    )
    pv_chain.add_argument(
        "--no-chain-confirm", dest="chain_confirm",
        action="store_false",
        help="skip lookup_hash; verify record self-consistency only "
             "(unsafe — locally-fabricated records will pass)",
    )
    pv.add_argument(
        "--lookup-host", default="https://proof.satsignal.cloud",
        help=argparse.SUPPRESS,
    )
    pv.set_defaults(func=_cmd_verify)

    args = p.parse_args(argv)
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())
