#!/usr/bin/env python3
"""
Margin of Safety Engine - Phase 1
Fetches 13F holdings from SEC EDGAR for super investors,
enriches with current/45-day-ago prices via yfinance,
and saves to holdings-latest.json
"""

import json
import time
import datetime
import requests
import yfinance as yf
from pathlib import Path

BASE_DIR = Path(__file__).parent
INVESTORS_FILE = BASE_DIR / "investors.json"
OUTPUT_FILE = BASE_DIR / "holdings-latest.json"

SEC_HEADERS = {
    "User-Agent": "MarginOfSafetyEngine joe@example.com",
    "Accept-Encoding": "gzip, deflate",
}

# Fallback sample data for Berkshire if 13F fetch fails
BERKSHIRE_SAMPLE = [
    {"ticker": "AAPL", "company": "Apple Inc", "shares": 300000000, "value_usd": 0},
    {"ticker": "BAC",  "company": "Bank of America Corp", "shares": 1032852006, "value_usd": 0},
    {"ticker": "AXP",  "company": "American Express Co", "shares": 151610700, "value_usd": 0},
    {"ticker": "KO",   "company": "Coca-Cola Co", "shares": 400000000, "value_usd": 0},
    {"ticker": "CVX",  "company": "Chevron Corp", "shares": 118610534, "value_usd": 0},
    {"ticker": "OXY",  "company": "Occidental Petroleum Corp", "shares": 255280404, "value_usd": 0},
    {"ticker": "MCO",  "company": "Moody's Corp", "shares": 24669778, "value_usd": 0},
    {"ticker": "KHC",  "company": "Kraft Heinz Co", "shares": 325634818, "value_usd": 0},
]

def get_latest_13f_accession(cik: str) -> str | None:
    """Get the most recent 13F filing accession number from EDGAR."""
    url = f"https://data.sec.gov/submissions/CIK{cik.lstrip('0').zfill(10)}.json"
    try:
        resp = requests.get(url, headers=SEC_HEADERS, timeout=15)
        resp.raise_for_status()
        data = resp.json()
        filings = data.get("filings", {}).get("recent", {})
        forms = filings.get("form", [])
        accessions = filings.get("accessionNumber", [])
        for form, acc in zip(forms, accessions):
            if form in ("13F-HR", "13F-HR/A"):
                return acc
        return None
    except Exception as e:
        print(f"  [WARN] Could not fetch submissions for CIK {cik}: {e}")
        return None

def parse_13f_holdings(cik: str, accession: str) -> list[dict]:
    """Download and parse 13F XML infotable from EDGAR."""
    acc_nodash = accession.replace("-", "")
    cik_padded = cik.lstrip("0").zfill(10)
    index_url = f"https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany&CIK={cik_padded}&type=13F-HR&dateb=&owner=include&count=1&search_text="

    # Try to find the primary XML document
    filing_url = f"https://www.sec.gov/Archives/edgar/data/{cik.lstrip('0')}/{acc_nodash}/"
    try:
        resp = requests.get(filing_url + "index.json", headers=SEC_HEADERS, timeout=15)
        resp.raise_for_status()
        idx = resp.json()
        files = idx.get("directory", {}).get("item", [])
        xml_file = None
        for f in files:
            name = f.get("name", "")
            if name.endswith(".xml") and "infotable" in name.lower():
                xml_file = name
                break
        if not xml_file:
            for f in files:
                name = f.get("name", "")
                if name.endswith(".xml") and name != "primary_doc.xml":
                    xml_file = name
                    break
        if not xml_file:
            print(f"  [WARN] No infotable XML found for {cik}")
            return []

        xml_url = filing_url + xml_file
        xml_resp = requests.get(xml_url, headers=SEC_HEADERS, timeout=30)
        xml_resp.raise_for_status()
        return parse_infotable_xml(xml_resp.text)
    except Exception as e:
        print(f"  [WARN] Could not parse 13F for CIK {cik}: {e}")
        return []

def parse_infotable_xml(xml_text: str) -> list[dict]:
    """Parse holdings from 13F infotable XML. Handles namespaced and plain XML."""
    import re
    from xml.etree import ElementTree as ET

    holdings = []
    try:
        # Strip ALL namespace declarations and prefixes for simple iteration
        xml_clean = re.sub(r'\s+xmlns[^=]*="[^"]*"', '', xml_text)
        # Remove namespace prefixes from element tags  
        xml_clean = re.sub(r'<([/]?)[\w]+:([\w]+)', r'<\1\2', xml_clean)
        # Remove namespace-prefixed attributes (e.g., xsi:schemaLocation="...")
        xml_clean = re.sub(r'\s+[\w]+:[\w]+=(?:"[^"]*"|\'[^\']*\')', '', xml_clean)

        root = ET.fromstring(xml_clean)

        def find_text(el, *tags):
            for tag in tags:
                found = el.find(tag)
                if found is not None and found.text:
                    return found.text.strip()
            return None

        # Find all infoTable entries (handles any nesting)
        for entry in root.iter("infoTable"):
            try:
                name = find_text(entry, "nameOfIssuer") or "Unknown"
                value_str = find_text(entry, "value")
                value = int(value_str) * 1000 if value_str else 0
                shares_str = find_text(entry, "sshPrnamt")
                shares = int(shares_str) if shares_str else 0
                ticker = find_text(entry, "ticker") or ""

                holdings.append({
                    "company": name,
                    "ticker": ticker,
                    "shares": shares,
                    "value_usd": value,
                })
            except Exception:
                continue
    except ET.ParseError as e:
        print(f"  [WARN] XML parse error: {e}")
    return holdings

def cusip_to_ticker_guess(company_name: str) -> str:
    """Very rough guess — real mapping needs a CUSIP lookup service."""
    # Common mappings from Berkshire/super investor holdings
    mapping = {
        "APPLE": "AAPL", "BANK OF AMERICA": "BAC", "AMERICAN EXPRESS": "AXP",
        "COCA-COLA": "KO", "CHEVRON": "CVX", "OCCIDENTAL": "OXY",
        "MOODY": "MCO", "KRAFT HEINZ": "KHC", "VERISIGN": "VRSN",
        "DAVITA": "DVA", "HP INC": "HPQ", "CITIGROUP": "C",
        "VISA": "V", "MASTERCARD": "MA", "AMAZON": "AMZN",
        "MICROSOFT": "MSFT", "ALPHABET": "GOOGL", "GOOGLE": "GOOGL",
        "META": "META", "BERKSHIRE": "BRK-B",
        "CONSTELLATION": "STZ", "AUTOZONE": "AZO",
        "FAIR ISAAC": "FICO", "ROPER": "ROP",
        # Pabrai / energy / materials holdings
        "ALPHA METALLURGICAL": "AMR", "ALPHA MET": "AMR",
        "TRANSOCEAN": "RIG", "VALARIS": "VAL",
        "WARRIOR MET": "HCC", "WARRIOR MET COAL": "HCC",
        # Akre holdings
        "MASTERCARD": "MA", "VISA": "V",
        "MOODY": "MCO", "S&P GLOBAL": "SPGI",
        "CARMAX": "KMX", "CHARLES SCHWAB": "SCHW",
        "DOLLAR TREE": "DLTR", "ROSS STORES": "ROST",
        "O REILLY": "ORLY", "OREILLY": "ORLY",
        "FERRARI": "RACE", "CONSTELLATION BRANDS": "STZ",
        "BOOKING": "BKNG", "PRICELINE": "BKNG",
        # Guy Spier / Aquamarine
        "WELLS FARGO": "WFC", "BANK OF NEW YORK": "BK",
        "BIGLARI": "BH", "BRISTOL": "BMY",
    }
    name_upper = company_name.upper()
    for key, ticker in mapping.items():
        if key in name_upper:
            return ticker
    return ""

def enrich_with_prices(holdings: list[dict]) -> list[dict]:
    """Add current price and 45-day-ago price for each holding."""
    today = datetime.date.today()
    day45_ago = today - datetime.timedelta(days=45)

    tickers = list({h["ticker"] for h in holdings if h.get("ticker")})
    if not tickers:
        return holdings

    print(f"  Fetching prices for {len(tickers)} tickers...")

    # Batch download for efficiency
    price_map = {}
    chunk_size = 20
    for i in range(0, len(tickers), chunk_size):
        chunk = tickers[i:i+chunk_size]
        try:
            data = yf.download(
                chunk,
                start=(day45_ago - datetime.timedelta(days=5)).isoformat(),
                end=(today + datetime.timedelta(days=1)).isoformat(),
                progress=False,
                auto_adjust=True,
                group_by="ticker" if len(chunk) > 1 else "column",
            )

            for ticker in chunk:
                try:
                    # Extract series for this ticker
                    # yfinance with group_by='ticker' always produces MultiIndex
                    if hasattr(data.columns, 'levels'):
                        # MultiIndex: (ticker, field)
                        if (ticker, "Close") in data.columns:
                            series = data[(ticker, "Close")]
                        else:
                            continue
                    else:
                        # Flat columns (shouldn't happen with group_by='ticker')
                        if "Close" in data.columns:
                            series = data["Close"]
                        else:
                            continue

                    series = series.dropna()
                    if series.empty:
                        continue

                    current_price = float(series.iloc[-1])

                    # Find closest price to 45 days ago
                    idx_dates = [d.date() if hasattr(d, 'date') else d for d in series.index]
                    target = day45_ago
                    closest_i = min(range(len(idx_dates)), key=lambda i: abs((idx_dates[i] - target).days))
                    entry_price = float(series.iloc[closest_i])

                    price_map[ticker] = {
                        "current_price": round(current_price, 2),
                        "entry_price_45d": round(entry_price, 2),
                        "pct_change": round(((current_price - entry_price) / entry_price) * 100, 2) if entry_price else None,
                    }
                except Exception as e:
                    print(f"    [WARN] Price fetch failed for {ticker}: {e}")
        except Exception as e:
            print(f"  [WARN] Batch price fetch failed: {e}")
        time.sleep(0.5)

    for h in holdings:
        ticker = h.get("ticker", "")
        if ticker in price_map:
            h.update(price_map[ticker])
        else:
            h["current_price"] = None
            h["entry_price_45d"] = None
            h["pct_change"] = None

    return holdings

def compute_convergence(holdings_by_ticker: dict) -> dict:
    """
    Convergence score: how many of our tracked investors hold this stock.
    Higher = more conviction across the group.
    """
    ticker_counts = {}
    for ticker, investors in holdings_by_ticker.items():
        ticker_counts[ticker] = len(investors)
    return ticker_counts

def main():
    print("=== Margin of Safety Engine — Phase 1 ===")
    print(f"Run date: {datetime.datetime.now().isoformat()}\n")

    with open(INVESTORS_FILE) as f:
        config = json.load(f)

    all_holdings = []
    holdings_by_ticker = {}

    for investor in config["investors"]:
        name = investor["name"]
        fund = investor["fund"]
        cik = investor["cik"]
        print(f"Processing: {name} ({fund}) CIK={cik}")

        accession = get_latest_13f_accession(cik)
        holdings = []
        is_sample = False

        if accession:
            print(f"  Latest 13F: {accession}")
            holdings = parse_13f_holdings(cik, accession)
            # Fill in missing tickers from company name guesses
            for h in holdings:
                if not h.get("ticker"):
                    h["ticker"] = cusip_to_ticker_guess(h.get("company", ""))
            # Remove entries with no ticker
            holdings = [h for h in holdings if h.get("ticker")]
            print(f"  Parsed {len(holdings)} holdings with tickers")
        
        if not holdings and fund == "Berkshire Hathaway":
            print(f"  Using sample data for {fund}")
            holdings = [dict(h) for h in BERKSHIRE_SAMPLE]
            is_sample = True

        if not holdings:
            print(f"  [SKIP] No holdings found for {name}\n")
            continue

        # Deduplicate: merge same ticker entries (different share classes etc.)
        deduped = {}
        for h in holdings:
            t = h["ticker"]
            if t not in deduped:
                deduped[t] = dict(h)
            else:
                deduped[t]["shares"] = deduped[t].get("shares", 0) + h.get("shares", 0)
                deduped[t]["value_usd"] = deduped[t].get("value_usd", 0) + h.get("value_usd", 0)
        holdings = list(deduped.values())
        print(f"  After dedup: {len(holdings)} unique tickers")

        # Compute portfolio total value
        total_value = sum(h.get("value_usd", 0) for h in holdings)

        holdings = enrich_with_prices(holdings)

        for h in holdings:
            ticker = h["ticker"]
            pct_portfolio = round((h.get("value_usd", 0) / total_value * 100), 2) if total_value > 0 else 0

            record = {
                "ticker": ticker,
                "company": h["company"],
                "investor": name,
                "fund": fund,
                "shares": h["shares"],
                "value_usd": h.get("value_usd", 0),
                "pct_portfolio": pct_portfolio,
                "current_price": h.get("current_price"),
                "entry_price_45d": h.get("entry_price_45d"),
                "pct_change_since_entry": h.get("pct_change"),
                "is_sample": is_sample,
            }
            all_holdings.append(record)

            if ticker not in holdings_by_ticker:
                holdings_by_ticker[ticker] = []
            holdings_by_ticker[ticker].append(name)

        print(f"  Done: {len(holdings)} holdings added\n")

    # Add convergence scores
    convergence = compute_convergence(holdings_by_ticker)
    for h in all_holdings:
        h["convergence"] = convergence.get(h["ticker"], 1)

    # Sort by convergence desc, then pct_change asc (best buys first)
    all_holdings.sort(key=lambda x: (-x["convergence"], x["pct_change_since_entry"] or 999))

    output = {
        "generated_at": datetime.datetime.now().isoformat(),
        "total_holdings": len(all_holdings),
        "holdings": all_holdings,
    }

    with open(OUTPUT_FILE, "w") as f:
        json.dump(output, f, indent=2)

    print(f"✅ Saved {len(all_holdings)} holdings to {OUTPUT_FILE}")

    # Summary
    buy_zone = [h for h in all_holdings if h.get("pct_change_since_entry") and h["pct_change_since_entry"] < 0]
    multi_inv = [h for h in all_holdings if h["convergence"] > 1]
    tickers_unique = len({h["ticker"] for h in all_holdings})
    print(f"\n📊 Summary:")
    print(f"   Total holdings records: {len(all_holdings)}")
    print(f"   Unique tickers: {tickers_unique}")
    print(f"   Buy zone (below entry): {len(buy_zone)}")
    print(f"   Multi-investor: {len(multi_inv)}")

if __name__ == "__main__":
    main()
