# Congressional Representation Project: metric calculator
#
# This script turns raw House election results into the JSON dataset used by the
# website (electoral-dataset.json).
#
# HOW TO FIND EACH METRIC IN THIS FILE
# Search for these labels:
#
#   STEP: Redistricting flags
#   STEP: Clean raw district votes
#   STEP: Adjusted vote estimates (uncontested races)
#   METRIC: Competitiveness (district level)
#   METRIC: Efficiency Gap
#   METRIC: Mean-Median
#   METRIC: Proportionality
#   METRIC: Fairness (composite)
#   METRIC: Contestation
#   METRIC: Incumbency Insulation
#   METRIC: Structural Stability
#   METRIC: Electoral Health Index (EHI)
#   METRIC: Metric Confidence

import os
import json
import re
import pandas as pd
import numpy as np

# Lookup tables: translate between state names, postal abbreviations (AL, AZ, …),
# and FIPS codes used in map files.
STATE_MAPPING = {
    "ALABAMA": {"po": "AL", "fips": "1"}, "ALASKA": {"po": "AK", "fips": "2"},
    "ARIZONA": {"po": "AZ", "fips": "4"}, "ARKANSAS": {"po": "AR", "fips": "5"},
    "CALIFORNIA": {"po": "CA", "fips": "6"}, "COLORADO": {"po": "CO", "fips": "8"},
    "CONNECTICUT": {"po": "CT", "fips": "9"}, "DELAWARE": {"po": "DE", "fips": "10"},
    "FLORIDA": {"po": "FL", "fips": "12"}, "GEORGIA": {"po": "GA", "fips": "13"},
    "HAWAII": {"po": "HI", "fips": "15"}, "IDAHO": {"po": "ID", "fips": "16"},
    "ILLINOIS": {"po": "IL", "fips": "17"}, "INDIANA": {"po": "IN", "fips": "18"},
    "IOWA": {"po": "IA", "fips": "19"}, "KANSAS": {"po": "KS", "fips": "20"},
    "KENTUCKY": {"po": "KY", "fips": "21"}, "LOUISIANA": {"po": "LA", "fips": "22"},
    "MAINE": {"po": "ME", "fips": "23"}, "MARYLAND": {"po": "MD", "fips": "24"},
    "MASSACHUSETTS": {"po": "MA", "fips": "25"}, "MICHIGAN": {"po": "MI", "fips": "26"},
    "MINNESOTA": {"po": "MN", "fips": "27"}, "MISSISSIPPI": {"po": "MS", "fips": "28"},
    "MISSOURI": {"po": "MO", "fips": "29"}, "MONTANA": {"po": "MT", "fips": "30"},
    "NEBRASKA": {"po": "NE", "fips": "31"}, "NEVADA": {"po": "NV", "fips": "32"},
    "NEW HAMPSHIRE": {"po": "NH", "fips": "33"}, "NEW JERSEY": {"po": "NJ", "fips": "34"},
    "NEW MEXICO": {"po": "NM", "fips": "35"}, "NEW YORK": {"po": "NY", "fips": "36"},
    "NORTH CAROLINA": {"po": "NC", "fips": "37"}, "NORTH DAKOTA": {"po": "ND", "fips": "38"},
    "OHIO": {"po": "OH", "fips": "39"}, "OKLAHOMA": {"po": "OK", "fips": "40"},
    "OREGON": {"po": "OR", "fips": "41"}, "PENNSYLVANIA": {"po": "PA", "fips": "42"},
    "RHODE ISLAND": {"po": "RI", "fips": "44"}, "SOUTH CAROLINA": {"po": "SC", "fips": "45"},
    "SOUTH DAKOTA": {"po": "SD", "fips": "46"}, "TENNESSEE": {"po": "TN", "fips": "47"},
    "TEXAS": {"po": "TX", "fips": "48"}, "UTAH": {"po": "UT", "fips": "49"},
    "VERMONT": {"po": "VT", "fips": "50"}, "VIRGINIA": {"po": "VA", "fips": "51"},
    "WASHINGTON": {"po": "WA", "fips": "53"}, "WEST VIRGINIA": {"po": "WV", "fips": "54"},
    "WISCONSIN": {"po": "WI", "fips": "55"}, "WYOMING": {"po": "WY", "fips": "56"}
}

PO_TO_STATE = {v["po"]: k for k, v in STATE_MAPPING.items()}


# STEP: Redistricting flags
# Reads map filenames to learn which election years follow a new district map.
# Those years get a lower confidence score because fresh maps can make comparisons noisier.
def detect_redistricting_years(geojson_dir="./geojson"):
    redistricting_map = {}
    if not os.path.exists(geojson_dir):
        return redistricting_map

    for filename in os.listdir(geojson_dir):
        if filename.endswith(".geojson"):
            match = re.match(r"(\d+)_(\d+)_(\d+)\.geojson", filename)
            if match:
                fips, start_yr, _ = match.groups()
                if fips not in redistricting_map:
                    redistricting_map[fips] = set()
                redistricting_map[fips].add(int(start_yr))
    return redistricting_map


# STEP: Clean raw district votes
# Reads one state's rows from the CSV and builds a simple record per district:
# Democratic votes, Republican votes, total votes, and vote share.
def clean_districts(state_group):
    cleaned_districts = []
    is_at_large = state_group['District'].nunique() == 1
    for _, row in state_group.iterrows():
        dist_num = int(row['District'])
        if is_at_large:
            dist_num = 0
        dem_v = float(row['Dem Votes']) if pd.notna(row['Dem Votes']) else 0.0
        rep_v = float(row['GOP Votes']) if pd.notna(row['GOP Votes']) else 0.0
        tot_v = dem_v + rep_v

        raw_share = dem_v / tot_v if tot_v > 0 else 0.5

        cleaned_districts.append({
            "district": dist_num,
            "raw_totalvotes": int(tot_v),
            "raw_demvotes": int(dem_v),
            "raw_repvotes": int(rep_v),
            "raw_dem_share": round(raw_share, 4),
            "incumbent": str(row['Incumbent']).strip(),
            "party": str(row['Party']).strip()
        })
    return cleaned_districts


# STEP: Adjusted vote estimates (uncontested races)
# When one party got almost no votes (or none), the raw result looks like 100%–0%.
# That can distort fairness and competition metrics. This step estimates a more
# realistic two-party split using statewide House results, presidential results,
# and neighboring districts — but only for metric calculations. Original votes
# are still saved separately as "raw" values in the output JSON.
def impute_uncontested(districts, statewide_house_dem_share, statewide_pres_dem_share):
    valid_turnouts = [d["raw_totalvotes"] for d in districts if d["raw_totalvotes"] > 0]
    avg_turnout = sum(valid_turnouts) / len(valid_turnouts) if valid_turnouts else 150000

    imputed_districts = []
    for d in districts:
        dv = d["raw_demvotes"]
        rv = d["raw_repvotes"]
        tot = d["raw_totalvotes"]

        # A district counts as uncontested when one major party is effectively missing.
        is_uncontested = (
            dv == 0
            or rv == 0
            or tot == 0
            or min(dv, rv) / max(tot, 1) < 0.02
        )

        if is_uncontested:
            contested = False
            nearby_shares = []

            for other in districts:
                if other["district"] == d["district"]:
                    continue

                odv = other["raw_demvotes"]
                orv = other["raw_repvotes"]
                otot = other["raw_totalvotes"]

                if odv > 0 and orv > 0 and otot > 0:
                    nearby_shares.append(other["raw_dem_share"])

            if nearby_shares:
                neighbor_avg = np.mean(nearby_shares)
            else:
                neighbor_avg = statewide_house_dem_share

            # Blend statewide House, presidential, and neighbor context.
            adj_dem_share = (
                0.40 * statewide_house_dem_share +
                0.40 * statewide_pres_dem_share +
                0.20 * neighbor_avg
            )

            adj_dem_share = max(0.05, min(0.95, adj_dem_share))
            adj_dem_share = max(0.05, min(0.95, adj_dem_share))

            adj_tot = tot if tot > 0 else avg_turnout
            adj_dv = adj_tot * adj_dem_share
            adj_rv = adj_tot * (1.0 - adj_dem_share)
        else:
            contested = True
            adj_dem_share = d["raw_dem_share"]
            adj_tot = tot
            adj_dv = dv
            adj_rv = rv

        winner = "DEM" if dv > rv else "REP"

        # METRIC: Competitiveness (district level) — adjusted path
        # Closer to 50/50 means more competitive. Formula: 1 - 2|share - 0.5|
        comp = max(0.0, min(1.0, 1.0 - (2.0 * abs(adj_dem_share - 0.5))))

        d_mod = d.copy()
        d_mod.update({
            "contested": contested,
            "adjusted_dem_share": round(adj_dem_share, 4),
            "adjusted_totalvotes": int(adj_tot),
            "adjusted_demvotes": int(adj_dv),
            "adjusted_repvotes": int(adj_rv),
            "winner": winner,
            "competitiveness": round(comp, 4)
        })
        imputed_districts.append(d_mod)
    return imputed_districts


# All statewide metrics are calculated here. The function runs twice per state:
# once on raw vote records and once on adjusted (imputed) vote records.
def calculate_metrics(districts, was_redistricted):
    valid_districts = [d for d in districts if d["raw_totalvotes"] > 1]
    total_districts = len(valid_districts)

    # Single-district states (at-large) cannot use seat-based fairness math.
    # EHI uses a simplified blend of competition, contestation, and stability only.
    if total_districts <= 1:
        comp_score = districts[0]["competitiveness"] if districts else 0.5
        contested_count = sum(1 for d in districts if d["contested"])
        contestation_score = contested_count / total_districts if total_districts > 0 else 1.0

        incumbents = [d for d in districts if d["incumbent"] in ['D', 'R']]
        inc_insulation = 0.0
        if incumbents:
            inc_uncontested_rate = 1.0 if not incumbents[0]["contested"] else 0.0
            margin = abs(incumbents[0]["adjusted_dem_share"] - 0.5) * 2 if incumbents[0]["contested"] else 0.0
            inc_insulation = (0.5 * inc_uncontested_rate) + (0.5 * margin)

        # METRIC: Structural Stability
        stability_score = max(0.0, min(1.0, 1.0 - inc_insulation))

        # METRIC: Electoral Health Index (EHI) — single-district variant
        weights = {"competitiveness": 0.50, "contestation": 0.25, "stability": 0.25}
        ehi = (weights["competitiveness"] * comp_score) + (weights["contestation"] * contestation_score) + (weights["stability"] * stability_score)

        # METRIC: Metric Confidence — single-district variant
        confidence = 0.75
        if was_redistricted:
            confidence -= 0.10
        if contestation_score < 1.0:
            confidence -= 0.20
        confidence = max(0.0, min(1.0, confidence))

        return {
            "efficiency_gap_score": None, "mean_median_score": None, "proportionality_score": None, "fairness_score": None,
            "competitiveness_score": round(comp_score, 4),
            "contestation_score": round(contestation_score, 4),
            "incumbency_insulation": round(inc_insulation, 4),
            "structural_stability_score": round(stability_score, 4),
            "electoral_health_index": round(ehi, 4),
            "metric_confidence": round(confidence, 4)
        }

    # --- Multi-district states ---
    total_adj_votes = 0
    total_adj_dem_wasted = 0
    total_adj_rep_wasted = 0
    adj_dem_shares = []
    dem_seats_won = 0
    contested_count = 0

    incumbents = [d for d in districts if d["incumbent"] in ['D', 'R']]
    inc_uncontested_count = 0
    contested_incumbent_margins = []

    for d in valid_districts:
        if d["raw_totalvotes"] <= 1:
            continue
        adv = d["adjusted_demvotes"]
        arv = d["adjusted_repvotes"]
        tot = d["adjusted_totalvotes"]
        share = d["adjusted_dem_share"]

        total_adj_votes += tot
        adj_dem_shares.append(share)

        if d["winner"] == "DEM":
            dem_seats_won += 1
        if d["contested"]:
            contested_count += 1

        if d["incumbent"] in ['D', 'R']:
            if not d["contested"]:
                inc_uncontested_count += 1
            else:
                contested_incumbent_margins.append(abs(share - 0.5) * 2)

        # Wasted votes feed the Efficiency Gap (below).
        # Losing votes are wasted; winning votes above the victory threshold are wasted too.
        threshold = (tot / 2.0) + 1
        if adv > arv:
            total_adj_dem_wasted += (adv - threshold)
            total_adj_rep_wasted += arv
        else:
            total_adj_rep_wasted += (arv - threshold)
            total_adj_dem_wasted += adv

    # METRIC: Efficiency Gap
    # Raw gap: (Dem wasted - Rep wasted) / total votes.
    # Score shown on the map: 1 - min(|gap| / 0.12, 1) so higher = more balanced.
    eg = (total_adj_dem_wasted - total_adj_rep_wasted) / total_adj_votes if total_adj_votes > 0 else 0.0
    eg_score = 1.0 - min(abs(eg) / 0.12, 1.0)

    # METRIC: Mean-Median
    # Compare average district Dem share to the middle (median) district.
    # Score shown on the map: 1 - min(|mean - median| / 0.06, 1).
    mean_share = np.mean(adj_dem_shares) if adj_dem_shares else 0.5
    median_share = np.median(adj_dem_shares) if adj_dem_shares else 0.5
    mm_gap = abs(mean_share - median_share)
    mm_score = 1.0 - min(mm_gap / 0.06, 1.0)

    # METRIC: Proportionality
    # Compare Democratic seat share to Democratic vote share statewide.
    # Score shown on the map: 1 - min(|seat share - vote share| / 0.25, 1).
    dem_seat_share = dem_seats_won / total_districts
    total_adj_dem_votes = sum(d["adjusted_demvotes"] for d in districts)
    dem_vote_share = total_adj_dem_votes / total_adj_votes if total_adj_votes > 0 else 0.5
    seat_vote_gap = abs(dem_seat_share - dem_vote_share)
    prop_score = 1.0 - min(seat_vote_gap / 0.25, 1.0)

    # METRIC: Fairness (composite)
    # Average of the three normalized fairness scores above (each 0–1, higher = better).
    fairness_score = (eg_score + mm_score + prop_score) / 3.0

    # METRIC: Competitiveness (statewide)
    # Average of each district's competitiveness score (see district-level formulas in main()).
    comp_score = np.mean([d["competitiveness"] for d in districts]) if districts else 0.0

    # METRIC: Contestation
    # Share of districts where both major parties received meaningful votes.
    contestation_score = contested_count / total_districts

    # METRIC: Incumbency Insulation
    # Combines how often incumbents run unopposed with how large their win margins are.
    if incumbents:
        inc_uncontested_rate = inc_uncontested_count / len(incumbents)
        avg_margin = np.mean(contested_incumbent_margins) if contested_incumbent_margins else 0.0
        inc_insulation = (
            0.75 * inc_uncontested_rate
            + 0.25 * avg_margin
        )
    else:
        inc_insulation = 0.0

    # METRIC: Structural Stability
    stability_score = max(0.0, min(1.0, 1.0 - inc_insulation))

    # METRIC: Electoral Health Index (EHI)
    # Weighted combination of fairness, competition, contestation, and stability.
    ehi = (0.30 * fairness_score) + (0.30 * comp_score) + (0.20 * contestation_score) + (0.20 * stability_score)

    # METRIC: Metric Confidence
    # Lower when many districts are uncontested or the state just redistricted.
    uncontested_rate = 1.0 - contestation_score
    confidence = 1.0 - (uncontested_rate * 0.5)
    if was_redistricted:
        confidence -= 0.10
    confidence = max(0.0, min(1.0, confidence))

    return {
        "efficiency_gap_score": round(eg_score, 4),
        "mean_median_score": round(mm_score, 4),
        "proportionality_score": round(prop_score, 4),
        "fairness_score": round(fairness_score, 4),
        "competitiveness_score": round(comp_score, 4),
        "contestation_score": round(contestation_score, 4),
        "incumbency_insulation": round(inc_insulation, 4),
        "structural_stability_score": round(stability_score, 4),
        "electoral_health_index": round(ehi, 4),
        "metric_confidence": round(confidence, 4)
    }


def main():
    print("Loading datasets...")
    ce_df = pd.read_csv('CE_results.csv')
    ge_df = pd.read_csv('GE_results.csv')

    ce_df['State'] = ce_df['State'].str.strip().str.upper()
    ge_df['State'] = ge_df['State'].str.strip().str.upper()

    redistricting_map = detect_redistricting_years()
    output_json = {}
    years = sorted(ce_df['Year'].unique())

    for year in years:
        output_json[str(year)] = {"states": {}}
        year_ce = ce_df[ce_df['Year'] == year]
        year_ge = ge_df[ge_df['Year'] == year]

        for state_po, state_group in year_ce.groupby('State'):
            state_name = PO_TO_STATE.get(state_po)
            if not state_name or state_name not in STATE_MAPPING:
                continue

            fips = STATE_MAPPING[state_name]["fips"]

            ge_match = year_ge[year_ge['State'] == state_name]
            if not ge_match.empty:
                d_pct = float(ge_match['Dem%'].values[0])
                r_pct = float(ge_match['Rep%'].values[0])
                statewide_pres_dem_share = d_pct / (d_pct + r_pct) if (d_pct + r_pct) > 0 else 0.5
                ge_swing = str(ge_match['Swing'].values[0])
            else:
                statewide_pres_dem_share, d_pct, r_pct, ge_swing = 0.5, None, None, None

            was_redistricted = int(year) in redistricting_map.get(fips, set())

            raw_districts = clean_districts(state_group)

            total_raw_dem = sum(d["raw_demvotes"] for d in raw_districts)
            total_raw_tot = sum(d["raw_totalvotes"] for d in raw_districts)
            statewide_house_dem_share = total_raw_dem / total_raw_tot if total_raw_tot > 0 else 0.5

            processed_districts = impute_uncontested(raw_districts, statewide_house_dem_share, statewide_pres_dem_share)

            # Build the "raw metrics" version using recorded votes only (no imputation).
            raw_version = []
            for d in raw_districts:
                share = d["raw_dem_share"]
                dv = d["raw_demvotes"]
                rv = d["raw_repvotes"]
                tot = d["raw_totalvotes"]
                contested = not (dv == 0 or rv == 0 or tot == 0)
                winner = "DEM" if dv > rv else ("REP" if rv > dv else ("DEM" if share >= 0.5 else "REP"))
                margin = abs(share - 0.5) * 2

                # METRIC: Competitiveness (district level) — raw path
                # Uses a squared margin: 1 - (2|share - 0.5|)²
                comp = max(0.0, 1.0 - (margin ** 2))

                d_raw = d.copy()
                d_raw.update({
                    "contested": contested,
                    "adjusted_dem_share": share,
                    "adjusted_totalvotes": tot,
                    "adjusted_demvotes": dv,
                    "adjusted_repvotes": rv,
                    "winner": winner,
                    "competitiveness": round(comp, 4)
                })
                raw_version.append(d_raw)

            raw_metrics = calculate_metrics(raw_version, was_redistricted)
            adjusted_metrics = calculate_metrics(processed_districts, was_redistricted)

            dist_output_json = {}
            for d in processed_districts:
                dist_output_json[str(d["district"])] = {
                    "totalvotes": d["raw_totalvotes"],
                    "demvotes": d["raw_demvotes"],
                    "repvotes": d["raw_repvotes"],
                    "adjusted_totalvotes": d["adjusted_totalvotes"],
                    "adjusted_demvotes": d["adjusted_demvotes"],
                    "adjusted_repvotes": d["adjusted_repvotes"],
                    "raw_dem_share": d["raw_dem_share"],
                    "adjusted_dem_share": d["adjusted_dem_share"],
                    "winner": d["winner"],
                    "competitiveness": d["competitiveness"],
                    "contested": d["contested"],
                    "incumbent": d["incumbent"],
                    "party": d["party"]
                }

            state_dict = {
                "state": state_name.title(),
                "state_po": state_po,
                "state_fips": fips,
                "ge_pres_dem_pct": d_pct,
                "ge_pres_rep_pct": r_pct,
                "ge_pres_swing": ge_swing,
                "post_redistricting": was_redistricted,
                "raw_metrics": raw_metrics,
                "adjusted_metrics": adjusted_metrics,
                "districts": dist_output_json
            }
            output_json[str(year)]["states"][state_name.title()] = state_dict

    with open("electoral_health_output.json", "w") as f:
        json.dump(output_json, f, indent=2)
    print("Success! Data successfully dumped to electoral_health_output.json")


if __name__ == "__main__":
    main()
