"""Script for patching transcript assignments for use with Xenium Ranger v2.0.1

Xenium Ranger v2.0.1 assumes that the segmentation was generated using only Q20
transcripts and that each of the generated cells is non-empty. If this
assumption is violated, erroneous outputs are generated.

This script removes any cells that are empty or do not contain any Q20 transcripts.
"""

from pathlib import Path
import re
import json
import argparse
import logging
import sys

try:
    import polars as pl
    import numpy as np
except ImportError:
    print("Please install numpy and polars before running this script (pip install numpy; pip install polars)")
    sys.exit(1)
    
LOGGER = logging.getLogger(__name__)

CELL_ID_REGEX = re.compile(r"^[\w|\d]+-(\d+)$")


def _build_segmentation_df(segmentation_csv: Path, xenium_bundle: Path):
    assert segmentation_csv.exists(), "Couldn't find segmentation CSV."
    
    transcripts_pq = xenium_bundle.joinpath("transcripts.parquet")
    assert transcripts_pq.exists(), "Couldn't find transcripts.parquet in xenium bundle."
   
    LOGGER.info("reading transcripts.parquet") 
    transcripts = pl.read_parquet(xenium_bundle.joinpath("transcripts.parquet"))
    LOGGER.info("reading segmentation CSV") 
    segmentation = pl.read_csv(segmentation_csv)
    
    return segmentation.join(
        transcripts.select(
            pl.col("transcript_id").cast(pl.Int64),
            pl.col("qv"),
        ),
        on="transcript_id",
    )
    
def _read_polygons(viz_polygons_geojson: Path):
    assert viz_polygons_geojson.exists(), "Couldn't find visualization polygons"
    
    LOGGER.info("reading visualization polygons")
    with open(viz_polygons_geojson) as file:
        viz_polygons = json.load(file)
        
    return viz_polygons

def _extract_segmentation_cell_ids_with_q20_transcripts(segmentation_df: pl.DataFrame):
    seg_cell_ids = []
    
    cell_ids = segmentation_df.select(pl.col("cell"), pl.col("qv")).filter(pl.col("cell").ne("") & pl.col("qv").ge(20.0)).get_column("cell")
    
    for cell_id in cell_ids:
        m = CELL_ID_REGEX.match(cell_id)
        assert m is not None
        
        index = int(m.group(1))
        assert index >= 0
        seg_cell_ids.append(index)

    return np.unique(seg_cell_ids).tolist()
    
def _extract_polygon_cell_ids(polygons: dict):
    polygon_cell_ids = [] 
    for p in polygons["geometries"]:
        assert p["cell"] >= 0
        polygon_cell_ids.append(p["cell"])
        
    return polygon_cell_ids

def _patch_transcript_assignment_outputs(segmentation_df: pl.DataFrame, polygons: dict):
    LOGGER.info("extracting cell IDs for cells with Q20 transcripts")
    seg_cell_ids = _extract_segmentation_cell_ids_with_q20_transcripts(segmentation_df)
    LOGGER.info("extracting cell IDs from visualization polygons")
    polygon_cell_ids = _extract_polygon_cell_ids(polygons)
    
    to_remove = set(polygon_cell_ids).difference(set(seg_cell_ids))
    LOGGER.info(f"Cell IDs that need to be removed: {to_remove}")
    
    id_map = {
        pid: i
        for i, pid in enumerate(filter(lambda i: i not in to_remove, polygon_cell_ids))
    }
    
    LOGGER.info("creating new cell IDS")
    new_cell_ids = []
    for cell_id, qv in segmentation_df.select(pl.col("cell"), pl.col("qv")).rows():
        m = CELL_ID_REGEX.match(cell_id)
        if m is not None:
            if qv > 20:
                index = int(m.group(1))
                assert index >= 0
                new_cell_ids.append(f"cell-{id_map[index]}")
            else:
                new_cell_ids.append("")
        else:
            assert cell_id == ""
            new_cell_ids.append("")

    
    LOGGER.info("patching segmentation")
    new_df_seg = segmentation_df.drop("cell", "qv").with_columns(pl.Series("cell", new_cell_ids))

    new_geometries = []

    LOGGER.info("patching polygons")
    for p in polygons["geometries"]:
        new_id = id_map.get(p["cell"])
        if new_id is not None:
            p_new = p.copy()
            p_new["cell"] = new_id
            new_geometries.append(p_new)
           
    new_polygons = {
        "type" : polygons["type"],
        "geometries": new_geometries,
    } 
            
    return (new_df_seg, new_polygons)

def _write_new_transcript_assignment_outputs(
    new_segmentation_df: pl.DataFrame,
    new_polygons: dict,
    output_segmentation_csv: Path,
    output_polygons_geojson: Path,
):
    log = logging.getLogger()
    
    log.info(f"writing output polygons to '{output_polygons_geojson}'")
    with open(output_polygons_geojson, "w") as f:
        json.dump(new_polygons, f)
        
    log.info(f"writing output segmentation to '{output_segmentation_csv}'")
    new_segmentation_df.write_csv(output_segmentation_csv)
    
    
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    
    parser = argparse.ArgumentParser(
        prog="patch_transcript_assignments",
        description="Patch transcript assignment outputs for use with Xenium Ranger 2.0.1",
    )
    
    parser.add_argument("--xenium-bundle", type=str, help="Path to the Xenium bundle", required=True)
    parser.add_argument("--transcript-assignment", type=str, help="Path to the transcript assignments CSV file", required=True)
    parser.add_argument("--viz-polygons", type=str, help="Path to the GeoJSON segmentation polygons", required=True)
    parser.add_argument("--output-transcript-assignment", type=str, help="Path where new transcript assignments CSV should be written", required=True)
    parser.add_argument("--output-viz-polygons", type=str, help="Path where new GeoJSON segmentation polygons should be written", required=True)
    
    args = parser.parse_args()
    
    xenium_bundle = Path(args.xenium_bundle)
    transcript_assignment = Path(args.transcript_assignment)
    viz_polygons_geojson = Path(args.viz_polygons)
    output_transcript_assignment = Path(args.output_transcript_assignment)
    output_viz_polygons = Path(args.output_viz_polygons)
    
    assert xenium_bundle.exists(), "Path to Xenium bundle doesn't exist"
    assert transcript_assignment.exists(), "Path to transcript assignment CSV doesn't exist"
    assert viz_polygons_geojson.exists(), "Path to visualization polygons doesn't exist"
    
    assert not output_transcript_assignment.exists(), "Path to output transcript assignment CSV already exist. Please delete or move before running."
    assert not output_viz_polygons.exists(), "Path to output visualization polygons already exist. Please delete or move before running."
    
    segmentation_df = _build_segmentation_df(
        segmentation_csv=transcript_assignment,
        xenium_bundle=xenium_bundle,
    )
    polygons = _read_polygons(
        viz_polygons_geojson=viz_polygons_geojson,
    )
    
    new_segmentation_df, new_polygons = _patch_transcript_assignment_outputs(segmentation_df=segmentation_df, polygons=polygons)
    
    _write_new_transcript_assignment_outputs(
        new_segmentation_df=new_segmentation_df,
        new_polygons=new_polygons,
        output_segmentation_csv=output_transcript_assignment,
        output_polygons_geojson=output_viz_polygons,
    )