#! /usr/bin/env python3
"""align_tops — align S1 TOPS-mode SLCs at the fractional-pixel level.

Python port of csh align_tops.csh (X. Xu, D. Sandwell 2015).

Workflow:
  1. Build PRM + LED files for master and/or aligned (make_s1a_tops with no SLC).
  2. Extract precise orbits (ext_orb_s1a), compute Doppler centroid + earth
     radius (calc_dop_orb).
  3. Geometric back-projection through the DEM to determine alignment grids
     r.grd and a.grd (range and azimuth offsets per pixel).
  4. Second make_s1a_tops pass with r.grd + a.grd → actual aligned SLC.
  5. resamp + fitoffset to apply sub-pixel alignment.
  6. Re-extract LED, final calc_dop_orb.

Usage:  align_tops master_prefix master_orb aligned_prefix aligned_orb dem.grd
                  [run_with_a/r_ready]
Output: S1_<date>_<time>_F<frame>.{PRM,LED,SLC} for both master and aligned.

skip_master semantics (driven by argv positions 2 and 4):
  - master_orb=="0", aligned_orb!="0": skip_master=1 (process aligned only)
  - master_orb!="0", aligned_orb=="0": skip_master=2 (process master only)
  - both non-zero:                     skip_master=0 (process both)
"""
import os
import re
import subprocess
import sys

from gmtsar_lib import run, check_file_report, replace_strings


def _grep_field3(path, key):
    """Return field 3 of the first line containing `key`, like `grep K f | awk '{print $3}'`."""
    with open(path) as f:
        for line in f:
            if key in line:
                parts = line.split()
                if len(parts) >= 3:
                    return parts[2]
    return ""


def _capture(cmd):
    """Run a shell pipeline and return its trimmed stdout as a string."""
    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE,
                          check=False).stdout.decode('utf-8').strip()


def _frame_prefix(input_name):
    """Build S1 output prefix from input filename via awk-style substring.
    Replicates: awk '{print "S1_"substr($1,16,8)"_"substr($1,25,6)"_F"substr($1,7,1)}'
    csh uses 1-based indexing; substr(s,16,8) → Python s[15:23], etc."""
    return f"S1_{input_name[15:23]}_{input_name[24:30]}_F{input_name[6:7]}"


def align_tops():
    if len(sys.argv) not in (6, 7):
        sys.exit(
            "Usage: align_tops master_prefix master_orb aligned_prefix "
            "aligned_orb dem.grd [run_with_a/r_ready]\n"
            "  Use '0' for master_orb to skip master, '0' for aligned_orb "
            "to process master only."
        )
    mpre_in, morb, spre_in, sorb, dem = sys.argv[1:6]
    mode = 1 if len(sys.argv) == 7 else 0

    # File-existence guards (mirror legacy `! -f X && exit`).
    if morb != "0":
        if not os.path.isfile(f"{mpre_in}.xml"):
            sys.exit(f"****** missing file: {mpre_in}.xml")
        if not os.path.isfile(morb):
            sys.exit(f"****** missing file: {morb}")
    if sorb != "0":
        if not os.path.isfile(f"{spre_in}.xml"):
            sys.exit(f"****** missing file: {spre_in}.xml")
        if not os.path.isfile(sorb):
            sys.exit(f"****** missing file: {sorb}")
    if not os.path.isfile(dem):
        sys.exit(f"****** missing file: {dem}")

    # skip_master dispatch
    if morb == "0":
        skip_master = 1
    elif sorb == "0":
        skip_master = 2
    else:
        skip_master = 0

    mtiff, mxml = f"{mpre_in}.tiff", f"{mpre_in}.xml"
    stiff, sxml = f"{spre_in}.tiff", f"{spre_in}.xml"
    mpre = _frame_prefix(mpre_in)
    spre = _frame_prefix(spre_in)
    if skip_master in (0, 2):
        print(mpre)
    if skip_master in (0, 1):
        print(spre)

    # ===== Stage 1: PRM + LED only (no SLC) =====
    if skip_master == 2:
        run(f"make_s1a_tops {mxml} {mtiff} {mpre} 1")
        run(f"ext_orb_s1a {mpre}.PRM {morb} {mpre}")
        run(f"calc_dop_orb {mpre}.PRM tmp 0 0")
        run(f"cat tmp >> {mpre}.PRM")
        run("rm -f tmp")
        return

    # skip_master ∈ {0, 1}: process aligned (+ optionally master)
    if skip_master == 0:
        run(f"make_s1a_tops {mxml} {mtiff} {mpre} 0")
    run(f"make_s1a_tops {sxml} {stiff} {spre} 0")

    if skip_master == 0:
        run(f"ext_orb_s1a {mpre}.PRM {morb} {mpre}")
        run(f"calc_dop_orb {mpre}.PRM tmp 0 0")
        run(f"cat tmp >> {mpre}.PRM")
        run("rm -f tmp")

    run(f"ext_orb_s1a {spre}.PRM {sorb} {spre}")
    earth_radius = _grep_field3(f"{mpre}.PRM", "earth_radius")
    run(f"calc_dop_orb {spre}.PRM tmp2 {earth_radius} 0")
    run(f"cat tmp2 >> {spre}.PRM")
    run("rm -f tmp2")

    # ===== Stage 2: geometric back-projection (skipped when mode=1) =====
    if mode == 0:
        run(f"gmt grdfilter {dem} -D3 -Fg2 -I12s -Ni -Gflt.grd")
        run("gmt grd2xyz --FORMAT_FLOAT_OUT=%lf flt.grd -s > topo.llt")

    # Tie-point alignment estimate — used in both mode=0 and mode=1 paths.
    lontie = _capture(f"SAT_baseline {mpre}.PRM {spre}.PRM | "
                      f"grep lon_tie_point | awk '{{print $3}}'")
    lattie = _capture(f"SAT_baseline {mpre}.PRM {spre}.PRM | "
                      f"grep lat_tie_point | awk '{{print $3}}'")
    tmp_am = _capture(f"echo {lontie} {lattie} 0 | "
                      f"SAT_llt2rat {mpre}.PRM 1 | awk '{{print $2}}'")
    tmp_as = _capture(f"echo {lontie} {lattie} 0 | "
                      f"SAT_llt2rat {spre}.PRM 1 | awk '{{print $2}}'")
    tmp_da = int(float(tmp_as) - float(tmp_am))

    if mode == 0:
        if -1000 < tmp_da < 1000:
            # No burst shift — run SAT_llt2rat in parallel for master & aligned.
            p1 = subprocess.Popen(f"SAT_llt2rat {mpre}.PRM 1 < topo.llt > master.ratll", shell=True)
            p2 = subprocess.Popen(f"SAT_llt2rat {spre}.PRM 1 < topo.llt > aligned.ratll", shell=True)
            p1.wait(); p2.wait()
        else:
            print(f"Modifying master PRM by {tmp_da} lines...")
            run(f"cp {mpre}.PRM tmp.PRM")
            prf = _grep_field3("tmp.PRM", "PRF")
            # Adjust four clock fields by -tmp_da / prf / 86400 days.
            shift = float(tmp_da) / float(prf) / 86400.0
            for key in ("clock_start", "clock_stop",
                        "SC_clock_start", "SC_clock_stop"):
                # Mirror legacy `grep K | grep -v SC_K` for non-SC variants.
                if key.startswith("SC_"):
                    val = _grep_field3("tmp.PRM", key)
                else:
                    val = _capture(f"grep {key} tmp.PRM | grep -v SC_{key} | awk '{{print $3}}'")
                new_val = f"{float(val) - shift:.12f}"
                run(f"update_PRM tmp.PRM {key} {new_val}")
            p1 = subprocess.Popen(f"SAT_llt2rat tmp.PRM 1 < topo.llt > master.ratll", shell=True)
            p2 = subprocess.Popen(f"SAT_llt2rat {spre}.PRM 1 < topo.llt > aligned.ratll", shell=True)
            p1.wait(); p2.wait()

        # paste master.ratll aligned.ratll → tmp.dat (range, dr, azimuth, da, snr)
        run("paste master.ratll aligned.ratll | "
            "awk '{printf(\"%.6f %.6f %.6f %.6f %d\\n\", $1, $6-$1, $2, $7-$2, \"100\")}' > tmp.dat")

        rmax = _grep_field3(f"{spre}.PRM", "num_rng_bins")
        amax = _grep_field3(f"{spre}.PRM", "num_lines")
        run(f"awk '{{if($1 > 0 && $1 < {rmax} && $3 > 0 && $3 < {amax}) "
            f"print $0 }}' < tmp.dat > offset.dat")
        if not (-1000 < tmp_da < 1000):
            run(f"awk '{{if($1 > 0 && $1 < {rmax} && $3 > 0 && $3 < {amax}) "
                f"printf(\"%.6f %.6f %.6f %.6f %d\\n\", $1, $2, $3 - {tmp_da}, $4 + {tmp_da}, \"100\") }}' "
                f"< tmp.dat > offset2.dat")

        # Build r.xyz and a.xyz, blockmedian → surface → flip → r.grd, a.grd
        run("awk '{ printf(\"%.6f %.6f %.6f \\n\", $1,$3,$2) }' < offset.dat > r.xyz")
        run("awk '{ printf(\"%.6f %.6f %.6f \\n\", $1,$3,$4) }' < offset.dat > a.xyz")
        run(f"gmt blockmedian r.xyz -R0/{rmax}/0/{amax} -I16/8 -r -bo3d > rtmp.xyz")
        run(f"gmt blockmedian a.xyz -R0/{rmax}/0/{amax} -I16/8 -r -bo3d > atmp.xyz")
        p1 = subprocess.Popen(
            f"gmt surface rtmp.xyz -bi3d -R0/{rmax}/0/{amax} -I16/8 -T0.3 -Grtmp.grd -N1000 -r",
            shell=True)
        p2 = subprocess.Popen(
            f"gmt surface atmp.xyz -bi3d -R0/{rmax}/0/{amax} -I16/8 -T0.3 -Gatmp.grd -N1000 -r",
            shell=True)
        p1.wait(); p2.wait()
        run("gmt grdmath rtmp.grd FLIPUD = r.grd")
        run("gmt grdmath atmp.grd FLIPUD = a.grd")

    # ===== Stage 3: actual aligned-SLC creation via second make_s1a_tops pass =====
    if skip_master == 0:
        run(f"make_s1a_tops {mxml} {mtiff} {mpre} 1")
    run(f"make_s1a_tops {sxml} {stiff} {spre} 1 r.grd a.grd")

    # resamp the aligned + clear or set ashift per burst-shift state
    run(f"cp {spre}.PRM {spre}.PRM0")
    if -1000 < tmp_da < 1000:
        run(f"update_PRM {spre}.PRM ashift 0")
    else:
        run(f"update_PRM {spre}.PRM ashift {tmp_da}")
        print(f"Restoring {tmp_da} lines with resamp...")
    run(f"resamp {mpre}.PRM {spre}.PRM {spre}.PRMresamp {spre}.SLCresamp 1")
    run(f"mv -f {spre}.SLCresamp {spre}.SLC")
    run(f"mv -f {spre}.PRMresamp {spre}.PRM")

    offset_file = "offset.dat" if -1000 < tmp_da < 1000 else "offset2.dat"
    run(f"fitoffset 3 3 {offset_file} >> {spre}.PRM")

    # Re-extract LED and recompute calc_dop_orb
    if skip_master == 0:
        run(f"ext_orb_s1a {mpre}.PRM {morb} {mpre}")
    run(f"ext_orb_s1a {spre}.PRM {sorb} {spre}")
    if skip_master == 0:
        run(f"calc_dop_orb {mpre}.PRM tmp 0 0")
        run(f"cat tmp >> {mpre}.PRM")
        run("rm -f tmp")
    earth_radius = _grep_field3(f"{mpre}.PRM", "earth_radius")
    run(f"calc_dop_orb {spre}.PRM tmp2 {earth_radius} 0")
    run(f"cat tmp2 >> {spre}.PRM")
    run("rm -f tmp2")

    run("rm -f topo.llt master.ratll aligned.ratll *tmp* flt.grd r.xyz a.xyz *.PRM0")


if __name__ == "__main__":
    align_tops()
