from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pyodbc

app = FastAPI()

CONN_STR = (
	"Driver={ODBC Driver 18 for SQL Server};"
        "Server=192.168.1.13;"
        "Database=Home;"
        "UID=unabomber;"
        "PWD=n116533891W2k;"
        "Encrypt=yes;"
        "TrustServerCertificate=yes;"
)

def get_conn():
    return pyodbc.connect(CONN_STR, autocommit=False)

class MarkRequest(BaseModel):
    keepFileId: int
    deleteFileIds: list[int]
    toDelete: bool

@app.get("//dup-groups")
def list_dup_groups():
    """
    Returns groups that have duplicates.
    Assumes your 'files' table has: id, path, sizeBytes (or size), hash, todelete, MTimeUtc, LastSeenUtc
    Adjust column names as needed.
    """
    conn = get_conn()
    try:
        cur = conn.cursor()
        cur.execute("""
            SELECT
                FullHash,
                sizeBytes = MAX(sizeBytes),
                fileCount = COUNT(*)
            FROM files
            WHERE FullHash IS NOT NULL
            GROUP BY FullHash, sizeBytes
            HAVING COUNT(*) > 1
            ORDER BY (COUNT(*) - 1) * MAX(sizeBytes) DESC
        """)
        rows = cur.fetchall()

        out = []
        group_id = 1
        for r in rows:
            size_bytes = int(r.sizeBytes)
            file_count = int(r.fileCount)
            wasted = (file_count - 1) * size_bytes
            out.append({
                "groupId": str(group_id),     # UI expects groupId
                "FullHash": str(r.FullHash),
                "sizeBytes": size_bytes,
                "fileCount": file_count,
                "wastedBytes": wasted,
            })
            group_id += 1

        return out
    finally:
        conn.close()

@app.get("//dup-groups/{group_id}")
def get_group(group_id: str):
    """
    Since our simple list endpoint invents groupId values, we re-derive by rank.
    For production: store groups in a table with a stable groupId.
    """
    # Convert "group_id" (1-based) into the Nth group by wasted space
    try:
        idx = int(group_id)
        if idx < 1:
            raise ValueError()
    except ValueError:
        raise HTTPException(status_code=400, detail="group_id must be a positive integer string")

    conn = get_conn()
    try:
        cur = conn.cursor()

        # Get the Nth duplicate hash+size group
        cur.execute("""
            WITH dup AS (
                SELECT
                    FullHash,
                    sizeBytes,
                    fileCount = COUNT(*),
                    wastedBytes = (COUNT(*) - 1) * MAX(sizeBytes),
                    rn = ROW_NUMBER() OVER (
                        ORDER BY (COUNT(*) - 1) * MAX(sizeBytes) DESC
                    )
                FROM files
                WHERE FullHash IS NOT NULL
                GROUP BY FullHash, sizeBytes
                HAVING COUNT(*) > 1
            )
            SELECT FullHash, sizeBytes
            FROM dup
            WHERE rn = ?
        """, idx)

        g = cur.fetchone()
        if not g:
            raise HTTPException(status_code=404, detail="Group not found")

        the_hash = g.FullHash
        size_bytes = int(g.sizeBytes)

        # Fetch files in that group
        cur.execute("""
            SELECT FileId, path, sizeBytes, MTimeUtc, LastSeenUtc, todelete
            FROM files
            WHERE FullHash = ? AND sizeBytes = ?
            ORDER BY path
        """, the_hash, size_bytes)

        files = []
        for r in cur.fetchall():
            files.append({
                "FileId": int(r.FileId),
                "path": str(r.path),
                "sizeBytes": int(r.sizeBytes),
                "MTimeUtc": r.MTimeUtc,
                "LastSeenUtc": r.LastSeenUtc,
                "todelete": bool(r.todelete),
            })

        return {
            "groupId": str(idx),
            "FullHash": str(the_hash),
            "sizeBytes": size_bytes,
            "files": files,
        }
    finally:
        conn.close()

@app.post("//dup-groups/{group_id}/mark")
def mark_group(group_id: str, req: MarkRequest):
    """
    Sets todelete bit for req.deleteFileIds.
    Safety: ensure keepFileId is not included.
    """
    if req.keepFileId in req.deleteFileIds:
        raise HTTPException(status_code=400, detail="keepFileId cannot be in deleteFileIds")

    if not req.deleteFileIds:
        raise HTTPException(status_code=400, detail="No deleteFileIds supplied")

    # Validate group exists and get its hash+size
    try:
        idx = int(group_id)
        if idx < 1:
            raise ValueError()
    except ValueError:
        raise HTTPException(status_code=400, detail="group_id must be a positive integer string")

    conn = get_conn()
    try:
        cur = conn.cursor()

        cur.execute("""
            WITH dup AS (
                SELECT
                    FullHash,
                    sizeBytes,
                    rn = ROW_NUMBER() OVER (
                        ORDER BY (COUNT(*) - 1) * MAX(sizeBytes) DESC
                    )
                FROM files
                WHERE FullHash IS NOT NULL
                GROUP BY FullHash, sizeBytes
                HAVING COUNT(*) > 1
            )
            SELECT FullHash, sizeBytes
            FROM dup
            WHERE rn = ?
        """, idx)

        g = cur.fetchone()
        if not g:
            raise HTTPException(status_code=404, detail="Group not found")

        the_hash = g.FullHash
        size_bytes = int(g.sizeBytes)

        # Only update rows that belong to this group
        placeholders = ",".join("?" for _ in req.deleteFileIds)
        sql = f"""
            UPDATE files
            SET todelete = ?
            WHERE FullHash = ? AND sizeBytes = ?
              AND FileId IN ({placeholders})
        """

        params = [1 if req.toDelete else 0, the_hash, size_bytes] + req.deleteFileIds
        cur.execute(sql, params)
        conn.commit()

        return {"updated": cur.rowcount}
    finally:
        conn.close()
