import asyncio import ipaddress import logging from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from app import auth, burnin, mailer, poller, retention, settings_store from app.config import settings from app.database import init_db from app.logging_config import configure as configure_logging from app.renderer import templates # noqa: F401 — registers filters as side-effect from app.routes import router from app.truenas import TrueNASClient # Configure structured JSON logging before anything else logs configure_logging() log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # IP allowlist middleware # --------------------------------------------------------------------------- class _IPAllowlistMiddleware(BaseHTTPMiddleware): """ Block requests from IPs not in ALLOWED_IPS. When ALLOWED_IPS is empty the middleware is a no-op. Checks X-Forwarded-For first (trusts the leftmost address), then the direct client IP. """ def __init__(self, app, allowed_ips: str) -> None: super().__init__(app) self._networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [] for entry in (s.strip() for s in allowed_ips.split(",") if s.strip()): try: self._networks.append(ipaddress.ip_network(entry, strict=False)) except ValueError: log.warning("Invalid ALLOWED_IPS entry ignored: %r", entry) def _is_allowed(self, ip_str: str) -> bool: try: addr = ipaddress.ip_address(ip_str) return any(addr in net for net in self._networks) except ValueError: return False async def dispatch(self, request: Request, call_next): if not self._networks: return await call_next(request) # Prefer X-Forwarded-For (leftmost = original client) forwarded = request.headers.get("X-Forwarded-For", "").split(",")[0].strip() client_ip = forwarded or (request.client.host if request.client else "") if self._is_allowed(client_ip): return await call_next(request) log.warning("Request blocked by IP allowlist", extra={"client_ip": client_ip}) return PlainTextResponse("Forbidden", status_code=403) # --------------------------------------------------------------------------- # Poller supervisor — restarts run() if it ever exits unexpectedly # --------------------------------------------------------------------------- async def _supervised_poller(client: TrueNASClient) -> None: while True: try: await poller.run(client) except asyncio.CancelledError: raise # Propagate shutdown signal cleanly except Exception as exc: log.critical("Poller crashed unexpectedly — restarting in 5s: %s", exc) await asyncio.sleep(5) # --------------------------------------------------------------------------- # Lifespan # --------------------------------------------------------------------------- _client: TrueNASClient | None = None @asynccontextmanager async def lifespan(app: FastAPI): global _client log.info("Starting up") await init_db() settings_store.init() await auth.bootstrap_admin_if_empty() _client = TrueNASClient() await burnin.init(_client) poll_task = asyncio.create_task(_supervised_poller(_client)) mailer_task = asyncio.create_task(mailer.run()) retention_task = asyncio.create_task(retention.run()) yield log.info("Shutting down") poll_task.cancel() mailer_task.cancel() retention_task.cancel() try: await asyncio.gather(poll_task, mailer_task, retention_task, return_exceptions=True) except asyncio.CancelledError: pass await _client.close() # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI(title="TrueNAS Burn-In Dashboard", lifespan=lifespan) # --------------------------------------------------------------------------- # Auth gate — must be added BEFORE include_router so it runs first. # Path-prefix allowlist below covers anything we want reachable without # a session cookie. SSE streams + WebSockets fall through to the dependency # in their handler so they 401 cleanly. # --------------------------------------------------------------------------- _PUBLIC_PATHS = {"/login", "/logout", "/health", "/auth/setup"} _PUBLIC_PREFIXES = ("/static/", "/api/v1/auth/") class _AuthGateMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): path = request.url.path # Always populate request.state.current_user from the session so # templates and route handlers can both rely on it. None when # unauthenticated. user_id = request.session.get("user_id") request.state.current_user = ( await auth.get_user_by_id(int(user_id)) if user_id else None ) if path in _PUBLIC_PATHS or path.startswith(_PUBLIC_PREFIXES): return await call_next(request) if request.state.current_user is not None: return await call_next(request) # Unauthenticated. HTML GETs bounce to /login with a `next` query # arg so the user lands back where they tried to go after logging # in. Anything else (API calls, SSE, POSTs) gets a 401. accept = request.headers.get("accept", "") if request.method == "GET" and "text/html" in accept: return auth.login_redirect(path) return JSONResponse( {"detail": "Authentication required"}, status_code=401 ) app.add_middleware(_AuthGateMiddleware) # SessionMiddleware must be added LAST (it wraps innermost so request.session # is populated before AuthGate runs). app.add_middleware( SessionMiddleware, secret_key=auth.get_session_secret(), session_cookie="burnin_session", max_age=settings.session_max_age_seconds, https_only=False, # we sit behind nginx-proxy-manager; trust upstream # SameSite=strict is the primary CSRF mitigation: the browser never # sends the session cookie on cross-site requests, so an attacker # page can't trigger any state-changing endpoint even if it knows # the URL. Trade-off: an external link (email, chat) into the app # won't carry the session — user has to re-auth via /login. For an # internal-only tool that's the right default. same_site="strict", ) if settings.allowed_ips: app.add_middleware(_IPAllowlistMiddleware, allowed_ips=settings.allowed_ips) log.info("IP allowlist active: %s", settings.allowed_ips) app.mount("/static", StaticFiles(directory="app/static"), name="static") app.include_router(router)