#!/usr/bin/env python3
"""
ee_mesh_build.py

Build a first-pass spherical triangle mesh for the Expanding Earth Authoring
Engine (EEAE) from an ee-project-v1 build JSON file.

This script is intentionally a MESH BUILDER, not the reverse-time deformation
solver. Its job is to convert the current EEAE line/constraint authoring state
into a connected spherical surface substrate that later solver scripts can use.

Primary outputs:
  1. build1.mesh.v1.json           Full mesh and constraint-anchor sidecar
  2. build1.mesh.preview.geojson   Lightweight triangle preview for EEAE/deck.gl
  3. build1.mesh.diagnostics.json  Mesh-quality and ingestion diagnostics

Core geometric idea:
  - Convert authored lon/lat linework and control points to 3D unit vectors.
  - Add a deterministic background point field so the whole sphere is covered.
  - Optionally add deterministic Steiner refinement points to reduce long edges
    and skinny triangles before final hull construction.
  - Use scipy.spatial.ConvexHull on the unit vectors. For points on a sphere,
    the hull facets are the spherical Delaunay triangles.
  - Use scipy.spatial.SphericalVoronoi to generate the dual Voronoi cells.

This does not yet close oceans. It only creates the topological substrate.
"""

from __future__ import annotations

import argparse
import bisect
import collections
import datetime as _dt
import hashlib
import json
import math
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Counter, DefaultDict, Iterable, Iterator, Optional

import numpy as np
from scipy.spatial import ConvexHull, QhullError, SphericalVoronoi, cKDTree


SCHEMA_VERSION = "ee-mesh-v1"
DEFAULT_RADIUS_KM = 6371.0
AGE_KEYS = (
    "age_Ma",
    "age_ma",
    "AGE_MA",
    "age_min",
    "age",
    "Age",
    "crust_age_ma",
    "crust_age",
    "crustAgeMa",
    "header_age",
)


# -----------------------------------------------------------------------------
# Data containers
# -----------------------------------------------------------------------------


@dataclass
class SeedPoint:
    unit: np.ndarray
    lon: float
    lat: float
    kind: str
    material: str = "unknown"
    age_ma: Optional[float] = None
    group_id: Optional[str] = None
    is_mor: bool = False
    source_refs: list[dict[str, Any]] = field(default_factory=list)


@dataclass
class VertexRecord:
    id: str
    unit: np.ndarray
    lon: float
    lat: float
    seed_kinds: Counter[str]
    materials: Counter[str]
    ages: list[float]
    groups: Counter[str]
    source_refs: list[dict[str, Any]]
    source_ref_overflow_count: int = 0
    is_mor: bool = False
    young_continental_weight: float = 0.0

    @property
    def material(self) -> str:
        return choose_material(self.materials)

    @property
    def age_ma(self) -> Optional[float]:
        if not self.ages:
            return None
        return float(np.median(np.array(self.ages, dtype=float)))

    @property
    def group_id(self) -> Optional[str]:
        if not self.groups:
            return None
        return self.groups.most_common(1)[0][0]


# -----------------------------------------------------------------------------
# Math helpers
# -----------------------------------------------------------------------------


def normalize_lon(lon: float) -> float:
    """Normalize longitude into [-180, 180)."""
    normalized = ((float(lon) + 180.0) % 360.0) - 180.0
    # Avoid emitting -180 for values that were exactly +180 where possible.
    if normalized == -180.0 and lon > 0:
        return 180.0
    return normalized


def clamp(x: float, lo: float, hi: float) -> float:
    return max(lo, min(hi, x))


def safe_unit(vec: np.ndarray) -> np.ndarray:
    norm = float(np.linalg.norm(vec))
    if norm <= 0.0 or not np.isfinite(norm):
        raise ValueError("Cannot normalize zero/non-finite vector")
    return vec / norm


def lonlat_to_unit(lon: float, lat: float) -> np.ndarray:
    lon_rad = math.radians(float(lon))
    lat_rad = math.radians(float(lat))
    cos_lat = math.cos(lat_rad)
    return np.array(
        [
            cos_lat * math.cos(lon_rad),
            cos_lat * math.sin(lon_rad),
            math.sin(lat_rad),
        ],
        dtype=float,
    )


def unit_to_lonlat(unit: np.ndarray) -> tuple[float, float]:
    u = safe_unit(unit)
    lon = math.degrees(math.atan2(float(u[1]), float(u[0])))
    lat = math.degrees(math.asin(clamp(float(u[2]), -1.0, 1.0)))
    return normalize_lon(lon), lat


def angular_distance_rad(a: np.ndarray, b: np.ndarray) -> float:
    # atan2(cross, dot) is more stable than arccos(dot) for small distances.
    cross_norm = float(np.linalg.norm(np.cross(a, b)))
    dot = clamp(float(np.dot(a, b)), -1.0, 1.0)
    return math.atan2(cross_norm, dot)


def chord_to_angle_rad(chord: float) -> float:
    return 2.0 * math.asin(clamp(float(chord) / 2.0, 0.0, 1.0))


def km_to_chord(km: float, radius_km: float) -> float:
    angle = float(km) / float(radius_km)
    return 2.0 * math.sin(angle / 2.0)


def slerp_unit(a: np.ndarray, b: np.ndarray, t: float) -> np.ndarray:
    t = float(t)
    dot = clamp(float(np.dot(a, b)), -1.0, 1.0)
    omega = math.acos(dot)
    if omega < 1e-12:
        return safe_unit((1.0 - t) * a + t * b)
    sin_omega = math.sin(omega)
    return safe_unit(
        math.sin((1.0 - t) * omega) / sin_omega * a
        + math.sin(t * omega) / sin_omega * b
    )


def spherical_triangle_area_sr(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> float:
    """Return spherical triangle solid angle on the unit sphere in steradians."""
    det = float(np.dot(a, np.cross(b, c)))
    denom = 1.0 + float(np.dot(a, b)) + float(np.dot(b, c)) + float(np.dot(c, a))
    area = 2.0 * math.atan2(abs(det), denom)
    # In rare near-degenerate cases denom can produce a negative/large branch.
    if area < 0.0:
        area += 4.0 * math.pi
    return float(area)


def spherical_triangle_angles_rad(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> tuple[float, float, float]:
    """Return the three internal angles of a spherical triangle in radians."""
    # Side lengths opposite A/B/C.
    side_a = angular_distance_rad(b, c)
    side_b = angular_distance_rad(c, a)
    side_c = angular_distance_rad(a, b)

    def angle_from_sides(opposite: float, side1: float, side2: float) -> float:
        denom = math.sin(side1) * math.sin(side2)
        if abs(denom) < 1e-15:
            return 0.0
        value = (math.cos(opposite) - math.cos(side1) * math.cos(side2)) / denom
        return math.acos(clamp(value, -1.0, 1.0))

    return (
        angle_from_sides(side_a, side_b, side_c),
        angle_from_sides(side_b, side_c, side_a),
        angle_from_sides(side_c, side_a, side_b),
    )


def triangle_quality_metrics(
    a: np.ndarray,
    b: np.ndarray,
    c: np.ndarray,
    radius_km: float,
) -> dict[str, Any]:
    edge_lengths = [
        angular_distance_rad(a, b) * radius_km,
        angular_distance_rad(b, c) * radius_km,
        angular_distance_rad(c, a) * radius_km,
    ]
    angles = spherical_triangle_angles_rad(a, b, c)
    angles_deg = [math.degrees(x) for x in angles]
    shortest = min(edge_lengths) if edge_lengths else 0.0
    longest = max(edge_lengths) if edge_lengths else 0.0
    return {
        "edgeLengthsKm": edge_lengths,
        "minAngleDeg": min(angles_deg) if angles_deg else 0.0,
        "maxAngleDeg": max(angles_deg) if angles_deg else 0.0,
        "aspectRatio": (longest / shortest) if shortest > 1e-9 else 1.0e12,
        "longestEdgeKm": longest,
        "shortestEdgeKm": shortest,
    }


def spherical_circumcenter(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> Optional[np.ndarray]:
    """Approximate spherical circumcenter as the unit normal to the vertex plane."""
    normal = np.cross(b - a, c - a)
    norm = float(np.linalg.norm(normal))
    if norm < 1e-14 or not math.isfinite(norm):
        return None
    center = normal / norm
    centroid = a + b + c
    if float(np.dot(center, centroid)) < 0.0:
        center = -center
    return safe_unit(center)


def make_quality_vertex(unit: np.ndarray, index: int, source: str, iteration: int) -> VertexRecord:
    lon, lat = unit_to_lonlat(unit)
    return VertexRecord(
        id=f"v_{index:06d}",
        unit=safe_unit(unit),
        lon=lon,
        lat=lat,
        seed_kinds=collections.Counter({"qualityRefinement": 1}),
        materials=collections.Counter({"background": 1}),
        ages=[],
        groups=collections.Counter(),
        source_refs=[{
            "kind": "qualityRefinement",
            "source": source,
            "iteration": int(iteration),
        }],
        source_ref_overflow_count=0,
        is_mor=False,
    )


def auto_refine_long_edge_km(args: argparse.Namespace) -> float:
    if getattr(args, "refine_long_edge_km", 0.0) and args.refine_long_edge_km > 0.0:
        return float(args.refine_long_edge_km)
    # Validation uses 4x target spacing by default (720 km for a 180 km mesh).
    # Refine below that so the post-build validator has room to pass.
    return max(float(args.target_spacing_km) * 3.0, 420.0)


def candidate_is_far_enough(
    unit: np.ndarray,
    tree: cKDTree,
    selected_units: list[np.ndarray],
    min_chord: float,
) -> bool:
    chord, _ = tree.query(unit, k=1)
    if float(chord) < min_chord:
        return False
    for other in selected_units:
        if float(np.linalg.norm(unit - other)) < min_chord:
            return False
    return True


def refine_vertices_for_quality(
    vertices: list[VertexRecord],
    args: argparse.Namespace,
) -> tuple[list[VertexRecord], dict[str, Any]]:
    """Add deterministic Steiner points to reduce long edges and skinny cells.

    This is a lightweight spherical Delaunay-refinement pass. It does not move or
    delete authored source/zipper vertices. It only adds background-quality
    points at long-edge midpoints and selected triangle centers, then rebuilds
    the hull. The solver can still use the original source and zipper anchors,
    while the triangulation gains better numerical conditioning.
    """
    if not getattr(args, "quality_refine", True):
        return vertices, {"enabled": False}

    long_edge_km = auto_refine_long_edge_km(args)
    small_angle_deg = float(args.refine_small_angle_deg)
    aspect_threshold = float(args.refine_aspect_ratio)
    min_candidate_km = max(float(args.refine_min_candidate_distance_km), float(args.dedupe_km) * 2.0)
    min_chord = km_to_chord(min_candidate_km, args.radius_km)
    max_new = max(0, int(args.refine_max_new_points_per_iteration))
    iterations = max(0, int(args.refine_iterations))

    stats: dict[str, Any] = {
        "enabled": True,
        "iterationsRequested": iterations,
        "longEdgeThresholdKm": long_edge_km,
        "smallAngleThresholdDeg": small_angle_deg,
        "aspectRatioThreshold": aspect_threshold,
        "minCandidateDistanceKm": min_candidate_km,
        "maxNewPointsPerIteration": max_new,
        "initialVertexCount": len(vertices),
        "iterations": [],
    }

    current = list(vertices)
    for iteration in range(1, iterations + 1):
        if max_new <= 0:
            break
        points = np.vstack([v.unit for v in current])
        hull, qhull_options_used = build_hull(points, args.qhull_options)
        candidate_rows: list[tuple[float, str, np.ndarray]] = []
        long_edge_count = 0
        skinny_count = 0

        for simplex in hull.simplices:
            tri = orient_triangle_outward(simplex, points)
            a, b, c = points[tri[0]], points[tri[1]], points[tri[2]]
            metrics = triangle_quality_metrics(a, b, c, args.radius_km)
            edge_lengths = metrics["edgeLengthsKm"]
            max_edge = float(metrics["longestEdgeKm"])
            min_angle = float(metrics["minAngleDeg"])
            aspect = float(metrics["aspectRatio"])

            if max_edge > long_edge_km:
                long_edge_count += 1
                edge_pairs = ((a, b, edge_lengths[0]), (b, c, edge_lengths[1]), (c, a, edge_lengths[2]))
                for p0, p1, length in edge_pairs:
                    if length > long_edge_km:
                        severity = 10.0 + (float(length) / long_edge_km)
                        candidate_rows.append((severity, "longEdgeMidpoint", slerp_unit(p0, p1, 0.5)))

            if min_angle < small_angle_deg or aspect > aspect_threshold:
                skinny_count += 1
                # Circumcenter tends to improve triangle shape when valid.
                center = spherical_circumcenter(a, b, c)
                if center is not None:
                    severity = 5.0 + max(
                        (small_angle_deg / max(min_angle, 1e-6)),
                        (aspect / max(aspect_threshold, 1e-6)),
                    )
                    candidate_rows.append((severity, "skinnyCircumcenter", center))
                # Longest-edge midpoint is a conservative backup for near-linear
                # bands where the circumcenter may be far from useful.
                longest_idx = int(np.argmax(np.array(edge_lengths, dtype=float)))
                pair = ((a, b), (b, c), (c, a))[longest_idx]
                candidate_rows.append((4.0 + aspect, "skinnyLongestEdgeMidpoint", slerp_unit(pair[0], pair[1], 0.5)))

        candidate_rows.sort(key=lambda row: row[0], reverse=True)
        tree = cKDTree(points)
        selected_units: list[np.ndarray] = []
        selected_sources: list[str] = []
        rejected_close = 0

        for _severity, source, unit in candidate_rows:
            if len(selected_units) >= max_new:
                break
            unit = safe_unit(unit)
            if not candidate_is_far_enough(unit, tree, selected_units, min_chord):
                rejected_close += 1
                continue
            selected_units.append(unit)
            selected_sources.append(source)

        start_index = len(current)
        for offset, (unit, source) in enumerate(zip(selected_units, selected_sources)):
            current.append(make_quality_vertex(unit, start_index + offset, source, iteration))

        stats["iterations"].append({
            "iteration": iteration,
            "qhullOptionsUsed": qhull_options_used,
            "trianglesAnalyzed": int(len(hull.simplices)),
            "longTriangleCount": int(long_edge_count),
            "skinnyTriangleCount": int(skinny_count),
            "candidateCount": int(len(candidate_rows)),
            "rejectedAsTooClose": int(rejected_close),
            "addedVertexCount": int(len(selected_units)),
            "vertexCountAfter": int(len(current)),
            "candidateSources": dict(collections.Counter(selected_sources)),
        })

        if not selected_units:
            break

    stats["finalVertexCount"] = len(current)
    stats["addedVertexCountTotal"] = len(current) - len(vertices)
    return current, stats


def round_float(value: float, digits: int = 8) -> float:
    return round(float(value), digits)


# -----------------------------------------------------------------------------
# Project extraction helpers
# -----------------------------------------------------------------------------


def read_json(path: Path) -> Any:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def write_json(path: Path, payload: Any, *, pretty: bool = False) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        if pretty:
            json.dump(payload, f, ensure_ascii=False, indent=2, allow_nan=False)
        else:
            json.dump(payload, f, ensure_ascii=False, separators=(",", ":"), allow_nan=False)
        f.write("\n")


def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            h.update(chunk)
    return h.hexdigest()


def source_feature_id(project: dict[str, Any], index: int) -> str:
    source_id = (
        project.get("source", {}).get("sourceId")
        or project.get("project", {}).get("sourceId")
        or "source"
    )
    return f"{source_id}:idx:{index:06d}"


def iter_feature_paths(feature: dict[str, Any]) -> Iterator[tuple[str, list[Any]]]:
    geometry = feature.get("geometry") or {}
    gtype = geometry.get("type")
    coords = geometry.get("coordinates")
    if gtype == "LineString" and isinstance(coords, list):
        yield "line_0", coords
    elif gtype == "MultiLineString" and isinstance(coords, list):
        for idx, line in enumerate(coords):
            if isinstance(line, list):
                yield f"line_{idx}", line


def lonlat_from_coord(coord: Any) -> Optional[tuple[float, float]]:
    if not isinstance(coord, (list, tuple)) or len(coord) < 2:
        return None
    try:
        lon = float(coord[0])
        lat = float(coord[1])
    except (TypeError, ValueError):
        return None
    if not (math.isfinite(lon) and math.isfinite(lat)):
        return None
    if lat < -90.0 or lat > 90.0:
        return None
    return normalize_lon(lon), lat


def coerce_age_ma(props: dict[str, Any]) -> Optional[float]:
    for key in AGE_KEYS:
        value = props.get(key)
        if isinstance(value, bool):
            continue
        if isinstance(value, (int, float)) and math.isfinite(float(value)):
            return float(value)
        if isinstance(value, str) and value.strip():
            try:
                parsed = float(value)
            except ValueError:
                continue
            if math.isfinite(parsed):
                return parsed
    return None


def is_mor_feature(props: dict[str, Any]) -> bool:
    name = str(props.get("name") or props.get("Name") or "")
    layer = str(props.get("layer") or props.get("Layer") or "")
    return "MOR" in name.upper() or "MID-OCEAN" in name.upper() or "MOR" in layer.upper()


def classify_feature_material(props: dict[str, Any]) -> str:
    if is_mor_feature(props):
        return "mor"
    age = coerce_age_ma(props)
    if age is not None:
        return "oceanic"
    return "continental"


def get_scene(project: dict[str, Any], scene_ma: int | float) -> dict[str, Any]:
    scenes = project.get("scenes") or []
    for scene in scenes:
        if float(scene.get("timeMa", -999999.0)) == float(scene_ma):
            return scene
    if scenes:
        return scenes[0]
    return {
        "timeMa": scene_ma,
        "groups": [],
        "selectedSegments": [],
        "hiddenFeatureIds": [],
        "kinematicLinks": {
            "morAssignments": [],
            "groupAssociations": [],
            "zipperControls": [],
            "morMigrationPlans": [],
        },
    }


def build_group_segment_index(
    scene: dict[str, Any],
) -> DefaultDict[tuple[str, str], list[tuple[int, int, str]]]:
    index: DefaultDict[tuple[str, str], list[tuple[int, int, str]]] = collections.defaultdict(list)
    for seg in scene.get("selectedSegments") or []:
        feature_id = seg.get("featureId")
        path_id = seg.get("pathId", "line_0")
        group_id = seg.get("groupId")
        if not isinstance(feature_id, str) or not isinstance(path_id, str) or not isinstance(group_id, str):
            continue
        try:
            start = int(seg.get("startIndex"))
            end = int(seg.get("endIndex"))
        except (TypeError, ValueError):
            continue
        lo, hi = sorted((start, end))
        index[(feature_id, path_id)].append((lo, hi, group_id))
    return index


def group_for_sample(
    group_index: DefaultDict[tuple[str, str], list[tuple[int, int, str]]],
    feature_id: str,
    path_id: str,
    segment_index: int,
    vertex_index: Optional[int],
) -> Optional[str]:
    spans = group_index.get((feature_id, path_id))
    if not spans:
        return None
    votes: Counter[str] = collections.Counter()
    for lo, hi, group_id in spans:
        # Segment samples belong to a selected span if the segment lies between
        # the selected endpoint indices. Vertex samples can match inclusively.
        if vertex_index is not None:
            if lo <= vertex_index <= hi:
                votes[group_id] += 1
        else:
            if lo <= segment_index < hi:
                votes[group_id] += 1
    if not votes:
        return None
    return votes.most_common(1)[0][0]


def extract_young_crust_points(project: dict[str, Any]) -> list[tuple[float, float, str]]:
    """Extract young-crust coordinate cloud points from annotation layers and overlays."""
    points: list[tuple[float, float, str]] = []
    seen: set[tuple[int, int]] = set()

    def add(lon: Any, lat: Any, point_id: str) -> None:
        try:
            lon_f = normalize_lon(float(lon))
            lat_f = float(lat)
        except (TypeError, ValueError):
            return
        if not (math.isfinite(lon_f) and math.isfinite(lat_f) and -90.0 <= lat_f <= 90.0):
            return
        key = (round(lon_f, 6), round(lat_f, 6))
        if key in seen:
            return
        seen.add(key)
        points.append((lon_f, lat_f, point_id))

    for layer in project.get("annotationLayers") or []:
        if layer.get("type") != "youngContinentalCrust":
            continue
        for idx, point in enumerate(((layer.get("data") or {}).get("points") or [])):
            if isinstance(point, dict):
                add(point.get("lon"), point.get("lat"), str(point.get("id") or f"young_layer_{idx}"))

    source_geojson = project.get("source", {}).get("geojson") or {}
    props = source_geojson.get("properties") or {}
    for overlay in props.get("ee_reference_overlays") or []:
        if not isinstance(overlay, dict):
            continue
        overlay_id = str(overlay.get("id") or "reference_overlay")
        text = f"{overlay_id} {overlay.get('name') or ''} {overlay.get('kind') or ''}".lower()
        target = overlay.get("target") if isinstance(overlay.get("target"), dict) else {}
        is_young = (
            "young" in text and "continental" in text
        ) or target.get("crustType") == "continental"
        if not is_young:
            continue
        for idx, coord in enumerate(overlay.get("coords") or overlay.get("points") or overlay.get("coordinates") or []):
            if isinstance(coord, dict):
                add(coord.get("lon", coord.get("longitude")), coord.get("lat", coord.get("latitude")), str(coord.get("id") or f"{overlay_id}_{idx}"))
            elif isinstance(coord, (list, tuple)) and len(coord) >= 2:
                add(coord[0], coord[1], f"{overlay_id}_{idx}")

    return points


# -----------------------------------------------------------------------------
# Seed generation
# -----------------------------------------------------------------------------


def resample_path_units(
    coords: list[Any],
    *,
    spacing_km: float,
    radius_km: float,
    keep_source_vertices: bool,
) -> list[dict[str, Any]]:
    """Return resampled path points with segment/t provenance."""
    lonlats: list[tuple[float, float]] = []
    units: list[np.ndarray] = []
    for coord in coords:
        lonlat = lonlat_from_coord(coord)
        if lonlat is None:
            continue
        lonlats.append(lonlat)
        units.append(lonlat_to_unit(*lonlat))

    if len(units) == 0:
        return []
    if len(units) == 1:
        lon, lat = lonlats[0]
        return [
            {
                "unit": units[0],
                "lon": lon,
                "lat": lat,
                "segmentIndex": 0,
                "vertexIndex": 0,
                "t": 0.0,
                "pathDistanceKm": 0.0,
            }
        ]

    seg_lengths: list[float] = []
    cumulative = [0.0]
    for a, b in zip(units[:-1], units[1:]):
        length = angular_distance_rad(a, b) * radius_km
        if not math.isfinite(length):
            length = 0.0
        seg_lengths.append(length)
        cumulative.append(cumulative[-1] + length)
    total = cumulative[-1]

    if total <= 1e-9:
        lon, lat = lonlats[0]
        return [
            {
                "unit": units[0],
                "lon": lon,
                "lat": lat,
                "segmentIndex": 0,
                "vertexIndex": 0,
                "t": 0.0,
                "pathDistanceKm": 0.0,
            }
        ]

    sample_distances: set[float] = {0.0, total}
    spacing = max(1e-6, float(spacing_km))
    n_steps = int(math.floor(total / spacing))
    for step in range(1, n_steps + 1):
        d = step * spacing
        if 0.0 < d < total:
            sample_distances.add(d)

    if keep_source_vertices:
        for d in cumulative:
            sample_distances.add(d)

    samples: list[dict[str, Any]] = []
    for d in sorted(sample_distances):
        # Find the segment containing distance d.
        if d >= total:
            seg_idx = len(seg_lengths) - 1
            t = 1.0
            vertex_idx: Optional[int] = len(units) - 1
            unit = units[-1]
        else:
            seg_idx = max(0, bisect.bisect_right(cumulative, d) - 1)
            seg_idx = min(seg_idx, len(seg_lengths) - 1)
            seg_start = cumulative[seg_idx]
            seg_len = seg_lengths[seg_idx]
            if seg_len <= 1e-12:
                t = 0.0
            else:
                t = clamp((d - seg_start) / seg_len, 0.0, 1.0)
            vertex_idx = None
            if abs(t) < 1e-8:
                vertex_idx = seg_idx
                unit = units[seg_idx]
            elif abs(t - 1.0) < 1e-8:
                vertex_idx = seg_idx + 1
                unit = units[seg_idx + 1]
            else:
                unit = slerp_unit(units[seg_idx], units[seg_idx + 1], t)
        lon, lat = unit_to_lonlat(unit)
        samples.append(
            {
                "unit": unit,
                "lon": lon,
                "lat": lat,
                "segmentIndex": seg_idx,
                "vertexIndex": vertex_idx,
                "t": t,
                "pathDistanceKm": d,
            }
        )
    return samples


def fibonacci_sphere_points(n: int) -> list[np.ndarray]:
    """Deterministic near-uniform unit vectors."""
    n = int(n)
    if n <= 0:
        return []
    points: list[np.ndarray] = []
    golden_angle = math.pi * (3.0 - math.sqrt(5.0))
    for i in range(n):
        # Offset by 0.5 so we do not place exact duplicate poles.
        z = 1.0 - 2.0 * (i + 0.5) / n
        r = math.sqrt(max(0.0, 1.0 - z * z))
        theta = golden_angle * i
        points.append(np.array([r * math.cos(theta), r * math.sin(theta), z], dtype=float))
    return points


def cap_seed_points(
    seeds: list[SeedPoint],
    max_source_points: int,
    *,
    strategy: str = "balanced",
) -> tuple[list[SeedPoint], dict[str, Any]]:
    """Cap source-line seeds without letting one dense class dominate the mesh.

    Legacy behavior prioritized every grouped isochron seed over all normal
    source seeds. In Build1 that meant almost the entire 6,000-source cap was
    consumed by grouped isochron linework, starving continent/MOR/context seeds
    and making long/skinny Delaunay bands more likely. The balanced strategy
    preserves MORs and group-bearing isochrons, but reserves capacity for
    continental and ungrouped oceanic context.
    """
    if max_source_points <= 0 or len(seeds) <= max_source_points:
        return seeds, {
            "capped": False,
            "strategy": strategy,
            "inputCount": len(seeds),
            "outputCount": len(seeds),
        }

    if strategy == "legacy":
        high_priority: list[SeedPoint] = []
        normal: list[SeedPoint] = []
        for seed in seeds:
            if seed.is_mor or seed.group_id or seed.material == "mor":
                high_priority.append(seed)
            else:
                normal.append(seed)

        if len(high_priority) >= max_source_points:
            selected = even_sample(high_priority, max_source_points)
        else:
            remaining = max_source_points - len(high_priority)
            selected = high_priority + even_sample(normal, remaining)

        return selected, {
            "capped": True,
            "strategy": strategy,
            "inputCount": len(seeds),
            "outputCount": len(selected),
            "highPriorityInputCount": len(high_priority),
            "normalInputCount": len(normal),
        }

    buckets: dict[str, list[SeedPoint]] = {
        "mor": [],
        "continental": [],
        "groupedIsochron": [],
        "ungroupedOceanic": [],
        "other": [],
    }
    for seed in seeds:
        if seed.is_mor or seed.material == "mor":
            buckets["mor"].append(seed)
        elif seed.material in {"continental", "youngContinental"} or seed.kind == "continentalLine":
            buckets["continental"].append(seed)
        elif seed.group_id or seed.kind == "isochron":
            if seed.group_id:
                buckets["groupedIsochron"].append(seed)
            else:
                buckets["ungroupedOceanic"].append(seed)
        elif seed.material == "oceanic":
            buckets["ungroupedOceanic"].append(seed)
        else:
            buckets["other"].append(seed)

    # Quotas are intentionally soft. Any unused capacity is redistributed to
    # remaining buckets in geologically useful priority order.
    quotas = {
        "mor": max(250, int(max_source_points * 0.08)),
        "continental": max(500, int(max_source_points * 0.14)),
        "groupedIsochron": int(max_source_points * 0.58),
        "ungroupedOceanic": int(max_source_points * 0.14),
        "other": int(max_source_points * 0.06),
    }

    selected_by_bucket: dict[str, list[SeedPoint]] = {}
    selected_total = 0
    for name, bucket in buckets.items():
        take = min(len(bucket), quotas.get(name, 0))
        selected_by_bucket[name] = even_sample(bucket, take)
        selected_total += take

    remaining = max_source_points - selected_total
    if remaining > 0:
        already_ids = {id(seed) for chosen in selected_by_bucket.values() for seed in chosen}
        # Keep extra capacity focused on the authored kinematic substrate.
        for name in ("groupedIsochron", "mor", "continental", "ungroupedOceanic", "other"):
            if remaining <= 0:
                break
            leftovers = [seed for seed in buckets[name] if id(seed) not in already_ids]
            take = min(remaining, len(leftovers))
            extra = even_sample(leftovers, take)
            selected_by_bucket[name].extend(extra)
            for seed in extra:
                already_ids.add(id(seed))
            remaining -= take

    selected: list[SeedPoint] = []
    for name in ("mor", "continental", "groupedIsochron", "ungroupedOceanic", "other"):
        selected.extend(selected_by_bucket.get(name, []))

    # A final guard for any quota/rounding overshoot.
    if len(selected) > max_source_points:
        selected = even_sample(selected, max_source_points)

    return selected, {
        "capped": True,
        "strategy": strategy,
        "inputCount": len(seeds),
        "outputCount": len(selected),
        "categoryInputCounts": {name: len(bucket) for name, bucket in buckets.items()},
        "categorySelectedCounts": {name: len(chosen) for name, chosen in selected_by_bucket.items()},
        "unusedCapacityAfterRedistribution": max(0, max_source_points - len(selected)),
    }


def even_sample(items: list[SeedPoint], n: int) -> list[SeedPoint]:
    if n <= 0:
        return []
    if len(items) <= n:
        return items
    if n == 1:
        return [items[0]]
    step = (len(items) - 1) / (n - 1)
    indices = sorted({int(round(i * step)) for i in range(n)})
    # If rounding collapsed indices, fill deterministically.
    cursor = 0
    while len(indices) < n and cursor < len(items):
        if cursor not in indices:
            indices.append(cursor)
        cursor += 1
    indices = sorted(indices[:n])
    return [items[i] for i in indices]


def downsample_young_points(
    young_points: list[tuple[float, float, str]],
    max_young_seed_points: int,
) -> list[tuple[float, float, str]]:
    if max_young_seed_points <= 0 or len(young_points) <= max_young_seed_points:
        return young_points
    step = (len(young_points) - 1) / (max_young_seed_points - 1)
    indices = sorted({int(round(i * step)) for i in range(max_young_seed_points)})
    return [young_points[i] for i in indices]


def build_seed_points(
    project: dict[str, Any],
    scene: dict[str, Any],
    args: argparse.Namespace,
) -> tuple[list[SeedPoint], dict[str, Any], list[tuple[float, float, str]]]:
    source_geojson = project.get("source", {}).get("geojson") or {}
    features = source_geojson.get("features") or []
    group_index = build_group_segment_index(scene)
    hidden_feature_ids = set(scene.get("hiddenFeatureIds") or [])

    source_seeds: list[SeedPoint] = []
    stats: dict[str, Any] = {
        "featureCount": len(features),
        "skippedHiddenFeatures": 0,
        "skippedUnsupportedGeometries": 0,
        "sourcePathCount": 0,
        "sourceSeedCountBeforeCap": 0,
        "sourceSeedCap": None,
        "zipperSeedCount": 0,
        "youngCrustSeedCount": 0,
        "backgroundSeedCount": 0,
        "invalidCoordinateCount": 0,
        "materialFeatureCounts": collections.Counter(),
    }

    for feature_index, feature in enumerate(features):
        if not isinstance(feature, dict):
            continue
        feature_id = source_feature_id(project, feature_index)
        if feature_id in hidden_feature_ids and not args.include_hidden_source:
            stats["skippedHiddenFeatures"] += 1
            continue
        props = feature.get("properties") or {}
        if not isinstance(props, dict):
            props = {}
        material = classify_feature_material(props)
        stats["materialFeatureCounts"][material] += 1
        age_ma = coerce_age_ma(props)
        is_mor = is_mor_feature(props)
        any_path = False
        for path_id, coords in iter_feature_paths(feature):
            any_path = True
            stats["sourcePathCount"] += 1
            samples = resample_path_units(
                coords,
                spacing_km=args.target_spacing_km,
                radius_km=args.radius_km,
                keep_source_vertices=args.keep_source_vertices,
            )
            for sample in samples:
                segment_index = int(sample["segmentIndex"])
                vertex_index = sample["vertexIndex"]
                group_id = group_for_sample(
                    group_index,
                    feature_id,
                    path_id,
                    segment_index,
                    vertex_index if isinstance(vertex_index, int) else None,
                )
                ref: dict[str, Any] = {
                    "kind": "sourceGeoJSON",
                    "featureId": feature_id,
                    "pathId": path_id,
                    "segmentIndex": segment_index,
                    "t": round_float(sample["t"], 6),
                    "pathDistanceKm": round_float(sample["pathDistanceKm"], 3),
                }
                if vertex_index is not None:
                    ref["vertexIndex"] = int(vertex_index)
                if age_ma is not None:
                    ref["ageMa"] = age_ma
                if group_id:
                    ref["groupId"] = group_id
                if is_mor:
                    ref["isMor"] = True
                seed_kind = "mor" if is_mor else ("isochron" if age_ma is not None else "continentalLine")
                source_seeds.append(
                    SeedPoint(
                        unit=sample["unit"],
                        lon=float(sample["lon"]),
                        lat=float(sample["lat"]),
                        kind=seed_kind,
                        material=material,
                        age_ma=age_ma,
                        group_id=group_id,
                        is_mor=is_mor,
                        source_refs=[ref],
                    )
                )
        if not any_path:
            stats["skippedUnsupportedGeometries"] += 1

    stats["sourceSeedCountBeforeCap"] = len(source_seeds)
    source_seeds, cap_stats = cap_seed_points(source_seeds, args.max_source_points, strategy=args.source_cap_strategy)
    stats["sourceSeedCap"] = cap_stats
    seeds = list(source_seeds)

    # Add zipper controls as explicit anchors. They are later snapped to final
    # mesh vertices for solver use.
    kinematic = scene.get("kinematicLinks") or {}
    for zipper in kinematic.get("zipperControls") or []:
        if not isinstance(zipper, dict):
            continue
        for ref_name, kind, material, is_mor in (
            ("fromIsochron", "zipperIsochron", "oceanic", False),
            ("toMor", "zipperMor", "mor", True),
            ("oppositeIsochron", "zipperOppositeIsochron", "oceanic", False),
        ):
            raw_ref = zipper.get(ref_name)
            if not isinstance(raw_ref, dict):
                continue
            lonlat = raw_ref.get("lonLat")
            if not (isinstance(lonlat, list) and len(lonlat) >= 2):
                continue
            parsed = lonlat_from_coord(lonlat)
            if parsed is None:
                continue
            lon, lat = parsed
            age = raw_ref.get("ageMa", zipper.get("ageMa"))
            try:
                age_ma = float(age) if age is not None and math.isfinite(float(age)) else None
            except (TypeError, ValueError):
                age_ma = None
            group_id = raw_ref.get("groupId") or zipper.get("groupId")
            source_ref = {
                "kind": kind,
                "zipperControlId": zipper.get("id"),
                "refName": ref_name,
                "featureId": raw_ref.get("featureId"),
                "pathId": raw_ref.get("pathId"),
                "vertexIndex": raw_ref.get("vertexIndex"),
                "groupId": group_id,
                "ageMa": age_ma,
                "isMor": is_mor,
            }
            seeds.append(
                SeedPoint(
                    unit=lonlat_to_unit(lon, lat),
                    lon=lon,
                    lat=lat,
                    kind=kind,
                    material=material,
                    age_ma=age_ma,
                    group_id=group_id if isinstance(group_id, str) else None,
                    is_mor=is_mor,
                    source_refs=[source_ref],
                )
            )
            stats["zipperSeedCount"] += 1

    young_points = extract_young_crust_points(project)
    young_seed_points = downsample_young_points(young_points, args.max_young_seed_points)
    if args.include_young_crust_seeds:
        for lon, lat, point_id in young_seed_points:
            seeds.append(
                SeedPoint(
                    unit=lonlat_to_unit(lon, lat),
                    lon=lon,
                    lat=lat,
                    kind="youngContinentalCrust",
                    material="youngContinental",
                    age_ma=None,
                    group_id=None,
                    is_mor=False,
                    source_refs=[
                        {
                            "kind": "youngContinentalCrust",
                            "pointId": point_id,
                        }
                    ],
                )
            )
            stats["youngCrustSeedCount"] += 1

    background = fibonacci_sphere_points(args.background_points)
    for idx, unit in enumerate(background):
        lon, lat = unit_to_lonlat(unit)
        seeds.append(
            SeedPoint(
                unit=unit,
                lon=lon,
                lat=lat,
                kind="background",
                material="background",
                source_refs=[{"kind": "background", "index": idx}],
            )
        )
    stats["backgroundSeedCount"] = len(background)
    stats["materialFeatureCounts"] = dict(stats["materialFeatureCounts"])

    return seeds, stats, young_points


# -----------------------------------------------------------------------------
# Deduplication and attribute assignment
# -----------------------------------------------------------------------------


def dedupe_seed_points(
    seeds: list[SeedPoint],
    *,
    dedupe_km: float,
    radius_km: float,
    max_source_refs_per_vertex: int,
) -> tuple[list[VertexRecord], dict[str, Any]]:
    if not seeds:
        raise ValueError("No seed points were generated")

    points = np.vstack([safe_unit(seed.unit) for seed in seeds])
    tolerance_chord = km_to_chord(dedupe_km, radius_km)
    tree = cKDTree(points)
    visited = np.zeros(len(seeds), dtype=bool)
    records: list[VertexRecord] = []
    duplicate_groups = 0
    max_group_size = 1

    for i in range(len(seeds)):
        if visited[i]:
            continue
        # Query around the current seed. This is not a full transitive clustering
        # pass, but with a conservative tolerance it removes duplicates without
        # accidentally coalescing nearby but distinct line vertices.
        neighbors = [j for j in tree.query_ball_point(points[i], tolerance_chord) if not visited[j]]
        if not neighbors:
            neighbors = [i]
        for j in neighbors:
            visited[j] = True
        if len(neighbors) > 1:
            duplicate_groups += 1
            max_group_size = max(max_group_size, len(neighbors))

        merged_unit = safe_unit(np.mean(points[neighbors], axis=0))
        lon, lat = unit_to_lonlat(merged_unit)
        kinds: Counter[str] = collections.Counter()
        materials: Counter[str] = collections.Counter()
        groups: Counter[str] = collections.Counter()
        ages: list[float] = []
        refs: list[dict[str, Any]] = []
        is_mor = False
        overflow = 0
        for j in neighbors:
            seed = seeds[j]
            kinds[seed.kind] += 1
            materials[seed.material] += 1
            if seed.group_id:
                groups[seed.group_id] += 1
            if seed.age_ma is not None and math.isfinite(float(seed.age_ma)):
                ages.append(float(seed.age_ma))
            if seed.is_mor:
                is_mor = True
            for ref in seed.source_refs:
                if len(refs) < max_source_refs_per_vertex:
                    refs.append(ref)
                else:
                    overflow += 1
        vid = f"v_{len(records):06d}"
        records.append(
            VertexRecord(
                id=vid,
                unit=merged_unit,
                lon=lon,
                lat=lat,
                seed_kinds=kinds,
                materials=materials,
                ages=ages,
                groups=groups,
                source_refs=refs,
                source_ref_overflow_count=overflow,
                is_mor=is_mor,
            )
        )

    stats = {
        "inputSeedCount": len(seeds),
        "outputVertexCount": len(records),
        "removedDuplicateSeedCount": len(seeds) - len(records),
        "duplicateGroupCount": duplicate_groups,
        "maxDuplicateGroupSize": max_group_size,
        "dedupeKm": dedupe_km,
        "dedupeChordTolerance": tolerance_chord,
    }
    return records, stats


def choose_material(counter: Counter[str]) -> str:
    if not counter:
        return "unknown"
    # MOR should remain recognizable, but most actual cells around MORs will be
    # marked morInfluenced rather than pure mor at the triangle level.
    priority = [
        "mor",
        "youngContinental",
        "continental",
        "oceanic",
        "background",
        "unknown",
    ]
    max_count = max(counter.values())
    tied = {k for k, v in counter.items() if v == max_count}
    for material in priority:
        if material in tied:
            return material
    return counter.most_common(1)[0][0]


def assign_young_crust_weights(
    vertices: list[VertexRecord],
    young_points: list[tuple[float, float, str]],
    *,
    radius_km: float,
    sigma_km: float,
    cutoff_km: float,
) -> dict[str, Any]:
    if not young_points or sigma_km <= 0.0:
        return {
            "youngPointCount": len(young_points),
            "sigmaKm": sigma_km,
            "cutoffKm": cutoff_km,
            "assignedNonzeroCount": 0,
        }
    young_units = np.vstack([lonlat_to_unit(lon, lat) for lon, lat, _ in young_points])
    tree = cKDTree(young_units)
    assigned = 0
    weights: list[float] = []
    for vertex in vertices:
        chord, _ = tree.query(vertex.unit, k=1)
        angle = chord_to_angle_rad(float(chord))
        distance_km = angle * radius_km
        if distance_km > cutoff_km:
            weight = 0.0
        else:
            # Gaussian radial basis, interpreted as a soft material field.
            weight = math.exp(-0.5 * (distance_km / sigma_km) ** 2)
        vertex.young_continental_weight = float(weight)
        weights.append(float(weight))
        if weight > 0.0:
            assigned += 1
    arr = np.array(weights, dtype=float)
    return {
        "youngPointCount": len(young_points),
        "sigmaKm": sigma_km,
        "cutoffKm": cutoff_km,
        "assignedNonzeroCount": assigned,
        "min": float(np.min(arr)) if len(arr) else 0.0,
        "median": float(np.median(arr)) if len(arr) else 0.0,
        "max": float(np.max(arr)) if len(arr) else 0.0,
    }


# -----------------------------------------------------------------------------
# Mesh construction
# -----------------------------------------------------------------------------


def build_hull(points: np.ndarray, qhull_options: str) -> tuple[ConvexHull, str]:
    try:
        return ConvexHull(points, qhull_options=qhull_options), qhull_options
    except QhullError:
        fallback = "QJ Qc Pp"
        if qhull_options.strip() == fallback:
            raise
        return ConvexHull(points, qhull_options=fallback), fallback


def orient_triangle_outward(simplex: Iterable[int], points: np.ndarray) -> list[int]:
    tri = list(map(int, simplex))
    a, b, c = points[tri[0]], points[tri[1]], points[tri[2]]
    normal = np.cross(b - a, c - a)
    centroid = a + b + c
    if float(np.dot(normal, centroid)) < 0.0:
        tri[1], tri[2] = tri[2], tri[1]
    return tri


def build_source_attribute_tree(vertices: list[VertexRecord]) -> tuple[Optional[cKDTree], list[int]]:
    source_indices: list[int] = []
    for idx, vertex in enumerate(vertices):
        material = vertex.material
        if material not in {"background", "unknown"} or vertex.age_ma is not None or vertex.group_id:
            source_indices.append(idx)
    if not source_indices:
        return None, []
    arr = np.vstack([vertices[idx].unit for idx in source_indices])
    return cKDTree(arr), source_indices


def nearest_source_vertex(
    unit: np.ndarray,
    tree: Optional[cKDTree],
    source_indices: list[int],
    vertices: list[VertexRecord],
    radius_km: float,
    max_km: float,
) -> Optional[VertexRecord]:
    if tree is None or not source_indices:
        return None
    chord, local_idx = tree.query(unit, k=1)
    distance_km = chord_to_angle_rad(float(chord)) * radius_km
    if distance_km > max_km:
        return None
    return vertices[source_indices[int(local_idx)]]


def mode_counter(values: Iterable[Optional[str]]) -> Optional[str]:
    counter: Counter[str] = collections.Counter(v for v in values if isinstance(v, str) and v)
    if not counter:
        return None
    return counter.most_common(1)[0][0]


def median_optional(values: Iterable[Optional[float]]) -> Optional[float]:
    arr = [float(v) for v in values if v is not None and math.isfinite(float(v))]
    if not arr:
        return None
    return float(np.median(np.array(arr, dtype=float)))


def triangle_cell_type(materials: list[str], is_mor: bool, nearest_material: Optional[str]) -> str:
    counts = collections.Counter(materials)
    if is_mor:
        return "morInfluenced"
    # Continental/young continental should win when at least two triangle
    # vertices carry that material, because these cells are likely near outlines.
    if counts["youngContinental"] >= 2:
        return "youngContinental"
    if counts["continental"] >= 2:
        return "continental"
    if counts["oceanic"] >= 2:
        return "oceanic"
    if nearest_material and nearest_material not in {"background", "unknown"}:
        if nearest_material == "mor":
            return "morInfluenced"
        return nearest_material
    return choose_material(counts)


def build_mesh(
    project: dict[str, Any],
    scene: dict[str, Any],
    vertices: list[VertexRecord],
    args: argparse.Namespace,
) -> tuple[dict[str, Any], dict[str, Any]]:
    points = np.vstack([v.unit for v in vertices])
    hull, qhull_options_used = build_hull(points, args.qhull_options)

    source_tree, source_indices = build_source_attribute_tree(vertices)
    triangles: list[dict[str, Any]] = []
    edge_lengths_all: list[float] = []
    area_all: list[float] = []
    min_angles_all: list[float] = []
    aspect_ratios_all: list[float] = []

    for tri_idx, simplex in enumerate(hull.simplices):
        tri_indices = orient_triangle_outward(simplex, points)
        tri_vertices = [vertices[i] for i in tri_indices]
        a, b, c = [v.unit for v in tri_vertices]
        centroid_unit = safe_unit(a + b + c)
        centroid_lon, centroid_lat = unit_to_lonlat(centroid_unit)
        tri_quality = triangle_quality_metrics(a, b, c, args.radius_km)
        edge_lengths = tri_quality["edgeLengthsKm"]
        edge_lengths_all.extend(edge_lengths)
        min_angles_all.append(float(tri_quality["minAngleDeg"]))
        aspect_ratios_all.append(float(tri_quality["aspectRatio"]))
        area_sr = spherical_triangle_area_sr(a, b, c)
        area_km2 = area_sr * args.radius_km * args.radius_km
        area_all.append(area_km2)

        nearest = nearest_source_vertex(
            centroid_unit,
            source_tree,
            source_indices,
            vertices,
            args.radius_km,
            args.attribute_search_km,
        )
        nearest_material = nearest.material if nearest else None
        age_ma = median_optional([v.age_ma for v in tri_vertices])
        if age_ma is None and nearest and nearest.age_ma is not None:
            age_ma = nearest.age_ma
        group_id = mode_counter([v.group_id for v in tri_vertices])
        if group_id is None and nearest and nearest.group_id:
            group_id = nearest.group_id
        is_mor = any(v.is_mor or v.material == "mor" for v in tri_vertices)
        materials = [v.material for v in tri_vertices]
        cell_type = triangle_cell_type(materials, is_mor, nearest_material)
        young_weight = float(np.mean([v.young_continental_weight for v in tri_vertices]))

        triangles.append(
            {
                "id": f"tri_{tri_idx:06d}",
                "v": [v.id for v in tri_vertices],
                "vi": tri_indices,
                "adjacent": [],  # filled after id assignment
                "centroidLonLat": [round_float(centroid_lon, 8), round_float(centroid_lat, 8)],
                "areaSteradians": round_float(area_sr, 12),
                "areaKm2AtR0": round_float(area_km2, 6),
                "edgeLengthsKmAtR0": [round_float(x, 4) for x in edge_lengths],
                "minAngleDeg": round_float(tri_quality["minAngleDeg"], 6),
                "maxAngleDeg": round_float(tri_quality["maxAngleDeg"], 6),
                "aspectRatio": round_float(tri_quality["aspectRatio"], 6),
                "cellType": cell_type,
                "ageMa": round_float(age_ma, 6) if age_ma is not None else None,
                "groupId": group_id,
                "youngContinentalWeight": round_float(young_weight, 6),
                "materialVotes": dict(collections.Counter(materials)),
                "nearestSourceVertexId": nearest.id if nearest else None,
            }
        )

    # ConvexHull.neighbors is aligned to hull.simplices. Convert neighbor indices
    # to triangle ids. -1 should not occur for a full spherical hull, but keep it
    # defensively.
    for tri_idx, neighbors in enumerate(hull.neighbors):
        triangles[tri_idx]["adjacent"] = [
            f"tri_{int(n):06d}" for n in neighbors if int(n) >= 0
        ]

    voronoi_payload: dict[str, Any]
    voronoi_stats: dict[str, Any]
    if args.include_voronoi:
        try:
            sv = SphericalVoronoi(
                points,
                radius=1.0,
                center=np.array([0.0, 0.0, 0.0]),
                threshold=args.voronoi_threshold,
            )
            sv.sort_vertices_of_regions()
            voronoi_payload = {
                "status": "ok",
                "radius": 1.0,
                "center": [0.0, 0.0, 0.0],
                "vertices": [[round_float(x, 10) for x in row] for row in sv.vertices.tolist()],
                "regions": [[int(i) for i in region] for region in sv.regions],
            }
            region_lengths = [len(r) for r in sv.regions]
            voronoi_stats = {
                "status": "ok",
                "voronoiVertexCount": int(len(sv.vertices)),
                "regionCount": int(len(sv.regions)),
                "regionVertexCountMin": int(min(region_lengths)) if region_lengths else 0,
                "regionVertexCountMedian": float(np.median(region_lengths)) if region_lengths else 0.0,
                "regionVertexCountMax": int(max(region_lengths)) if region_lengths else 0,
            }
        except Exception as exc:  # noqa: BLE001 - recorded in diagnostics
            voronoi_payload = {
                "status": "failed",
                "error": f"{type(exc).__name__}: {exc}",
            }
            voronoi_stats = dict(voronoi_payload)
    else:
        voronoi_payload = {"status": "not_requested"}
        voronoi_stats = {"status": "not_requested"}

    vertex_payload = [vertex_to_json(v) for v in vertices]
    constraint_anchors = build_constraint_anchors(project, scene, vertices, args.radius_km)

    project_meta = project.get("project") or {}
    mesh_id = make_mesh_id(project_meta, scene.get("timeMa", 0), len(vertices), len(triangles))
    generated_at = _dt.datetime.now(tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z")

    mesh = {
        "schemaVersion": SCHEMA_VERSION,
        "meshId": mesh_id,
        "generatedAtUtc": generated_at,
        "sourceProject": {
            "id": project_meta.get("id"),
            "name": project_meta.get("name"),
            "updatedAtUtc": project_meta.get("updatedAtUtc"),
            "sourceName": project_meta.get("sourceName") or project.get("source", {}).get("sourceName"),
            "sourceId": project_meta.get("sourceId") or project.get("source", {}).get("sourceId"),
            "sceneMa": scene.get("timeMa", 0),
        },
        "earthModel": {
            "mathSpace": "unit-vector-sphere-for-topology",
            "renderSpace": "lonlat-unit-sphere-viewer",
            "radiusKm0": args.radius_km,
            "note": "This file builds the 0 Ma mesh substrate only. Reverse-time shrinking-radius embedding belongs in ee_solve_step.py.",
        },
        "parameters": serializable_args(args),
        "counts": {
            "vertices": len(vertices),
            "triangles": len(triangles),
            "convexHullSimplices": int(len(hull.simplices)),
            "convexHullCoplanarCount": int(len(getattr(hull, "coplanar", []))),
        },
        "vertices": vertex_payload,
        "triangles": triangles,
        "voronoi": voronoi_payload,
        "constraintAnchors": constraint_anchors,
    }

    quality_stats = summarize_quality(edge_lengths_all, area_all, min_angles_all, aspect_ratios_all)
    cell_counts = collections.Counter(t["cellType"] for t in triangles)
    age_values = [t["ageMa"] for t in triangles if t.get("ageMa") is not None]
    diagnostics = {
        "meshId": mesh_id,
        "generatedAtUtc": generated_at,
        "qhullOptionsRequested": args.qhull_options,
        "qhullOptionsUsed": qhull_options_used,
        "counts": mesh["counts"],
        "cellTypeCounts": dict(cell_counts),
        "triangleAgeMa": summarize_numeric(age_values),
        "edgeLengthKmAtR0": quality_stats["edgeLengthKmAtR0"],
        "triangleAreaKm2AtR0": quality_stats["triangleAreaKm2AtR0"],
        "triangleMinAngleDeg": quality_stats["triangleMinAngleDeg"],
        "triangleAspectRatio": quality_stats["triangleAspectRatio"],
        "voronoi": voronoi_stats,
        "constraintAnchors": constraint_anchors.get("diagnostics", {}),
        "warnings": build_warnings(vertices, triangles, quality_stats, voronoi_stats),
    }

    # Remove duplicated diagnostics subobject from mesh constraint anchors.
    if "diagnostics" in mesh["constraintAnchors"]:
        mesh["constraintAnchors"].pop("diagnostics", None)

    return mesh, diagnostics


def vertex_to_json(vertex: VertexRecord) -> dict[str, Any]:
    return {
        "id": vertex.id,
        "lon": round_float(vertex.lon, 8),
        "lat": round_float(vertex.lat, 8),
        "unit": [round_float(x, 10) for x in vertex.unit.tolist()],
        "seedKinds": dict(vertex.seed_kinds),
        "material": vertex.material,
        "materialVotes": dict(vertex.materials),
        "ageMa": round_float(vertex.age_ma, 6) if vertex.age_ma is not None else None,
        "groupId": vertex.group_id,
        "isMor": bool(vertex.is_mor),
        "youngContinentalWeight": round_float(vertex.young_continental_weight, 6),
        "sourceRefs": vertex.source_refs,
        "sourceRefOverflowCount": int(vertex.source_ref_overflow_count),
    }


def make_mesh_id(project_meta: dict[str, Any], scene_ma: Any, vertex_count: int, triangle_count: int) -> str:
    base = f"{project_meta.get('id','project')}|{project_meta.get('updatedAtUtc','')}|{scene_ma}|{vertex_count}|{triangle_count}"
    digest = hashlib.sha1(base.encode("utf-8")).hexdigest()[:12]
    return f"mesh_{digest}"


def serializable_args(args: argparse.Namespace) -> dict[str, Any]:
    out: dict[str, Any] = {}
    for key, value in vars(args).items():
        if isinstance(value, Path):
            out[key] = str(value)
        elif isinstance(value, (str, int, float, bool)) or value is None:
            out[key] = value
    return out


def summarize_numeric(values: Iterable[Any]) -> dict[str, Any]:
    arr = np.array([float(v) for v in values if v is not None and math.isfinite(float(v))], dtype=float)
    if arr.size == 0:
        return {"count": 0}
    return {
        "count": int(arr.size),
        "min": float(np.min(arr)),
        "median": float(np.median(arr)),
        "mean": float(np.mean(arr)),
        "p95": float(np.percentile(arr, 95)),
        "max": float(np.max(arr)),
    }


def summarize_quality(
    edge_lengths: list[float],
    areas: list[float],
    min_angles_deg: Optional[list[float]] = None,
    aspect_ratios: Optional[list[float]] = None,
) -> dict[str, Any]:
    return {
        "edgeLengthKmAtR0": summarize_numeric(edge_lengths),
        "triangleAreaKm2AtR0": summarize_numeric(areas),
        "triangleMinAngleDeg": summarize_numeric(min_angles_deg or []),
        "triangleAspectRatio": summarize_numeric(aspect_ratios or []),
    }


def build_warnings(
    vertices: list[VertexRecord],
    triangles: list[dict[str, Any]],
    quality: dict[str, Any],
    voronoi_stats: dict[str, Any],
) -> list[dict[str, Any]]:
    warnings: list[dict[str, Any]] = []
    if len(vertices) < 1000:
        warnings.append(
            {
                "kind": "low_vertex_count",
                "severity": "review",
                "message": "Mesh has fewer than 1,000 vertices; this is probably only suitable for smoke testing.",
            }
        )
    edge_stats = quality.get("edgeLengthKmAtR0", {})
    if edge_stats.get("max", 0.0) > 1500.0:
        warnings.append(
            {
                "kind": "large_edge_length",
                "severity": "review",
                "message": "Some Delaunay edges exceed 1,500 km. Increase background/source sampling for production meshes.",
                "maxKm": edge_stats.get("max"),
            }
        )
    angle_stats = quality.get("triangleMinAngleDeg", {})
    if angle_stats.get("p95") is not None and angle_stats.get("min", 90.0) < 1.0:
        warnings.append(
            {
                "kind": "very_small_triangle_angle",
                "severity": "review",
                "message": "Some triangles have minimum angles below 1 degree. Inspect validation hotspots before solving.",
                "minAngleDeg": angle_stats.get("min"),
            }
        )
    aspect_stats = quality.get("triangleAspectRatio", {})
    if aspect_stats.get("p95", 0.0) > 12.0:
        warnings.append(
            {
                "kind": "high_triangle_aspect_ratio",
                "severity": "review",
                "message": "The 95th percentile triangle aspect ratio is high. Consider more refinement or less dense line seeding.",
                "p95AspectRatio": aspect_stats.get("p95"),
            }
        )
    if voronoi_stats.get("status") == "failed":
        warnings.append(
            {
                "kind": "voronoi_failed",
                "severity": "high",
                "message": "ConvexHull triangles were built, but SphericalVoronoi failed. Review duplicate tolerance and qhull stability.",
                "error": voronoi_stats.get("error"),
            }
        )
    background_triangles = sum(1 for t in triangles if t.get("cellType") == "background")
    if background_triangles > 0.5 * len(triangles):
        warnings.append(
            {
                "kind": "many_background_cells",
                "severity": "review",
                "message": "More than half the triangles are classified as background. This can be normal for coarse first-pass meshes, but attribute interpolation should be improved before solving.",
                "backgroundTriangleCount": background_triangles,
                "triangleCount": len(triangles),
            }
        )
    return warnings


# -----------------------------------------------------------------------------
# Constraint anchor snapping
# -----------------------------------------------------------------------------


def snap_lonlat_to_vertex(
    lonlat: Any,
    tree: cKDTree,
    vertices: list[VertexRecord],
    radius_km: float,
) -> Optional[dict[str, Any]]:
    parsed = lonlat_from_coord(lonlat)
    if parsed is None:
        return None
    unit = lonlat_to_unit(*parsed)
    chord, idx = tree.query(unit, k=1)
    idx = int(idx)
    distance_km = chord_to_angle_rad(float(chord)) * radius_km
    return {
        "vertexId": vertices[idx].id,
        "distanceKm": round_float(distance_km, 6),
        "lonLat": [round_float(parsed[0], 8), round_float(parsed[1], 8)],
    }


def build_constraint_anchors(
    project: dict[str, Any],
    scene: dict[str, Any],
    vertices: list[VertexRecord],
    radius_km: float,
) -> dict[str, Any]:
    points = np.vstack([v.unit for v in vertices])
    tree = cKDTree(points)
    kinematic = scene.get("kinematicLinks") or {}

    zipper_anchors: list[dict[str, Any]] = []
    snap_distances: list[float] = []
    for zipper in kinematic.get("zipperControls") or []:
        if not isinstance(zipper, dict):
            continue
        anchor: dict[str, Any] = {
            "id": zipper.get("id"),
            "groupId": zipper.get("groupId"),
            "morFeatureId": zipper.get("morFeatureId"),
            "morPathId": zipper.get("morPathId"),
            "ageMa": zipper.get("ageMa"),
            "constraint": zipper.get("constraint", "zipTo"),
        }
        for key in ("fromIsochron", "toMor", "oppositeIsochron"):
            raw = zipper.get(key)
            if isinstance(raw, dict):
                snap = snap_lonlat_to_vertex(raw.get("lonLat"), tree, vertices, radius_km)
                if snap:
                    anchor[key] = {
                        "featureId": raw.get("featureId"),
                        "pathId": raw.get("pathId"),
                        "vertexIndex": raw.get("vertexIndex"),
                        "groupId": raw.get("groupId"),
                        "ageMa": raw.get("ageMa"),
                        "nearestMeshVertexId": snap["vertexId"],
                        "snapDistanceKm": snap["distanceKm"],
                    }
                    snap_distances.append(float(snap["distanceKm"]))
        zipper_anchors.append(anchor)

    group_associations = []
    for assoc in kinematic.get("groupAssociations") or []:
        if not isinstance(assoc, dict):
            continue
        group_associations.append(
            {
                "id": assoc.get("id"),
                "morFeatureId": assoc.get("morFeatureId"),
                "morPathId": assoc.get("morPathId"),
                "leftGroupId": assoc.get("leftGroupId"),
                "rightGroupId": assoc.get("rightGroupId"),
                "pairType": assoc.get("pairType"),
                "createdAtMa": assoc.get("createdAtMa"),
            }
        )

    mor_assignments = []
    for assignment in kinematic.get("morAssignments") or []:
        if not isinstance(assignment, dict):
            continue
        mor_assignments.append(
            {
                "id": assignment.get("id"),
                "groupId": assignment.get("groupId"),
                "morFeatureId": assignment.get("morFeatureId"),
                "morPathId": assignment.get("morPathId"),
                "side": assignment.get("side"),
                "role": assignment.get("role"),
                "pairType": assignment.get("pairType"),
                "groupAssociationId": assignment.get("groupAssociationId"),
                "createdAtMa": assignment.get("createdAtMa"),
            }
        )

    diagnostics = {
        "zipperControlCount": len(zipper_anchors),
        "zipperSnapDistanceKm": summarize_numeric(snap_distances),
        "groupAssociationCount": len(group_associations),
        "morAssignmentCount": len(mor_assignments),
    }

    return {
        "zipperControls": zipper_anchors,
        "groupAssociations": group_associations,
        "morAssignments": mor_assignments,
        "diagnostics": diagnostics,
    }


# -----------------------------------------------------------------------------
# Preview GeoJSON
# -----------------------------------------------------------------------------


def make_preview_geojson(
    mesh: dict[str, Any],
    *,
    max_triangles: int,
    include_seed_points: bool,
    max_seed_points: int,
) -> dict[str, Any]:
    vertices = mesh.get("vertices") or []
    triangles = mesh.get("triangles") or []
    by_id = {v["id"]: v for v in vertices}

    features: list[dict[str, Any]] = []
    tri_limit = len(triangles) if max_triangles <= 0 else min(max_triangles, len(triangles))
    # Keep deterministic sample across the whole sphere if we limit previews.
    if tri_limit < len(triangles):
        indices = sorted({int(round(i * (len(triangles) - 1) / max(1, tri_limit - 1))) for i in range(tri_limit)})
    else:
        indices = list(range(len(triangles)))

    for idx in indices:
        tri = triangles[idx]
        coords: list[list[float]] = []
        ok = True
        for vid in tri.get("v") or []:
            v = by_id.get(vid)
            if not v:
                ok = False
                break
            coords.append([float(v["lon"]), float(v["lat"])])
        if not ok or len(coords) != 3:
            continue
        coords.append(coords[0])
        features.append(
            {
                "type": "Feature",
                "properties": {
                    "kind": "meshTriangle",
                    "id": tri.get("id"),
                    "cellType": tri.get("cellType"),
                    "ageMa": tri.get("ageMa"),
                    "groupId": tri.get("groupId"),
                    "areaKm2AtR0": tri.get("areaKm2AtR0"),
                    "youngContinentalWeight": tri.get("youngContinentalWeight"),
                },
                "geometry": {
                    "type": "Polygon",
                    "coordinates": [coords],
                },
            }
        )

    if include_seed_points:
        non_background = [v for v in vertices if v.get("material") != "background"]
        if max_seed_points > 0 and len(non_background) > max_seed_points:
            step = (len(non_background) - 1) / (max_seed_points - 1)
            non_background = [non_background[int(round(i * step))] for i in range(max_seed_points)]
        for v in non_background:
            features.append(
                {
                    "type": "Feature",
                    "properties": {
                        "kind": "meshVertex",
                        "id": v.get("id"),
                        "material": v.get("material"),
                        "ageMa": v.get("ageMa"),
                        "groupId": v.get("groupId"),
                        "seedKinds": v.get("seedKinds"),
                        "youngContinentalWeight": v.get("youngContinentalWeight"),
                    },
                    "geometry": {
                        "type": "Point",
                        "coordinates": [float(v["lon"]), float(v["lat"])],
                    },
                }
            )

    return {
        "type": "FeatureCollection",
        "name": f"{mesh.get('meshId', 'ee_mesh')}_preview",
        "properties": {
            "schemaVersion": "ee-mesh-preview-v1",
            "meshId": mesh.get("meshId"),
            "trianglePreviewCount": tri_limit,
            "triangleTotalCount": len(triangles),
            "vertexTotalCount": len(vertices),
            "note": "Preview is for inspection only. Use the mesh JSON for solver work.",
        },
        "features": features,
    }


# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------


def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Build a spherical triangle mesh sidecar from an EEAE ee-project-v1 JSON file.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("project_json", type=Path, help="Input build .eeproject.json file")
    parser.add_argument("--out", type=Path, default=None, help="Output mesh JSON path")
    parser.add_argument("--preview-out", type=Path, default=None, help="Output mesh preview GeoJSON path")
    parser.add_argument("--diag-out", type=Path, default=None, help="Output diagnostics JSON path")
    parser.add_argument("--scene-ma", type=float, default=0.0, help="Scene age to mesh")
    parser.add_argument("--radius-km", type=float, default=DEFAULT_RADIUS_KM, help="0 Ma Earth radius used for km metrics")
    parser.add_argument("--target-spacing-km", type=float, default=180.0, help="Approximate linework resampling interval")
    parser.add_argument("--background-points", type=int, default=1500, help="Fibonacci background points for global coverage")
    parser.add_argument("--max-source-points", type=int, default=6000, help="Cap source line seeds before adding zippers/background; <=0 disables cap")
    parser.add_argument("--source-cap-strategy", choices=["balanced", "legacy"], default="legacy", help="How to cap source line seeds when max-source-points is reached")
    parser.add_argument("--keep-source-vertices", action="store_true", help="Keep original source vertices in addition to spacing samples")
    parser.add_argument("--include-hidden-source", action="store_true", help="Include source features hidden in the selected scene")
    parser.add_argument("--include-young-crust-seeds", dest="include_young_crust_seeds", action="store_true", default=True, help="Use young-crust cloud points as mesh seed vertices")
    parser.add_argument("--no-young-crust-seeds", dest="include_young_crust_seeds", action="store_false", help="Use young-crust points only for soft weights, not as mesh seed vertices")
    parser.add_argument("--max-young-seed-points", type=int, default=800, help="Cap young-crust seed points; <=0 keeps all")
    parser.add_argument("--young-crust-sigma-km", type=float, default=250.0, help="Gaussian sigma for young continental soft-field weights")
    parser.add_argument("--young-crust-cutoff-km", type=float, default=750.0, help="Cutoff distance for young continental soft-field weights")
    parser.add_argument("--dedupe-km", type=float, default=0.5, help="Merge seed points within this chord-equivalent surface distance")
    parser.add_argument("--max-source-refs-per-vertex", type=int, default=10, help="Limit stored provenance refs per merged vertex")
    parser.add_argument("--attribute-search-km", type=float, default=350.0, help="Nearest source search radius for triangle cell attributes")
    parser.add_argument("--quality-refine", dest="quality_refine", action="store_true", default=True, help="Add deterministic Steiner points to reduce long edges and skinny triangles")
    parser.add_argument("--no-quality-refine", dest="quality_refine", action="store_false", help="Disable quality-refinement Steiner point pass")
    parser.add_argument("--refine-iterations", type=int, default=1, help="Maximum spherical Delaunay quality-refinement passes")
    parser.add_argument("--refine-long-edge-km", type=float, default=0.0, help="Long-edge refinement threshold; <=0 auto-derives from target spacing")
    parser.add_argument("--refine-small-angle-deg", type=float, default=8.0, help="Triangles below this min angle become refinement candidates")
    parser.add_argument("--refine-aspect-ratio", type=float, default=8.0, help="Triangles above this edge aspect ratio become refinement candidates")
    parser.add_argument("--refine-max-new-points-per-iteration", type=int, default=1000, help="Cap Steiner points added per refinement pass")
    parser.add_argument("--refine-min-candidate-distance-km", type=float, default=35.0, help="Do not add refinement points closer than this to an existing/selected point")
    parser.add_argument("--qhull-options", type=str, default="QJ Qc Pp", help="Options passed to scipy.spatial.ConvexHull")
    parser.add_argument("--include-voronoi", action="store_true", default=True, help="Build scipy.spatial.SphericalVoronoi dual cells")
    parser.add_argument("--no-voronoi", dest="include_voronoi", action="store_false", help="Skip SphericalVoronoi generation")
    parser.add_argument("--voronoi-threshold", type=float, default=1e-6, help="Duplicate/sphere mismatch threshold for SphericalVoronoi")
    parser.add_argument("--preview-max-triangles", type=int, default=2000, help="Max triangles in preview GeoJSON; <=0 includes all")
    parser.add_argument("--preview-include-seed-points", action="store_true", help="Include non-background seed vertices in preview GeoJSON")
    parser.add_argument("--preview-max-seed-points", type=int, default=2000, help="Max seed point features in preview")
    parser.add_argument("--pretty", action="store_true", help="Pretty-print JSON outputs")
    return parser.parse_args(argv)


def default_output_paths(project_path: Path, args: argparse.Namespace) -> tuple[Path, Path, Path]:
    stem = project_path.name
    for suffix in (".eeproject.json", ".json"):
        if stem.endswith(suffix):
            stem = stem[: -len(suffix)]
            break
    out = args.out or project_path.with_name(f"{stem}.mesh.v1.json")
    preview = args.preview_out or project_path.with_name(f"{stem}.mesh.preview.geojson")
    diag = args.diag_out or project_path.with_name(f"{stem}.mesh.diagnostics.json")
    return out, preview, diag


def main(argv: Optional[list[str]] = None) -> int:
    args = parse_args(argv)
    project_path = args.project_json
    out_path, preview_path, diag_path = default_output_paths(project_path, args)

    project = read_json(project_path)
    if project.get("schemaVersion") != "ee-project-v1":
        raise ValueError(f"Expected schemaVersion ee-project-v1, got {project.get('schemaVersion')!r}")

    scene = get_scene(project, args.scene_ma)
    seeds, seed_stats, young_points = build_seed_points(project, scene, args)
    vertices, dedupe_stats = dedupe_seed_points(
        seeds,
        dedupe_km=args.dedupe_km,
        radius_km=args.radius_km,
        max_source_refs_per_vertex=args.max_source_refs_per_vertex,
    )
    vertices, refinement_stats = refine_vertices_for_quality(vertices, args)
    young_weight_stats = assign_young_crust_weights(
        vertices,
        young_points,
        radius_km=args.radius_km,
        sigma_km=args.young_crust_sigma_km,
        cutoff_km=args.young_crust_cutoff_km,
    )
    mesh, diagnostics = build_mesh(project, scene, vertices, args)

    diagnostics["inputProject"] = {
        "path": str(project_path),
        "sha256": sha256_file(project_path),
        "schemaVersion": project.get("schemaVersion"),
        "project": project.get("project"),
    }
    diagnostics["seedGeneration"] = seed_stats
    diagnostics["dedupe"] = dedupe_stats
    diagnostics["qualityRefinement"] = refinement_stats
    diagnostics["youngContinentalWeights"] = young_weight_stats

    preview = make_preview_geojson(
        mesh,
        max_triangles=args.preview_max_triangles,
        include_seed_points=args.preview_include_seed_points,
        max_seed_points=args.preview_max_seed_points,
    )

    write_json(out_path, mesh, pretty=args.pretty)
    write_json(preview_path, preview, pretty=args.pretty)
    write_json(diag_path, diagnostics, pretty=True)

    print(json.dumps({
        "meshOut": str(out_path),
        "previewOut": str(preview_path),
        "diagnosticsOut": str(diag_path),
        "meshId": mesh.get("meshId"),
        "vertices": mesh.get("counts", {}).get("vertices"),
        "triangles": mesh.get("counts", {}).get("triangles"),
        "voronoiStatus": mesh.get("voronoi", {}).get("status"),
        "warnings": diagnostics.get("warnings", []),
    }, indent=2))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
