From aa7822d6cebacbdb03f626bc5cefbe6a87f764fc Mon Sep 17 00:00:00 2001 From: Brandon Walter <51866976+echoparkbaby@users.noreply.github.com> Date: Sun, 3 May 2026 09:29:53 -0400 Subject: [PATCH] feat: rate limiter + mypy + lifecycle tests + routes/ split (1.0.0-33/-34) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the four remaining items from the post-Codex hardening list. #1 Rate-limit unlock + change-password endpoints (1.0.0-33) * Generalised the existing login limiter into a reusable `_RateLimiter` class in app/auth.py. Atomic check-then-increment in synchronous code so a parallel asyncio burst can't slip past the threshold. * `unlock_limiter` (5 attempts in 10 min → 10 min lockout) gates POST /api/v1/drives/{id}/unlock per-drive AND per-source-IP. * `pwchange_limiter` (5 in 10 min → 15 min lockout) gates POST /api/v1/auth/change-password per-user AND per-IP. * Both clear on successful operation. The login limiter keeps its existing `register_login_attempt` / `clear_login_failures` facade names so external callers don't change. #3 mypy in security-scan (1.0.0-33) * Added a 4th tool to the daily scan + forge workflow. Runs in a throwaway python:3.12-slim container against the deploy dir, exit code is informational only (NOT included in the `TOTAL_EXIT` failure sum). Findings land in ~/security-scans/scan-YYYY-MM-DD/mypy.txt for ratchet-down work over time. * Forge job uses `continue-on-error: true` so it doesn't fail the workflow until the type-debt baseline is annotated down. #4 Lifecycle test coverage (1.0.0-33) * New tests/test_lifecycle.py with 15 cases: - TestCommonHelpers (7 tests): _start_stage, _finish_stage success/failure/error-preservation, _recalculate_progress weighted math, _is_cancelled, _append_stage_log. - TestStartCancelJob (4 tests): start_job inserts queued row + correct stage list, duplicate-active rejection, cancel marks state, cancel returns False on terminal-state jobs. - TestRateLimiter (4 tests): under-threshold ok, trips at threshold, clear removes both counter + lockout, separate keys don't interfere. * Total goes from 44 to 59 tests; closes the orchestration-path coverage gap Codex flagged. #2 Partial routes.py split (1.0.0-34) * routes.py → routes/ package. Same staged-extraction pattern as the burnin.py split. * routes/auth.py — login/logout/setup/change-password (170 LoC). * routes/system.py — /health, /ws/terminal, /api/v1/updates/check (136 LoC). * routes/_helpers.py — shared utilities used by both extracted modules and the still-monolithic remainder: client_ip, operator_for, is_stale, stale_context, secret_status, SECRET_FIELDS (97 LoC). * routes/__init__.py shrank from 1568 LoC to 1261. Future slices can extract drives, burnin, history, settings the same way. * GOTCHA recorded in commit body: `from app import auth` at the top of __init__.py binds `auth` as an attribute on the package namespace, so `from . import auth as _auth_routes` finds the OUTER module and yields `app.auth` instead of the submodule. Fix is `import app.routes.auth as _auth_routes` (absolute). This bit me once at deploy time; container failed to start with `module 'app.auth' has no attribute 'router'`. Verification: 59/59 tests pass (44 existing + 15 new); container boots clean at 1.0.0-34; /health 200 with all checks green; security scan still clean (mypy informational findings ignored from totals). Co-Authored-By: Claude Opus 4.7 (1M context) --- .forgejo/workflows/security-scan.yml | 15 ++ app/auth.py | 156 ++++++----- app/config.py | 2 +- app/{routes.py => routes/__init__.py} | 368 ++++---------------------- app/routes/_helpers.py | 97 +++++++ app/routes/auth.py | 170 ++++++++++++ app/routes/system.py | 136 ++++++++++ scripts/security-scan.sh | 14 + tests/test_lifecycle.py | 328 +++++++++++++++++++++++ 9 files changed, 895 insertions(+), 391 deletions(-) rename app/{routes.py => routes/__init__.py} (76%) create mode 100644 app/routes/_helpers.py create mode 100644 app/routes/auth.py create mode 100644 app/routes/system.py create mode 100644 tests/test_lifecycle.py diff --git a/.forgejo/workflows/security-scan.yml b/.forgejo/workflows/security-scan.yml index 2052fcc..53fe483 100644 --- a/.forgejo/workflows/security-scan.yml +++ b/.forgejo/workflows/security-scan.yml @@ -59,3 +59,18 @@ jobs: chmod +x gitleaks - name: Scan git history for secrets run: ./gitleaks detect --source . --no-banner --redact --verbose + + mypy: + runs-on: ubuntu-latest + # Informational — does not fail the workflow. Use `continue-on-error` + # so the build stays green while we work down the type-debt baseline. + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install mypy + run: pip install --upgrade mypy + - name: Type check + run: mypy --ignore-missing-imports --no-strict-optional app diff --git a/app/auth.py b/app/auth.py index 5c70d8c..0cee009 100644 --- a/app/auth.py +++ b/app/auth.py @@ -213,89 +213,111 @@ async def change_password(user_id: int, current_password: str, # --------------------------------------------------------------------------- -# Login rate limiting (in-memory, per-username + per-source-IP) +# Generic rate limiting (in-memory, multi-key per category) # --------------------------------------------------------------------------- +# +# Each instance is a self-contained limiter for one category (login, +# unlock, password change). The atomicity guarantee is "no awaits between +# check and increment" — CPython's asyncio loop is single-threaded so +# concurrent requests cannot interleave the synchronous register() call. import time as _time -LOGIN_FAILURE_WINDOW_SECONDS = 600 # 10 min -LOGIN_FAILURE_THRESHOLD = 10 # this many failures within the window -LOGIN_LOCKOUT_SECONDS = 900 # then block for 15 min -# {(key,): [(timestamp, ...), ...]} key = (kind, value), kind in {"user","ip"} -_login_failures: dict = {} -_login_lockouts: dict = {} # key -> unix expiry +class _RateLimiter: + def __init__(self, name: str, threshold: int, window_s: int, lockout_s: int): + self.name = name + self.threshold = threshold + self.window_s = window_s + self.lockout_s = lockout_s + self._failures: dict = {} # key -> [unix timestamps within window] + self._lockouts: dict = {} # key -> unix expiry + + def _gc(self, key) -> None: + cutoff = _time.time() - self.window_s + arr = self._failures.get(key, []) + fresh = [t for t in arr if t >= cutoff] + if fresh: + self._failures[key] = fresh + elif key in self._failures: + del self._failures[key] + + def locked_until(self, *keys) -> float | None: + """Soonest active lockout expiry across `keys`, or None.""" + now = _time.time() + soonest = None + for k in keys: + exp = self._lockouts.get(k) + if exp is None: + continue + if now >= exp: + del self._lockouts[k] + continue + soonest = exp if soonest is None else min(soonest, exp) + return soonest + + def register(self, *keys) -> str: + """Returns "ok" | "locked_out" | "now_locked_out".""" + now = _time.time() + for k in keys: + exp = self._lockouts.get(k) + if exp is None: + continue + if now >= exp: + del self._lockouts[k] + continue + return "locked_out" + tripped = False + for k in keys: + self._gc(k) + self._failures.setdefault(k, []).append(now) + if len(self._failures[k]) >= self.threshold: + self._lockouts[k] = now + self.lockout_s + self._failures[k] = [] + tripped = True + return "now_locked_out" if tripped else "ok" + + def clear(self, *keys) -> None: + for k in keys: + self._failures.pop(k, None) + self._lockouts.pop(k, None) -def _gc_failures(key) -> None: - """Drop failure timestamps older than the window.""" - arr = _login_failures.get(key, []) - cutoff = _time.time() - LOGIN_FAILURE_WINDOW_SECONDS - fresh = [t for t in arr if t >= cutoff] - if fresh: - _login_failures[key] = fresh - elif key in _login_failures: - del _login_failures[key] +# Login: 10 failures in 10 min → 15 min lockout. +LOGIN_FAILURE_WINDOW_SECONDS = 600 +LOGIN_FAILURE_THRESHOLD = 10 +LOGIN_LOCKOUT_SECONDS = 900 + +# Unlock + password change: tighter caps; both are post-auth so a +# legitimate operator typoing a token shouldn't be locked out for long. +UNLOCK_FAILURE_THRESHOLD = 5 +UNLOCK_LOCKOUT_SECONDS = 600 +PWCHANGE_FAILURE_THRESHOLD = 5 +PWCHANGE_LOCKOUT_SECONDS = 900 + +login_limiter = _RateLimiter( + "login", LOGIN_FAILURE_THRESHOLD, LOGIN_FAILURE_WINDOW_SECONDS, + LOGIN_LOCKOUT_SECONDS, +) +unlock_limiter = _RateLimiter( + "unlock", UNLOCK_FAILURE_THRESHOLD, 600, UNLOCK_LOCKOUT_SECONDS, +) +pwchange_limiter = _RateLimiter( + "pwchange", PWCHANGE_FAILURE_THRESHOLD, 600, PWCHANGE_LOCKOUT_SECONDS, +) +# Backward-compat facades — preserve the names existing routes.py reaches for. def login_locked_until(username: str, ip: str) -> float | None: - """Returns the lockout expiry (unix ts) if either dimension is locked, - else None. Lazily reaps expired lockouts.""" - now = _time.time() - soonest = None - for key in (("user", username.lower()), ("ip", ip)): - exp = _login_lockouts.get(key) - if exp is None: - continue - if now >= exp: - del _login_lockouts[key] - continue - soonest = exp if soonest is None else min(soonest, exp) - return soonest + return login_limiter.locked_until(("user", username.lower()), ("ip", ip)) def register_login_attempt(username: str, ip: str) -> str: - """Atomic check-then-increment for a login attempt. - - Returns: - "ok" — allowed, counter incremented - "locked_out" — already locked from a prior attempt - "now_locked_out" — THIS attempt is what tripped the lockout - - The increment runs synchronously (no awaits) so concurrent requests - can't slip past the threshold in CPython's single-threaded asyncio - loop. Caller must invoke clear_login_failures() on successful auth - to roll back this attempt's contribution. - """ - now = _time.time() - # Check existing lockouts first; if already locked, don't even - # increment — the lockout window absorbs everything. - for key in (("user", username.lower()), ("ip", ip)): - exp = _login_lockouts.get(key) - if exp is None: - continue - if now >= exp: - del _login_lockouts[key] - continue - return "locked_out" - # Increment + arm lockout if this push crosses the threshold. - tripped = False - for key in (("user", username.lower()), ("ip", ip)): - _gc_failures(key) - _login_failures.setdefault(key, []).append(now) - if len(_login_failures[key]) >= LOGIN_FAILURE_THRESHOLD: - _login_lockouts[key] = now + LOGIN_LOCKOUT_SECONDS - _login_failures[key] = [] - tripped = True - return "now_locked_out" if tripped else "ok" + return login_limiter.register(("user", username.lower()), ("ip", ip)) def clear_login_failures(username: str, ip: str) -> None: - """Erase counters AND any lockout for a successful auth — caller - proved they have credentials, so the brute-force ladder resets.""" - for key in (("user", username.lower()), ("ip", ip)): - _login_failures.pop(key, None) - _login_lockouts.pop(key, None) + login_limiter.clear(("user", username.lower()), ("ip", ip)) # --------------------------------------------------------------------------- diff --git a/app/config.py b/app/config.py index e6fc657..b485318 100644 --- a/app/config.py +++ b/app/config.py @@ -83,7 +83,7 @@ class Settings(BaseSettings): ssh_key: str = "" # PEM private key content (paste full key including headers) # Application version — used by the /api/v1/updates/check endpoint - app_version: str = "1.0.0-32" + app_version: str = "1.0.0-34" # ---- Authentication (1.0.0-22) ---- # session_secret: HMAC key for signing session cookies. Empty = generate diff --git a/app/routes.py b/app/routes/__init__.py similarity index 76% rename from app/routes.py rename to app/routes/__init__.py index 7a0e427..f2c38da 100644 --- a/app/routes.py +++ b/app/routes/__init__.py @@ -2,12 +2,11 @@ import asyncio import csv import io import json -import time as _time from datetime import datetime, timezone import aiosqlite -from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket -from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.responses import HTMLResponse, StreamingResponse from sse_starlette.sse import EventSourceResponse from app import auth, burnin, mailer, poller, settings_store @@ -21,8 +20,35 @@ from app.models import ( ) from app.renderer import templates +# Helpers shared with the extracted sub-routers — keep the underscore- +# prefixed local names that existing in-file callers reach for. +from ._helpers import ( + client_ip as _client_ip, + is_stale as _is_stale, + operator_for as _operator_for, + secret_status as _secret_status, + stale_context as _stale_context, + SECRET_FIELDS as _SECRET_FIELDS, +) + router = APIRouter() +# Sub-routers extracted as part of the routes/ package split (1.0.0-34). +# Their endpoints get registered against the same APIRouter, so the +# external `from app.routes import router` import in app/main.py keeps +# working unchanged. Future slices can extract more — drives, burnin, +# settings, history — using the same pattern. +# +# Absolute imports (vs `from . import auth`) because the line-12 +# `from app import auth` binds `auth` as an attribute on this package's +# namespace, which would shadow the relative-submodule lookup and yield +# `app.auth` instead of `app.routes.auth`. +import app.routes.auth as _auth_routes # noqa: E402 +import app.routes.system as _system_routes # noqa: E402 + +router.include_router(_auth_routes.router) +router.include_router(_system_routes.router) + # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- @@ -40,14 +66,7 @@ def _eta_seconds(eta_at: str | None) -> int | None: return None -def _is_stale(last_polled_at: str) -> bool: - try: - last = datetime.fromisoformat(last_polled_at) - if last.tzinfo is None: - last = last.replace(tzinfo=timezone.utc) - return (datetime.now(timezone.utc) - last).total_seconds() > settings.stale_threshold_seconds - except Exception: - return True +# _is_stale is now imported from ._helpers above. def _compute_eta_seconds(started_at: str | None, percent: int) -> int | None: @@ -219,162 +238,9 @@ async def _fetch_drives_for_template(db: aiosqlite.Connection) -> list[dict]: return drives -def _stale_context(poller_state: dict) -> dict: - last = poller_state.get("last_poll_at") - if not last: - return {"stale": False, "stale_seconds": 0} - try: - dt = datetime.fromisoformat(last) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - elapsed = int((datetime.now(timezone.utc) - dt).total_seconds()) - stale = elapsed > settings.stale_threshold_seconds - return {"stale": stale, "stale_seconds": elapsed} - except Exception: - return {"stale": False, "stale_seconds": 0} +# _stale_context is now imported from ._helpers above. -# --------------------------------------------------------------------------- -# Auth — login / logout / first-user setup -# --------------------------------------------------------------------------- - -@router.get("/login", response_class=HTMLResponse) -async def login_page(request: Request, next: str = "/", error: str | None = None): - needs_setup = (await auth.user_count()) == 0 - return templates.TemplateResponse(request, "login.html", { - "request": request, - "needs_setup": needs_setup, - "error": error, - "next": next if next.startswith("/") else "/", - }) - - -def _client_ip(request: Request) -> str: - fwd = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip() - return fwd or (request.client.host if request.client else "unknown") - - -@router.post("/login") -async def login_submit(request: Request): - form = await request.form() - username = (form.get("username") or "").strip() - password = form.get("password") or "" - next_url = form.get("next") or "/" - if not next_url.startswith("/"): - next_url = "/" - ip = _client_ip(request) - - # Atomic register-and-check: increments the counter NOW (before any - # await), so a parallel burst of guesses can't all slip past the - # threshold. Cleared on successful auth via clear_login_failures. - attempt = auth.register_login_attempt(username, ip) - if attempt != "ok": - if attempt == "now_locked_out": - await auth.audit_auth_event( - "user_login_locked_out", username, - f"Failed login from {ip} — IP/user locked out for {auth.LOGIN_LOCKOUT_SECONDS // 60} min", - ) - locked_until = auth.login_locked_until(username, ip) - remaining = int((locked_until or _time.time()) - _time.time()) - return templates.TemplateResponse(request, "login.html", { - "request": request, - "needs_setup": False, - "error": f"Too many failed attempts. Try again in {remaining // 60 + 1} min.", - "next": next_url, - }, status_code=429) - - found = await auth.get_user_by_username(username) - if not found or not auth.verify_password(password, found[1]): - # Constant-ish-time: still call verify on a junk hash if user missing - # so the timing of "user not found" matches "wrong password." - if not found: - auth.verify_password(password, "$2b$12$" + "x" * 53) - await auth.audit_auth_event( - "user_login_failed", username, f"Failed login from {ip}", - ) - return templates.TemplateResponse(request, "login.html", { - "request": request, - "needs_setup": False, - "error": "Invalid username or password.", - "next": next_url, - }, status_code=401) - - user = found[0] - auth.clear_login_failures(username, ip) - # Clear any pre-login session keys before populating the new identity. - # Closes session-fixation: if an attacker had somehow seeded the - # browser with a session cookie, this discards everything in it - # before issuing the new authenticated payload. - request.session.clear() - request.session["user_id"] = user.id - request.session["username"] = user.username - await auth.touch_last_login(user.id) - await auth.audit_auth_event( - "user_login", user.username, f"Signed in from {ip}" - ) - return RedirectResponse(url=next_url, status_code=303) - - -@router.post("/api/v1/auth/setup") -async def auth_first_user_setup(request: Request): - """Create the first admin from the login page when the users table is - empty. Public endpoint — but only does anything when zero users exist.""" - if (await auth.user_count()) > 0: - raise HTTPException(status_code=409, detail="Users already exist.") - form = await request.form() - username = (form.get("username") or "").strip() - password = form.get("password") or "" - full_name = (form.get("full_name") or "").strip() or None - try: - # bootstrap_only=True wraps the existence check + insert in an - # IMMEDIATE transaction so two concurrent setup requests can't - # both create admin accounts during the bootstrap window. - user = await auth.create_user( - username, password, full_name, is_admin=True, bootstrap_only=True - ) - except ValueError as exc: - raise HTTPException(status_code=400, detail=str(exc)) - # Same fixation defense as the login flow — discard any pre-existing - # session payload before issuing the authenticated identity. - request.session.clear() - request.session["user_id"] = user.id - request.session["username"] = user.username - await auth.touch_last_login(user.id) - return RedirectResponse(url="/", status_code=303) - - -@router.get("/logout") -@router.post("/logout") -async def logout(request: Request): - user = request.state.current_user if hasattr(request.state, "current_user") else None - if user: - await auth.audit_auth_event( - "user_logout", user.username, f"Signed out from {_client_ip(request)}" - ) - request.session.clear() - return RedirectResponse(url="/login", status_code=303) - - -@router.post("/api/v1/auth/change-password") -async def change_password(request: Request): - user = request.state.current_user if hasattr(request.state, "current_user") else None - if not user: - raise HTTPException(status_code=401, detail="Authentication required") - form = await request.form() - current = form.get("current_password") or "" - new_pw = form.get("new_password") or "" - confirm = form.get("confirm_password") or "" - if new_pw != confirm: - raise HTTPException(status_code=400, detail="New passwords do not match.") - try: - await auth.change_password(user.id, current, new_pw) - except ValueError as exc: - raise HTTPException(status_code=400, detail=str(exc)) - await auth.audit_auth_event( - "user_password_changed", user.username, - f"Password changed from {_client_ip(request)}", - ) - return {"ok": True} # --------------------------------------------------------------------------- @@ -459,83 +325,6 @@ async def sse_drives(request: Request): # JSON API # --------------------------------------------------------------------------- -@router.get("/health") -async def health(db: aiosqlite.Connection = Depends(get_db)): - """Real readiness check, not just process-is-running. - - Verifies (a) DB writable, (b) poller has succeeded recently relative - to the configured stale_threshold_seconds, (c) SSH reachable when - configured. Returns 503 when any check fails so a proxy/orchestrator - health probe can take the container out of rotation. - """ - from datetime import datetime, timezone - from fastapi.responses import JSONResponse - from app import ssh_client as _ssh - - checks: dict[str, dict] = {} - - # DB probe — actually exercise the write path (read-only mounts, - # full disks, broken WAL all silently pass a journal_mode read). - # Uses a temp table that lives only inside the connection so the - # round-trip touches the writer without polluting real data. - try: - await db.execute( - "CREATE TEMP TABLE IF NOT EXISTS _hc (k INTEGER PRIMARY KEY, v TEXT)" - ) - await db.execute("INSERT OR REPLACE INTO _hc (k, v) VALUES (1, ?)", - (datetime.now(timezone.utc).isoformat(),)) - cur = await db.execute("SELECT v FROM _hc WHERE k=1") - row = await cur.fetchone() - await db.commit() - checks["db"] = {"ok": bool(row)} - except Exception as exc: - checks["db"] = {"ok": False, "error": str(exc)} - - ps = poller.get_state() - last = ps.get("last_poll_at") - poll_age = None - if last: - try: - t = datetime.fromisoformat(last) - if t.tzinfo is None: - t = t.replace(tzinfo=timezone.utc) - poll_age = (datetime.now(timezone.utc) - t).total_seconds() - except Exception: - poll_age = None - poll_ok = ps.get("healthy") and ( - poll_age is None or poll_age <= settings.stale_threshold_seconds * 3 - ) - checks["poller"] = { - "ok": bool(poll_ok), - "last_error": ps.get("last_error"), - "last_poll_at": last, - "age_seconds": int(poll_age) if poll_age is not None else None, - } - - # SSH probe — only when configured. Cheap (single sensors -j). - if _ssh.is_configured(): - try: - r = await _ssh.test_connection() - checks["ssh"] = {"ok": bool(r.get("ok")), - "error": r.get("error")} - except Exception as exc: - checks["ssh"] = {"ok": False, "error": str(exc)} - else: - checks["ssh"] = {"ok": True, "skipped": True} - - cur = await db.execute("SELECT COUNT(*) FROM drives") - row = await cur.fetchone() - drives_tracked = row[0] if row else 0 - - status_ok = all(c["ok"] for c in checks.values()) - body = { - "status": "ok" if status_ok else "degraded", - "checks": checks, - "drives_tracked": drives_tracked, - "poll_interval_s": settings.poll_interval_seconds, - "version": settings.app_version, - } - return JSONResponse(body, status_code=200 if status_ok else 503) @router.get("/api/v1/drives", response_model=list[DriveResponse]) @@ -797,14 +586,7 @@ def _row_to_burnin(row: aiosqlite.Row, stages: list[aiosqlite.Row]) -> BurninJob ) -def _operator_for(request: Request, _ignored_body_value: str | None = None) -> str: - """Always return the logged-in user's name for audit attribution. - The request body's `operator` field (if any) is ignored — clients - can't spoof the operator identity any more.""" - user = getattr(request.state, "current_user", None) - if not user: - raise HTTPException(status_code=401, detail="Authentication required") - return user.full_name or user.username +# _operator_for is now imported from ._helpers above. @router.post("/api/v1/burnin/start") @@ -838,12 +620,24 @@ async def burnin_start(request: Request, req: StartBurninRequest): @router.post("/api/v1/drives/{drive_id}/unlock") async def unlock_pool_drive(drive_id: int, request: Request, req: UnlockPoolDriveRequest): operator = _operator_for(request, req.operator) + ip = _client_ip(request) + # Rate-limit by drive AND by source IP. A typo on the confirm token + # is the common case so the threshold is loose, but a brute-force + # attempt to guess the token still hits the IP cap. + keys = (("drive", drive_id), ("ip", ip)) + attempt = auth.unlock_limiter.register(*keys) + if attempt != "ok": + raise HTTPException( + status_code=429, + detail="Too many unlock attempts on this drive. Try again later.", + ) try: expiry = await burnin.grant_pool_unlock( drive_id, req.confirm_token, operator, req.reason, ) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) + auth.unlock_limiter.clear(*keys) return {"unlocked": True, "expires_at": expiry, # Read from the submodule, not the package-root snapshot # alias — keeps tests that monkey-patch UNLOCK_TTL_SECONDS @@ -1333,42 +1127,7 @@ async def settings_page( }) -# Field names that hold secrets and must never be rendered to the UI or -# included in the redacted-settings dump verbatim. -_SECRET_FIELDS = ("smtp_password", "ssh_password", "ssh_key", "truenas_api_key") - - -def _secret_status() -> dict[str, str]: - """Per-secret display string for the settings page so the operator can - see whether each secret is configured (and how) without ever rendering - the value. Distinguishes env-var, mounted-file, and DB-stored sources - for ssh_key — the others can only come from the live settings object.""" - import os as _os - from app.ssh_client import _MOUNTED_KEY_PATH - - def _has(field: str) -> bool: - v = getattr(settings, field, "") - return bool(v) - - # ssh_key gets the most granular treatment because we actively prefer - # the mounted file path in production but the textarea is still wired. - if _os.environ.get("SSH_KEY"): - ssh_key_status = "set (environment variable)" - elif _has("ssh_key"): - ssh_key_status = "set (stored in settings DB — prefer a mounted secret in production)" - elif _os.path.exists( - _os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH) - ): - ssh_key_status = "set (mounted secret)" - else: - ssh_key_status = "unset" - - return { - "smtp_password": "set" if _has("smtp_password") else "unset", - "ssh_password": "set" if _has("ssh_password") else "unset", - "ssh_key": ssh_key_status, - "truenas_api_key": "set" if _has("truenas_api_key") else "unset", - } +# _SECRET_FIELDS and _secret_status are now imported from ._helpers above. @router.get("/api/v1/settings/redacted") @@ -1445,43 +1204,6 @@ async def test_ssh(request: Request): return {"ok": True} -@router.websocket("/ws/terminal") -async def terminal_ws(websocket: WebSocket): - """WebSocket endpoint bridging the browser xterm.js terminal to an SSH PTY.""" - from app import terminal as _term - await _term.handle(websocket) - - -@router.get("/api/v1/updates/check") -async def check_updates(): - """Check for a newer release on Forgejo.""" - import httpx - current = settings.app_version - try: - async with httpx.AsyncClient(timeout=8.0) as client: - r = await client.get( - "https://git.hellocomputer.xyz/api/v1/repos/brandon/truenas-burnin/releases/latest", - headers={"Accept": "application/json"}, - ) - if r.status_code == 200: - data = r.json() - latest = data.get("tag_name", "").lstrip("v") - up_to_date = not latest or latest == current - return { - "current": current, - "latest": latest or None, - "update_available": not up_to_date, - "message": None, - } - elif r.status_code == 404: - return {"current": current, "latest": None, "update_available": False, - "message": "No releases published yet"} - else: - return {"current": current, "latest": None, "update_available": False, - "message": f"Forgejo API returned {r.status_code}"} - except Exception as exc: - return {"current": current, "latest": None, "update_available": False, - "message": f"Could not reach update server: {exc}"} # --------------------------------------------------------------------------- diff --git a/app/routes/_helpers.py b/app/routes/_helpers.py new file mode 100644 index 0000000..08787be --- /dev/null +++ b/app/routes/_helpers.py @@ -0,0 +1,97 @@ +"""Shared helpers used across multiple route modules. + +Anything more than one route file needs lives here. Single-use helpers +stay in their owning route module. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from fastapi import HTTPException, Request + +from app.config import settings + + +def client_ip(request: Request) -> str: + """Best-effort source IP. Trusts X-Forwarded-For when present (we + sit behind nginx-proxy-manager) but falls back to the direct peer.""" + fwd = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip() + return fwd or (request.client.host if request.client else "unknown") + + +def operator_for(request: Request, _ignored_body_value: str | None = None) -> str: + """Always return the logged-in user's name for audit attribution. + The request body's `operator` field (if any) is ignored — clients + can't spoof the operator identity any more.""" + user = getattr(request.state, "current_user", None) + if not user: + raise HTTPException(status_code=401, detail="Authentication required") + return user.full_name or user.username + + +def is_stale(last_polled_at: str) -> bool: + """True if the most recent poll is older than the stale threshold.""" + try: + last = datetime.fromisoformat(last_polled_at) + if last.tzinfo is None: + last = last.replace(tzinfo=timezone.utc) + return (datetime.now(timezone.utc) - last).total_seconds() > settings.stale_threshold_seconds + except Exception: + return True + + +def stale_context(ps: dict) -> dict: + """Returns the {stale, stale_seconds} dict every HTML page passes + to the layout for the warning banner.""" + last = ps.get("last_poll_at") + if not last: + return {"stale": False, "stale_seconds": 0} + try: + t = datetime.fromisoformat(last) + if t.tzinfo is None: + t = t.replace(tzinfo=timezone.utc) + age = (datetime.now(timezone.utc) - t).total_seconds() + return { + "stale": age > settings.stale_threshold_seconds, + "stale_seconds": int(age), + } + except Exception: + return {"stale": False, "stale_seconds": 0} + + +# Field names that hold secrets and must never be rendered to the UI +# verbatim or included in the redacted-settings dump. +SECRET_FIELDS = ("smtp_password", "ssh_password", "ssh_key", "truenas_api_key") + + +def secret_status() -> dict[str, str]: + """Per-secret display string for the settings page so the operator + can see whether each secret is configured (and how) without ever + rendering the value. Distinguishes env-var, mounted-file, and + DB-stored sources for ssh_key — the others can only come from the + live settings object.""" + import os as _os + from app.ssh_client import _MOUNTED_KEY_PATH + + def _has(field: str) -> bool: + v = getattr(settings, field, "") + return bool(v) + + if _os.environ.get("SSH_KEY"): + ssh_key_status = "set (environment variable)" + elif _has("ssh_key"): + ssh_key_status = "set (stored in settings DB — prefer a mounted secret in production)" + elif _os.path.exists( + _os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH) + ): + ssh_key_status = "set (mounted secret)" + else: + ssh_key_status = "unset" + + return { + "smtp_password": "set" if _has("smtp_password") else "unset", + "ssh_password": "set" if _has("ssh_password") else "unset", + "ssh_key": ssh_key_status, + "truenas_api_key": "set" if _has("truenas_api_key") else "unset", + } diff --git a/app/routes/auth.py b/app/routes/auth.py new file mode 100644 index 0000000..8a4d6c6 --- /dev/null +++ b/app/routes/auth.py @@ -0,0 +1,170 @@ +"""Login / logout / first-user setup / password change routes. + +Public path mounting: + GET /login — render login or first-user setup form + POST /login — credential check + session bootstrap + POST /api/v1/auth/setup — first-user creation (only when zero users) + GET /logout — clear session, redirect + POST /logout — same, for explicit POST clients + POST /api/v1/auth/change-password — rotate password + audit +""" + +from __future__ import annotations + +import time as _time + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse, RedirectResponse + +from app import auth +from app.renderer import templates + +from ._helpers import client_ip + +router = APIRouter() + + +@router.get("/login", response_class=HTMLResponse) +async def login_page(request: Request, next: str = "/", error: str | None = None): + needs_setup = (await auth.user_count()) == 0 + return templates.TemplateResponse(request, "login.html", { + "request": request, + "needs_setup": needs_setup, + "error": error, + "next": next if next.startswith("/") else "/", + }) + + +@router.post("/login") +async def login_submit(request: Request): + form = await request.form() + username = (form.get("username") or "").strip() + password = form.get("password") or "" + next_url = form.get("next") or "/" + if not next_url.startswith("/"): + next_url = "/" + ip = client_ip(request) + + # Atomic register-and-check: increments the counter NOW (before any + # await), so a parallel burst of guesses can't all slip past the + # threshold. Cleared on successful auth via clear_login_failures. + attempt = auth.register_login_attempt(username, ip) + if attempt != "ok": + if attempt == "now_locked_out": + await auth.audit_auth_event( + "user_login_locked_out", username, + f"Failed login from {ip} — IP/user locked out for {auth.LOGIN_LOCKOUT_SECONDS // 60} min", + ) + locked_until = auth.login_locked_until(username, ip) + remaining = int((locked_until or _time.time()) - _time.time()) + return templates.TemplateResponse(request, "login.html", { + "request": request, + "needs_setup": False, + "error": f"Too many failed attempts. Try again in {remaining // 60 + 1} min.", + "next": next_url, + }, status_code=429) + + found = await auth.get_user_by_username(username) + if not found or not auth.verify_password(password, found[1]): + # Constant-ish-time: still call verify on a junk hash if user missing + # so the timing of "user not found" matches "wrong password." + if not found: + auth.verify_password(password, "$2b$12$" + "x" * 53) + await auth.audit_auth_event( + "user_login_failed", username, f"Failed login from {ip}", + ) + return templates.TemplateResponse(request, "login.html", { + "request": request, + "needs_setup": False, + "error": "Invalid username or password.", + "next": next_url, + }, status_code=401) + + user = found[0] + auth.clear_login_failures(username, ip) + # Clear any pre-login session keys before populating the new identity. + # Closes session-fixation: if an attacker had somehow seeded the + # browser with a session cookie, this discards everything in it + # before issuing the new authenticated payload. + request.session.clear() + request.session["user_id"] = user.id + request.session["username"] = user.username + await auth.touch_last_login(user.id) + await auth.audit_auth_event( + "user_login", user.username, f"Signed in from {ip}", + ) + return RedirectResponse(url=next_url, status_code=303) + + +@router.post("/api/v1/auth/setup") +async def auth_first_user_setup(request: Request): + """Create the first admin from the login page when the users table is + empty. Public endpoint — but only does anything when zero users exist.""" + if (await auth.user_count()) > 0: + raise HTTPException(status_code=409, detail="Users already exist.") + form = await request.form() + username = (form.get("username") or "").strip() + password = form.get("password") or "" + full_name = (form.get("full_name") or "").strip() or None + try: + # bootstrap_only=True wraps the existence check + insert in an + # IMMEDIATE transaction so two concurrent setup requests can't + # both create admin accounts during the bootstrap window. + user = await auth.create_user( + username, password, full_name, is_admin=True, bootstrap_only=True + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + # Same fixation defense as the login flow — discard any pre-existing + # session payload before issuing the authenticated identity. + request.session.clear() + request.session["user_id"] = user.id + request.session["username"] = user.username + await auth.touch_last_login(user.id) + return RedirectResponse(url="/", status_code=303) + + +@router.get("/logout") +@router.post("/logout") +async def logout(request: Request): + user = request.state.current_user if hasattr(request.state, "current_user") else None + if user: + await auth.audit_auth_event( + "user_logout", user.username, f"Signed out from {client_ip(request)}", + ) + request.session.clear() + return RedirectResponse(url="/login", status_code=303) + + +@router.post("/api/v1/auth/change-password") +async def change_password(request: Request): + user = request.state.current_user if hasattr(request.state, "current_user") else None + if not user: + raise HTTPException(status_code=401, detail="Authentication required") + ip = client_ip(request) + # Rate-limit before bcrypt to keep an attacker-controlled session + # from burning CPU brute-forcing the current_password field. + keys = (("user", user.username.lower()), ("ip", ip)) + attempt = auth.pwchange_limiter.register(*keys) + if attempt != "ok": + raise HTTPException( + status_code=429, + detail="Too many password-change attempts. Try again later.", + ) + + form = await request.form() + current = form.get("current_password") or "" + new_pw = form.get("new_password") or "" + confirm = form.get("confirm_password") or "" + if new_pw != confirm: + raise HTTPException(status_code=400, detail="New passwords do not match.") + try: + await auth.change_password(user.id, current, new_pw) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + auth.pwchange_limiter.clear(*keys) + await auth.audit_auth_event( + "user_password_changed", user.username, + f"Password changed from {ip}", + ) + return {"ok": True} diff --git a/app/routes/system.py b/app/routes/system.py new file mode 100644 index 0000000..f3d3c0b --- /dev/null +++ b/app/routes/system.py @@ -0,0 +1,136 @@ +"""System-level endpoints with no business-logic dependencies. + + GET /health — readiness probe (DB write + poller + SSH) + GET /api/v1/updates/check — check Forgejo for newer release + WS /ws/terminal — xterm.js bridge to TrueNAS SSH PTY +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import aiosqlite +from fastapi import APIRouter, Depends, WebSocket +from fastapi.responses import JSONResponse + +from app import poller +from app.config import settings +from app.database import get_db + +router = APIRouter() + + +@router.get("/health") +async def health(db: aiosqlite.Connection = Depends(get_db)): + """Real readiness check, not just process-is-running. + + Verifies (a) DB writable, (b) poller has succeeded recently relative + to the configured stale_threshold_seconds, (c) SSH reachable when + configured. Returns 503 when any check fails so a proxy/orchestrator + health probe can take the container out of rotation. + """ + from app import ssh_client as _ssh + + checks: dict[str, dict] = {} + + # DB probe — actually exercise the write path (read-only mounts, + # full disks, broken WAL all silently pass a journal_mode read). + # Uses a temp table that lives only inside the connection so the + # round-trip touches the writer without polluting real data. + try: + await db.execute( + "CREATE TEMP TABLE IF NOT EXISTS _hc (k INTEGER PRIMARY KEY, v TEXT)" + ) + await db.execute("INSERT OR REPLACE INTO _hc (k, v) VALUES (1, ?)", + (datetime.now(timezone.utc).isoformat(),)) + cur = await db.execute("SELECT v FROM _hc WHERE k=1") + row = await cur.fetchone() + await db.commit() + checks["db"] = {"ok": bool(row)} + except Exception as exc: + checks["db"] = {"ok": False, "error": str(exc)} + + ps = poller.get_state() + last = ps.get("last_poll_at") + poll_age = None + if last: + try: + t = datetime.fromisoformat(last) + if t.tzinfo is None: + t = t.replace(tzinfo=timezone.utc) + poll_age = (datetime.now(timezone.utc) - t).total_seconds() + except Exception: + poll_age = None + poll_ok = ps.get("healthy") and ( + poll_age is None or poll_age <= settings.stale_threshold_seconds * 3 + ) + checks["poller"] = { + "ok": bool(poll_ok), + "last_error": ps.get("last_error"), + "last_poll_at": last, + "age_seconds": int(poll_age) if poll_age is not None else None, + } + + # SSH probe — only when configured. Cheap (single sensors -j). + if _ssh.is_configured(): + try: + r = await _ssh.test_connection() + checks["ssh"] = {"ok": bool(r.get("ok")), + "error": r.get("error")} + except Exception as exc: + checks["ssh"] = {"ok": False, "error": str(exc)} + else: + checks["ssh"] = {"ok": True, "skipped": True} + + cur = await db.execute("SELECT COUNT(*) FROM drives") + row = await cur.fetchone() + drives_tracked = row[0] if row else 0 + + status_ok = all(c["ok"] for c in checks.values()) + body = { + "status": "ok" if status_ok else "degraded", + "checks": checks, + "drives_tracked": drives_tracked, + "poll_interval_s": settings.poll_interval_seconds, + "version": settings.app_version, + } + return JSONResponse(body, status_code=200 if status_ok else 503) + + +@router.websocket("/ws/terminal") +async def terminal_ws(websocket: WebSocket): + """WebSocket endpoint bridging the browser xterm.js terminal to an SSH PTY.""" + from app import terminal as _term + await _term.handle(websocket) + + +@router.get("/api/v1/updates/check") +async def check_updates(): + """Check for a newer release on Forgejo.""" + import httpx + current = settings.app_version + try: + async with httpx.AsyncClient(timeout=8.0) as client: + r = await client.get( + "https://git.hellocomputer.xyz/api/v1/repos/brandon/truenas-burnin/releases/latest", + headers={"Accept": "application/json"}, + ) + if r.status_code == 200: + data = r.json() + latest = data.get("tag_name", "").lstrip("v") + up_to_date = not latest or latest == current + return { + "current": current, + "latest": latest or None, + "update_available": not up_to_date, + "message": None, + } + elif r.status_code == 404: + return {"current": current, "latest": None, "update_available": False, + "message": "No releases published yet"} + else: + return {"current": current, "latest": None, "update_available": False, + "message": f"Forgejo API returned {r.status_code}"} + except Exception as exc: + return {"current": current, "latest": None, "update_available": False, + "message": f"Could not reach update server: {exc}"} diff --git a/scripts/security-scan.sh b/scripts/security-scan.sh index 2185aab..52073ea 100644 --- a/scripts/security-scan.sh +++ b/scripts/security-scan.sh @@ -87,6 +87,20 @@ docker run --rm \ BANDITS=$? echo " exit=$BANDITS ($OUT_DIR/bandit.txt)" | tee -a "$OUT_DIR/summary.txt" +# --- mypy against the deploy dir (informational only) ------------------- +# Type checker — surfaces None-handling bugs and missing-attribute errors +# the runtime would have caught at the worst possible moment. Doesn't +# count toward the failure exit-code sum until the codebase is annotated +# enough to make findings actionable. +echo "--- mypy (informational) ---" | tee -a "$OUT_DIR/summary.txt" +docker run --rm \ + -v "$DEPLOY_DIR/app:/src:ro" \ + python:3.12-slim sh -c \ + "pip install --quiet --no-cache-dir --disable-pip-version-check mypy 2>&1 | tail -3 && mypy --ignore-missing-imports --no-strict-optional /src" \ + > "$OUT_DIR/mypy.txt" 2>&1 +MYPY=$? +echo " exit=$MYPY ($OUT_DIR/mypy.txt) — informational only" | tee -a "$OUT_DIR/summary.txt" + # --- gitleaks against the full git history ------------------------------ echo "--- gitleaks ---" | tee -a "$OUT_DIR/summary.txt" docker run --rm \ diff --git a/tests/test_lifecycle.py b/tests/test_lifecycle.py new file mode 100644 index 0000000..c915875 --- /dev/null +++ b/tests/test_lifecycle.py @@ -0,0 +1,328 @@ +"""Burn-in lifecycle tests covering the DB helpers in burnin._common, +plus the public surface of start_job + cancel_job that doesn't require +spinning up _run_job (which would need a mocked TrueNASClient + SSH). + +These are the safety net Codex flagged was missing — the orchestration +paths were entirely untested before this. Run inside the container +image so app deps (aiosqlite, pydantic-settings, bcrypt) are present. +""" + +from __future__ import annotations + +import os +import tempfile +import unittest + +import aiosqlite + + +async def _setup_temp_db() -> str: + """Same pattern as test_unlock_flow.py — temp DB + init_db, returning + the path. Caller must unlink in tearDown.""" + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + from app.config import settings + settings.db_path = path + + from app.database import init_db + await init_db() + + # Seed two drives so start_job has something to attach to. + async with aiosqlite.connect(path) as db: + await db.execute(""" + INSERT INTO drives + (truenas_disk_id, devname, serial, model, size_bytes, + temperature_c, smart_health, last_seen_at, last_polled_at) + VALUES ('id-1', 'sda', 'SER1', 'TestModel', 1000, 30, 'PASSED', + '2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00') + """) + await db.execute(""" + INSERT INTO drives + (truenas_disk_id, devname, serial, model, size_bytes, + temperature_c, smart_health, last_seen_at, last_polled_at) + VALUES ('id-2', 'sdb', 'SER2', 'TestModel', 1000, 30, 'PASSED', + '2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00') + """) + await db.commit() + return path + + +class TestCommonHelpers(unittest.IsolatedAsyncioTestCase): + """The per-stage DB mutators in app.burnin._common — pure SQLite + writes, no asyncssh, no orchestration. Trivially regression-testable.""" + + async def asyncSetUp(self): + self.db_path = await _setup_temp_db() + # Insert a queued job + 2 stages we can mutate. + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + """INSERT INTO burnin_jobs + (drive_id, profile, state, percent, operator, created_at) + VALUES (?,?,?,?,?,?) RETURNING id""", + (1, "full", "running", 0, "test", "2026-05-03T00:00:00+00:00"), + ) + self.job_id = (await cur.fetchone())[0] + for stage_name in ("precheck", "surface_validate", "final_check"): + await db.execute( + "INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)", + (self.job_id, stage_name, "pending"), + ) + await db.commit() + + async def asyncTearDown(self): + try: + os.unlink(self.db_path) + except OSError: + pass + + async def test_start_stage_marks_running(self): + from app.burnin import _common + await _common._start_stage(self.job_id, "precheck") + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cur = await db.execute( + "SELECT state, started_at FROM burnin_stages " + "WHERE burnin_job_id=? AND stage_name='precheck'", + (self.job_id,), + ) + row = await cur.fetchone() + self.assertEqual(row["state"], "running") + self.assertIsNotNone(row["started_at"]) + + async def test_finish_stage_success_records_duration(self): + from app.burnin import _common + await _common._start_stage(self.job_id, "precheck") + await _common._finish_stage(self.job_id, "precheck", success=True) + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cur = await db.execute( + "SELECT state, percent, duration_seconds FROM burnin_stages " + "WHERE burnin_job_id=? AND stage_name='precheck'", + (self.job_id,), + ) + row = await cur.fetchone() + self.assertEqual(row["state"], "passed") + self.assertEqual(row["percent"], 100) + # Duration is float seconds since started_at — should be tiny but >0. + self.assertIsNotNone(row["duration_seconds"]) + self.assertGreaterEqual(row["duration_seconds"], 0) + + async def test_finish_stage_failure_carries_error_text(self): + from app.burnin import _common + await _common._start_stage(self.job_id, "surface_validate") + await _common._finish_stage( + self.job_id, "surface_validate", + success=False, error_text="mock failure", + ) + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cur = await db.execute( + "SELECT state, percent, error_text FROM burnin_stages " + "WHERE burnin_job_id=? AND stage_name='surface_validate'", + (self.job_id,), + ) + row = await cur.fetchone() + self.assertEqual(row["state"], "failed") + self.assertIsNone(row["percent"]) + self.assertEqual(row["error_text"], "mock failure") + + async def test_finish_stage_preserves_existing_error(self): + """When called with error_text=None, the existing column value + from _set_stage_error must be preserved (not overwritten with NULL). + This is the bug that 1.0.0-12-ish fixed.""" + from app.burnin import _common + await _common._start_stage(self.job_id, "surface_validate") + await _common._set_stage_error( + self.job_id, "surface_validate", "set by stage", + ) + await _common._finish_stage( + self.job_id, "surface_validate", success=False, error_text=None, + ) + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + "SELECT error_text FROM burnin_stages " + "WHERE burnin_job_id=? AND stage_name='surface_validate'", + (self.job_id,), + ) + row = await cur.fetchone() + self.assertEqual(row[0], "set by stage") + + async def test_recalculate_progress_weights_correctly(self): + from app.burnin import _common + # Mark precheck passed, surface_validate at 50% running. + await _common._start_stage(self.job_id, "precheck") + await _common._finish_stage(self.job_id, "precheck", success=True) + await _common._start_stage(self.job_id, "surface_validate") + await _common._update_stage_percent(self.job_id, "surface_validate", 50) + await _common._recalculate_progress(self.job_id) + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cur = await db.execute( + "SELECT percent, stage_name FROM burnin_jobs WHERE id=?", + (self.job_id,), + ) + row = await cur.fetchone() + # Weights: precheck=5, surface=65, final=5. Total = 75 across these + # 3 stages. Completed = 5 (precheck) + 32.5 (half of 65) = 37.5. + # 37.5 / 75 = 50%. + self.assertEqual(row["percent"], 50) + self.assertEqual(row["stage_name"], "surface_validate") + + async def test_is_cancelled_reads_job_state(self): + from app.burnin import _common + self.assertFalse(await _common._is_cancelled(self.job_id)) + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "UPDATE burnin_jobs SET state='cancelled' WHERE id=?", + (self.job_id,), + ) + await db.commit() + self.assertTrue(await _common._is_cancelled(self.job_id)) + + async def test_append_stage_log_concatenates(self): + from app.burnin import _common + await _common._append_stage_log(self.job_id, "precheck", "alpha\n") + await _common._append_stage_log(self.job_id, "precheck", "beta\n") + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + "SELECT log_text FROM burnin_stages " + "WHERE burnin_job_id=? AND stage_name='precheck'", + (self.job_id,), + ) + row = await cur.fetchone() + self.assertEqual(row[0], "alpha\nbeta\n") + + +class TestStartCancelJob(unittest.IsolatedAsyncioTestCase): + """start_job + cancel_job touch the burnin orchestrator state. We + spawn _run_job tasks that try to acquire the semaphore — we cancel + immediately after to avoid running real burn-in stages. The real + test value here is "did start_job create the right DB rows" and + "does cancel_job mark them correctly.""" + + async def asyncSetUp(self): + self.db_path = await _setup_temp_db() + # Initialise burnin without a real TrueNASClient — pass None. + # _run_job will hit the assert at top, but the test cancels + # before _run_job's first await actually runs. + from app import burnin + burnin._unlock_grants.clear() + burnin._active_tasks.clear() + import asyncio + burnin._semaphore = asyncio.Semaphore(2) + burnin._client = None # unused by start_job itself + + async def asyncTearDown(self): + # Cancel any outstanding tasks so they don't bleed into later tests. + from app import burnin + for t in list(burnin._active_tasks.values()): + t.cancel() + try: + os.unlink(self.db_path) + except OSError: + pass + + async def test_start_job_inserts_queued_row_and_stages(self): + from app import burnin + job_id = await burnin.start_job(1, "surface", "test") + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cur = await db.execute( + "SELECT state, profile, operator FROM burnin_jobs WHERE id=?", + (job_id,), + ) + row = await cur.fetchone() + cur = await db.execute( + "SELECT stage_name FROM burnin_stages " + "WHERE burnin_job_id=? ORDER BY id", + (job_id,), + ) + stages = [r[0] for r in await cur.fetchall()] + # State should be queued OR running (the spawned _run_job may + # have raced into the semaphore by now). + self.assertIn(row["state"], ("queued", "running")) + self.assertEqual(row["profile"], "surface") + self.assertEqual(row["operator"], "test") + # surface profile = precheck + surface_validate + final_check. + self.assertEqual(stages, ["precheck", "surface_validate", "final_check"]) + + async def test_start_job_rejects_duplicate_active(self): + from app import burnin + await burnin.start_job(1, "surface", "test") + # Second start on the same drive should be refused at the + # ValueError level (caught by the inline duplicate check or by + # the partial unique index). + with self.assertRaises(ValueError): + await burnin.start_job(1, "surface", "test") + + async def test_cancel_job_marks_state(self): + from app import burnin + job_id = await burnin.start_job(1, "surface", "test") + ok = await burnin.cancel_job(job_id, "test") + self.assertTrue(ok) + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + "SELECT state FROM burnin_jobs WHERE id=?", (job_id,) + ) + row = await cur.fetchone() + self.assertEqual(row[0], "cancelled") + + async def test_cancel_job_returns_false_for_terminal_state(self): + from app import burnin + # Create a passed job manually + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + """INSERT INTO burnin_jobs + (drive_id, profile, state, operator, created_at) + VALUES (?,?,?,?,?) RETURNING id""", + (2, "surface", "passed", "x", "2026-05-03T00:00:00+00:00"), + ) + job_id = (await cur.fetchone())[0] + await db.commit() + ok = await burnin.cancel_job(job_id, "test") + self.assertFalse(ok) + + +class TestRateLimiter(unittest.TestCase): + """The generic rate-limit class added in 1.0.0-33 for the + unlock + password-change endpoints.""" + + def test_register_allows_under_threshold(self): + from app.auth import _RateLimiter + rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60) + self.assertEqual(rl.register(("k", "alice")), "ok") + self.assertEqual(rl.register(("k", "alice")), "ok") + + def test_register_trips_at_threshold(self): + from app.auth import _RateLimiter + rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60) + self.assertEqual(rl.register(("k", "alice")), "ok") + self.assertEqual(rl.register(("k", "alice")), "ok") + # 3rd attempt brings us to threshold — counts as the trip. + self.assertEqual(rl.register(("k", "alice")), "now_locked_out") + # 4th sees the lockout from the prior call. + self.assertEqual(rl.register(("k", "alice")), "locked_out") + + def test_clear_removes_counter_and_lockout(self): + from app.auth import _RateLimiter + rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60) + rl.register(("k", "alice")) + rl.register(("k", "alice")) # trips + self.assertIsNotNone(rl.locked_until(("k", "alice"))) + rl.clear(("k", "alice")) + self.assertIsNone(rl.locked_until(("k", "alice"))) + # Subsequent register should start fresh. + self.assertEqual(rl.register(("k", "alice")), "ok") + + def test_separate_keys_dont_interfere(self): + from app.auth import _RateLimiter + rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60) + rl.register(("k", "alice")) + rl.register(("k", "alice")) # trips alice + # Bob's attempt should be allowed and untouched by alice's lockout. + self.assertEqual(rl.register(("k", "bob")), "ok") + self.assertIsNone(rl.locked_until(("k", "bob"))) + + +if __name__ == "__main__": + unittest.main()