feat: rate limiter + mypy + lifecycle tests + routes/ split (1.0.0-33/-34)
Some checks are pending
Security scan / pip-audit (push) Waiting to run
Security scan / bandit (push) Waiting to run
Security scan / gitleaks (push) Waiting to run
Security scan / mypy (push) Waiting to run

Closes the four remaining items from the post-Codex hardening list.

#1 Rate-limit unlock + change-password endpoints (1.0.0-33)
   * Generalised the existing login limiter into a reusable
     `_RateLimiter` class in app/auth.py. Atomic check-then-increment
     in synchronous code so a parallel asyncio burst can't slip past
     the threshold.
   * `unlock_limiter` (5 attempts in 10 min → 10 min lockout) gates
     POST /api/v1/drives/{id}/unlock per-drive AND per-source-IP.
   * `pwchange_limiter` (5 in 10 min → 15 min lockout) gates
     POST /api/v1/auth/change-password per-user AND per-IP.
   * Both clear on successful operation. The login limiter keeps its
     existing `register_login_attempt` / `clear_login_failures`
     facade names so external callers don't change.

#3 mypy in security-scan (1.0.0-33)
   * Added a 4th tool to the daily scan + forge workflow. Runs in a
     throwaway python:3.12-slim container against the deploy dir,
     exit code is informational only (NOT included in the
     `TOTAL_EXIT` failure sum). Findings land in
     ~/security-scans/scan-YYYY-MM-DD/mypy.txt for ratchet-down
     work over time.
   * Forge job uses `continue-on-error: true` so it doesn't fail the
     workflow until the type-debt baseline is annotated down.

#4 Lifecycle test coverage (1.0.0-33)
   * New tests/test_lifecycle.py with 15 cases:
     - TestCommonHelpers (7 tests): _start_stage, _finish_stage
       success/failure/error-preservation, _recalculate_progress
       weighted math, _is_cancelled, _append_stage_log.
     - TestStartCancelJob (4 tests): start_job inserts queued row +
       correct stage list, duplicate-active rejection, cancel marks
       state, cancel returns False on terminal-state jobs.
     - TestRateLimiter (4 tests): under-threshold ok, trips at
       threshold, clear removes both counter + lockout, separate
       keys don't interfere.
   * Total goes from 44 to 59 tests; closes the orchestration-path
     coverage gap Codex flagged.

#2 Partial routes.py split (1.0.0-34)
   * routes.py → routes/ package. Same staged-extraction pattern as
     the burnin.py split.
   * routes/auth.py — login/logout/setup/change-password (170 LoC).
   * routes/system.py — /health, /ws/terminal, /api/v1/updates/check
     (136 LoC).
   * routes/_helpers.py — shared utilities used by both extracted
     modules and the still-monolithic remainder: client_ip,
     operator_for, is_stale, stale_context, secret_status,
     SECRET_FIELDS (97 LoC).
   * routes/__init__.py shrank from 1568 LoC to 1261. Future slices
     can extract drives, burnin, history, settings the same way.
   * GOTCHA recorded in commit body: `from app import auth` at the
     top of __init__.py binds `auth` as an attribute on the package
     namespace, so `from . import auth as _auth_routes` finds the
     OUTER module and yields `app.auth` instead of the submodule.
     Fix is `import app.routes.auth as _auth_routes` (absolute).
     This bit me once at deploy time; container failed to start
     with `module 'app.auth' has no attribute 'router'`.

Verification: 59/59 tests pass (44 existing + 15 new); container
boots clean at 1.0.0-34; /health 200 with all checks green; security
scan still clean (mypy informational findings ignored from totals).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Brandon Walter 2026-05-03 09:29:53 -04:00
parent eb2a964171
commit aa7822d6ce
9 changed files with 895 additions and 391 deletions

View file

@ -59,3 +59,18 @@ jobs:
chmod +x gitleaks chmod +x gitleaks
- name: Scan git history for secrets - name: Scan git history for secrets
run: ./gitleaks detect --source . --no-banner --redact --verbose run: ./gitleaks detect --source . --no-banner --redact --verbose
mypy:
runs-on: ubuntu-latest
# Informational — does not fail the workflow. Use `continue-on-error`
# so the build stays green while we work down the type-debt baseline.
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install mypy
run: pip install --upgrade mypy
- name: Type check
run: mypy --ignore-missing-imports --no-strict-optional app

View file

@ -213,89 +213,111 @@ async def change_password(user_id: int, current_password: str,
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Login rate limiting (in-memory, per-username + per-source-IP) # Generic rate limiting (in-memory, multi-key per category)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
#
# Each instance is a self-contained limiter for one category (login,
# unlock, password change). The atomicity guarantee is "no awaits between
# check and increment" — CPython's asyncio loop is single-threaded so
# concurrent requests cannot interleave the synchronous register() call.
import time as _time import time as _time
LOGIN_FAILURE_WINDOW_SECONDS = 600 # 10 min
LOGIN_FAILURE_THRESHOLD = 10 # this many failures within the window
LOGIN_LOCKOUT_SECONDS = 900 # then block for 15 min
# {(key,): [(timestamp, ...), ...]} key = (kind, value), kind in {"user","ip"} class _RateLimiter:
_login_failures: dict = {} def __init__(self, name: str, threshold: int, window_s: int, lockout_s: int):
_login_lockouts: dict = {} # key -> unix expiry self.name = name
self.threshold = threshold
self.window_s = window_s
self.lockout_s = lockout_s
self._failures: dict = {} # key -> [unix timestamps within window]
self._lockouts: dict = {} # key -> unix expiry
def _gc(self, key) -> None:
cutoff = _time.time() - self.window_s
arr = self._failures.get(key, [])
fresh = [t for t in arr if t >= cutoff]
if fresh:
self._failures[key] = fresh
elif key in self._failures:
del self._failures[key]
def locked_until(self, *keys) -> float | None:
"""Soonest active lockout expiry across `keys`, or None."""
now = _time.time()
soonest = None
for k in keys:
exp = self._lockouts.get(k)
if exp is None:
continue
if now >= exp:
del self._lockouts[k]
continue
soonest = exp if soonest is None else min(soonest, exp)
return soonest
def register(self, *keys) -> str:
"""Returns "ok" | "locked_out" | "now_locked_out"."""
now = _time.time()
for k in keys:
exp = self._lockouts.get(k)
if exp is None:
continue
if now >= exp:
del self._lockouts[k]
continue
return "locked_out"
tripped = False
for k in keys:
self._gc(k)
self._failures.setdefault(k, []).append(now)
if len(self._failures[k]) >= self.threshold:
self._lockouts[k] = now + self.lockout_s
self._failures[k] = []
tripped = True
return "now_locked_out" if tripped else "ok"
def clear(self, *keys) -> None:
for k in keys:
self._failures.pop(k, None)
self._lockouts.pop(k, None)
def _gc_failures(key) -> None: # Login: 10 failures in 10 min → 15 min lockout.
"""Drop failure timestamps older than the window.""" LOGIN_FAILURE_WINDOW_SECONDS = 600
arr = _login_failures.get(key, []) LOGIN_FAILURE_THRESHOLD = 10
cutoff = _time.time() - LOGIN_FAILURE_WINDOW_SECONDS LOGIN_LOCKOUT_SECONDS = 900
fresh = [t for t in arr if t >= cutoff]
if fresh: # Unlock + password change: tighter caps; both are post-auth so a
_login_failures[key] = fresh # legitimate operator typoing a token shouldn't be locked out for long.
elif key in _login_failures: UNLOCK_FAILURE_THRESHOLD = 5
del _login_failures[key] UNLOCK_LOCKOUT_SECONDS = 600
PWCHANGE_FAILURE_THRESHOLD = 5
PWCHANGE_LOCKOUT_SECONDS = 900
login_limiter = _RateLimiter(
"login", LOGIN_FAILURE_THRESHOLD, LOGIN_FAILURE_WINDOW_SECONDS,
LOGIN_LOCKOUT_SECONDS,
)
unlock_limiter = _RateLimiter(
"unlock", UNLOCK_FAILURE_THRESHOLD, 600, UNLOCK_LOCKOUT_SECONDS,
)
pwchange_limiter = _RateLimiter(
"pwchange", PWCHANGE_FAILURE_THRESHOLD, 600, PWCHANGE_LOCKOUT_SECONDS,
)
# Backward-compat facades — preserve the names existing routes.py reaches for.
def login_locked_until(username: str, ip: str) -> float | None: def login_locked_until(username: str, ip: str) -> float | None:
"""Returns the lockout expiry (unix ts) if either dimension is locked, return login_limiter.locked_until(("user", username.lower()), ("ip", ip))
else None. Lazily reaps expired lockouts."""
now = _time.time()
soonest = None
for key in (("user", username.lower()), ("ip", ip)):
exp = _login_lockouts.get(key)
if exp is None:
continue
if now >= exp:
del _login_lockouts[key]
continue
soonest = exp if soonest is None else min(soonest, exp)
return soonest
def register_login_attempt(username: str, ip: str) -> str: def register_login_attempt(username: str, ip: str) -> str:
"""Atomic check-then-increment for a login attempt. return login_limiter.register(("user", username.lower()), ("ip", ip))
Returns:
"ok" allowed, counter incremented
"locked_out" already locked from a prior attempt
"now_locked_out" THIS attempt is what tripped the lockout
The increment runs synchronously (no awaits) so concurrent requests
can't slip past the threshold in CPython's single-threaded asyncio
loop. Caller must invoke clear_login_failures() on successful auth
to roll back this attempt's contribution.
"""
now = _time.time()
# Check existing lockouts first; if already locked, don't even
# increment — the lockout window absorbs everything.
for key in (("user", username.lower()), ("ip", ip)):
exp = _login_lockouts.get(key)
if exp is None:
continue
if now >= exp:
del _login_lockouts[key]
continue
return "locked_out"
# Increment + arm lockout if this push crosses the threshold.
tripped = False
for key in (("user", username.lower()), ("ip", ip)):
_gc_failures(key)
_login_failures.setdefault(key, []).append(now)
if len(_login_failures[key]) >= LOGIN_FAILURE_THRESHOLD:
_login_lockouts[key] = now + LOGIN_LOCKOUT_SECONDS
_login_failures[key] = []
tripped = True
return "now_locked_out" if tripped else "ok"
def clear_login_failures(username: str, ip: str) -> None: def clear_login_failures(username: str, ip: str) -> None:
"""Erase counters AND any lockout for a successful auth — caller login_limiter.clear(("user", username.lower()), ("ip", ip))
proved they have credentials, so the brute-force ladder resets."""
for key in (("user", username.lower()), ("ip", ip)):
_login_failures.pop(key, None)
_login_lockouts.pop(key, None)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -83,7 +83,7 @@ class Settings(BaseSettings):
ssh_key: str = "" # PEM private key content (paste full key including headers) ssh_key: str = "" # PEM private key content (paste full key including headers)
# Application version — used by the /api/v1/updates/check endpoint # Application version — used by the /api/v1/updates/check endpoint
app_version: str = "1.0.0-32" app_version: str = "1.0.0-34"
# ---- Authentication (1.0.0-22) ---- # ---- Authentication (1.0.0-22) ----
# session_secret: HMAC key for signing session cookies. Empty = generate # session_secret: HMAC key for signing session cookies. Empty = generate

View file

@ -2,12 +2,11 @@ import asyncio
import csv import csv
import io import io
import json import json
import time as _time
from datetime import datetime, timezone from datetime import datetime, timezone
import aiosqlite import aiosqlite
from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse from fastapi.responses import HTMLResponse, StreamingResponse
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app import auth, burnin, mailer, poller, settings_store from app import auth, burnin, mailer, poller, settings_store
@ -21,8 +20,35 @@ from app.models import (
) )
from app.renderer import templates from app.renderer import templates
# Helpers shared with the extracted sub-routers — keep the underscore-
# prefixed local names that existing in-file callers reach for.
from ._helpers import (
client_ip as _client_ip,
is_stale as _is_stale,
operator_for as _operator_for,
secret_status as _secret_status,
stale_context as _stale_context,
SECRET_FIELDS as _SECRET_FIELDS,
)
router = APIRouter() router = APIRouter()
# Sub-routers extracted as part of the routes/ package split (1.0.0-34).
# Their endpoints get registered against the same APIRouter, so the
# external `from app.routes import router` import in app/main.py keeps
# working unchanged. Future slices can extract more — drives, burnin,
# settings, history — using the same pattern.
#
# Absolute imports (vs `from . import auth`) because the line-12
# `from app import auth` binds `auth` as an attribute on this package's
# namespace, which would shadow the relative-submodule lookup and yield
# `app.auth` instead of `app.routes.auth`.
import app.routes.auth as _auth_routes # noqa: E402
import app.routes.system as _system_routes # noqa: E402
router.include_router(_auth_routes.router)
router.include_router(_system_routes.router)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal helpers # Internal helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -40,14 +66,7 @@ def _eta_seconds(eta_at: str | None) -> int | None:
return None return None
def _is_stale(last_polled_at: str) -> bool: # _is_stale is now imported from ._helpers above.
try:
last = datetime.fromisoformat(last_polled_at)
if last.tzinfo is None:
last = last.replace(tzinfo=timezone.utc)
return (datetime.now(timezone.utc) - last).total_seconds() > settings.stale_threshold_seconds
except Exception:
return True
def _compute_eta_seconds(started_at: str | None, percent: int) -> int | None: def _compute_eta_seconds(started_at: str | None, percent: int) -> int | None:
@ -219,162 +238,9 @@ async def _fetch_drives_for_template(db: aiosqlite.Connection) -> list[dict]:
return drives return drives
def _stale_context(poller_state: dict) -> dict: # _stale_context is now imported from ._helpers above.
last = poller_state.get("last_poll_at")
if not last:
return {"stale": False, "stale_seconds": 0}
try:
dt = datetime.fromisoformat(last)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
elapsed = int((datetime.now(timezone.utc) - dt).total_seconds())
stale = elapsed > settings.stale_threshold_seconds
return {"stale": stale, "stale_seconds": elapsed}
except Exception:
return {"stale": False, "stale_seconds": 0}
# ---------------------------------------------------------------------------
# Auth — login / logout / first-user setup
# ---------------------------------------------------------------------------
@router.get("/login", response_class=HTMLResponse)
async def login_page(request: Request, next: str = "/", error: str | None = None):
needs_setup = (await auth.user_count()) == 0
return templates.TemplateResponse(request, "login.html", {
"request": request,
"needs_setup": needs_setup,
"error": error,
"next": next if next.startswith("/") else "/",
})
def _client_ip(request: Request) -> str:
fwd = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip()
return fwd or (request.client.host if request.client else "unknown")
@router.post("/login")
async def login_submit(request: Request):
form = await request.form()
username = (form.get("username") or "").strip()
password = form.get("password") or ""
next_url = form.get("next") or "/"
if not next_url.startswith("/"):
next_url = "/"
ip = _client_ip(request)
# Atomic register-and-check: increments the counter NOW (before any
# await), so a parallel burst of guesses can't all slip past the
# threshold. Cleared on successful auth via clear_login_failures.
attempt = auth.register_login_attempt(username, ip)
if attempt != "ok":
if attempt == "now_locked_out":
await auth.audit_auth_event(
"user_login_locked_out", username,
f"Failed login from {ip} — IP/user locked out for {auth.LOGIN_LOCKOUT_SECONDS // 60} min",
)
locked_until = auth.login_locked_until(username, ip)
remaining = int((locked_until or _time.time()) - _time.time())
return templates.TemplateResponse(request, "login.html", {
"request": request,
"needs_setup": False,
"error": f"Too many failed attempts. Try again in {remaining // 60 + 1} min.",
"next": next_url,
}, status_code=429)
found = await auth.get_user_by_username(username)
if not found or not auth.verify_password(password, found[1]):
# Constant-ish-time: still call verify on a junk hash if user missing
# so the timing of "user not found" matches "wrong password."
if not found:
auth.verify_password(password, "$2b$12$" + "x" * 53)
await auth.audit_auth_event(
"user_login_failed", username, f"Failed login from {ip}",
)
return templates.TemplateResponse(request, "login.html", {
"request": request,
"needs_setup": False,
"error": "Invalid username or password.",
"next": next_url,
}, status_code=401)
user = found[0]
auth.clear_login_failures(username, ip)
# Clear any pre-login session keys before populating the new identity.
# Closes session-fixation: if an attacker had somehow seeded the
# browser with a session cookie, this discards everything in it
# before issuing the new authenticated payload.
request.session.clear()
request.session["user_id"] = user.id
request.session["username"] = user.username
await auth.touch_last_login(user.id)
await auth.audit_auth_event(
"user_login", user.username, f"Signed in from {ip}"
)
return RedirectResponse(url=next_url, status_code=303)
@router.post("/api/v1/auth/setup")
async def auth_first_user_setup(request: Request):
"""Create the first admin from the login page when the users table is
empty. Public endpoint but only does anything when zero users exist."""
if (await auth.user_count()) > 0:
raise HTTPException(status_code=409, detail="Users already exist.")
form = await request.form()
username = (form.get("username") or "").strip()
password = form.get("password") or ""
full_name = (form.get("full_name") or "").strip() or None
try:
# bootstrap_only=True wraps the existence check + insert in an
# IMMEDIATE transaction so two concurrent setup requests can't
# both create admin accounts during the bootstrap window.
user = await auth.create_user(
username, password, full_name, is_admin=True, bootstrap_only=True
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
# Same fixation defense as the login flow — discard any pre-existing
# session payload before issuing the authenticated identity.
request.session.clear()
request.session["user_id"] = user.id
request.session["username"] = user.username
await auth.touch_last_login(user.id)
return RedirectResponse(url="/", status_code=303)
@router.get("/logout")
@router.post("/logout")
async def logout(request: Request):
user = request.state.current_user if hasattr(request.state, "current_user") else None
if user:
await auth.audit_auth_event(
"user_logout", user.username, f"Signed out from {_client_ip(request)}"
)
request.session.clear()
return RedirectResponse(url="/login", status_code=303)
@router.post("/api/v1/auth/change-password")
async def change_password(request: Request):
user = request.state.current_user if hasattr(request.state, "current_user") else None
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
form = await request.form()
current = form.get("current_password") or ""
new_pw = form.get("new_password") or ""
confirm = form.get("confirm_password") or ""
if new_pw != confirm:
raise HTTPException(status_code=400, detail="New passwords do not match.")
try:
await auth.change_password(user.id, current, new_pw)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
await auth.audit_auth_event(
"user_password_changed", user.username,
f"Password changed from {_client_ip(request)}",
)
return {"ok": True}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -459,83 +325,6 @@ async def sse_drives(request: Request):
# JSON API # JSON API
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/health")
async def health(db: aiosqlite.Connection = Depends(get_db)):
"""Real readiness check, not just process-is-running.
Verifies (a) DB writable, (b) poller has succeeded recently relative
to the configured stale_threshold_seconds, (c) SSH reachable when
configured. Returns 503 when any check fails so a proxy/orchestrator
health probe can take the container out of rotation.
"""
from datetime import datetime, timezone
from fastapi.responses import JSONResponse
from app import ssh_client as _ssh
checks: dict[str, dict] = {}
# DB probe — actually exercise the write path (read-only mounts,
# full disks, broken WAL all silently pass a journal_mode read).
# Uses a temp table that lives only inside the connection so the
# round-trip touches the writer without polluting real data.
try:
await db.execute(
"CREATE TEMP TABLE IF NOT EXISTS _hc (k INTEGER PRIMARY KEY, v TEXT)"
)
await db.execute("INSERT OR REPLACE INTO _hc (k, v) VALUES (1, ?)",
(datetime.now(timezone.utc).isoformat(),))
cur = await db.execute("SELECT v FROM _hc WHERE k=1")
row = await cur.fetchone()
await db.commit()
checks["db"] = {"ok": bool(row)}
except Exception as exc:
checks["db"] = {"ok": False, "error": str(exc)}
ps = poller.get_state()
last = ps.get("last_poll_at")
poll_age = None
if last:
try:
t = datetime.fromisoformat(last)
if t.tzinfo is None:
t = t.replace(tzinfo=timezone.utc)
poll_age = (datetime.now(timezone.utc) - t).total_seconds()
except Exception:
poll_age = None
poll_ok = ps.get("healthy") and (
poll_age is None or poll_age <= settings.stale_threshold_seconds * 3
)
checks["poller"] = {
"ok": bool(poll_ok),
"last_error": ps.get("last_error"),
"last_poll_at": last,
"age_seconds": int(poll_age) if poll_age is not None else None,
}
# SSH probe — only when configured. Cheap (single sensors -j).
if _ssh.is_configured():
try:
r = await _ssh.test_connection()
checks["ssh"] = {"ok": bool(r.get("ok")),
"error": r.get("error")}
except Exception as exc:
checks["ssh"] = {"ok": False, "error": str(exc)}
else:
checks["ssh"] = {"ok": True, "skipped": True}
cur = await db.execute("SELECT COUNT(*) FROM drives")
row = await cur.fetchone()
drives_tracked = row[0] if row else 0
status_ok = all(c["ok"] for c in checks.values())
body = {
"status": "ok" if status_ok else "degraded",
"checks": checks,
"drives_tracked": drives_tracked,
"poll_interval_s": settings.poll_interval_seconds,
"version": settings.app_version,
}
return JSONResponse(body, status_code=200 if status_ok else 503)
@router.get("/api/v1/drives", response_model=list[DriveResponse]) @router.get("/api/v1/drives", response_model=list[DriveResponse])
@ -797,14 +586,7 @@ def _row_to_burnin(row: aiosqlite.Row, stages: list[aiosqlite.Row]) -> BurninJob
) )
def _operator_for(request: Request, _ignored_body_value: str | None = None) -> str: # _operator_for is now imported from ._helpers above.
"""Always return the logged-in user's name for audit attribution.
The request body's `operator` field (if any) is ignored — clients
can't spoof the operator identity any more."""
user = getattr(request.state, "current_user", None)
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user.full_name or user.username
@router.post("/api/v1/burnin/start") @router.post("/api/v1/burnin/start")
@ -838,12 +620,24 @@ async def burnin_start(request: Request, req: StartBurninRequest):
@router.post("/api/v1/drives/{drive_id}/unlock") @router.post("/api/v1/drives/{drive_id}/unlock")
async def unlock_pool_drive(drive_id: int, request: Request, req: UnlockPoolDriveRequest): async def unlock_pool_drive(drive_id: int, request: Request, req: UnlockPoolDriveRequest):
operator = _operator_for(request, req.operator) operator = _operator_for(request, req.operator)
ip = _client_ip(request)
# Rate-limit by drive AND by source IP. A typo on the confirm token
# is the common case so the threshold is loose, but a brute-force
# attempt to guess the token still hits the IP cap.
keys = (("drive", drive_id), ("ip", ip))
attempt = auth.unlock_limiter.register(*keys)
if attempt != "ok":
raise HTTPException(
status_code=429,
detail="Too many unlock attempts on this drive. Try again later.",
)
try: try:
expiry = await burnin.grant_pool_unlock( expiry = await burnin.grant_pool_unlock(
drive_id, req.confirm_token, operator, req.reason, drive_id, req.confirm_token, operator, req.reason,
) )
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) raise HTTPException(status_code=400, detail=str(exc))
auth.unlock_limiter.clear(*keys)
return {"unlocked": True, "expires_at": expiry, return {"unlocked": True, "expires_at": expiry,
# Read from the submodule, not the package-root snapshot # Read from the submodule, not the package-root snapshot
# alias — keeps tests that monkey-patch UNLOCK_TTL_SECONDS # alias — keeps tests that monkey-patch UNLOCK_TTL_SECONDS
@ -1333,42 +1127,7 @@ async def settings_page(
}) })
# Field names that hold secrets and must never be rendered to the UI or # _SECRET_FIELDS and _secret_status are now imported from ._helpers above.
# included in the redacted-settings dump verbatim.
_SECRET_FIELDS = ("smtp_password", "ssh_password", "ssh_key", "truenas_api_key")
def _secret_status() -> dict[str, str]:
"""Per-secret display string for the settings page so the operator can
see whether each secret is configured (and how) without ever rendering
the value. Distinguishes env-var, mounted-file, and DB-stored sources
for ssh_key the others can only come from the live settings object."""
import os as _os
from app.ssh_client import _MOUNTED_KEY_PATH
def _has(field: str) -> bool:
v = getattr(settings, field, "")
return bool(v)
# ssh_key gets the most granular treatment because we actively prefer
# the mounted file path in production but the textarea is still wired.
if _os.environ.get("SSH_KEY"):
ssh_key_status = "set (environment variable)"
elif _has("ssh_key"):
ssh_key_status = "set (stored in settings DB — prefer a mounted secret in production)"
elif _os.path.exists(
_os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)
):
ssh_key_status = "set (mounted secret)"
else:
ssh_key_status = "unset"
return {
"smtp_password": "set" if _has("smtp_password") else "unset",
"ssh_password": "set" if _has("ssh_password") else "unset",
"ssh_key": ssh_key_status,
"truenas_api_key": "set" if _has("truenas_api_key") else "unset",
}
@router.get("/api/v1/settings/redacted") @router.get("/api/v1/settings/redacted")
@ -1445,43 +1204,6 @@ async def test_ssh(request: Request):
return {"ok": True} return {"ok": True}
@router.websocket("/ws/terminal")
async def terminal_ws(websocket: WebSocket):
"""WebSocket endpoint bridging the browser xterm.js terminal to an SSH PTY."""
from app import terminal as _term
await _term.handle(websocket)
@router.get("/api/v1/updates/check")
async def check_updates():
"""Check for a newer release on Forgejo."""
import httpx
current = settings.app_version
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(
"https://git.hellocomputer.xyz/api/v1/repos/brandon/truenas-burnin/releases/latest",
headers={"Accept": "application/json"},
)
if r.status_code == 200:
data = r.json()
latest = data.get("tag_name", "").lstrip("v")
up_to_date = not latest or latest == current
return {
"current": current,
"latest": latest or None,
"update_available": not up_to_date,
"message": None,
}
elif r.status_code == 404:
return {"current": current, "latest": None, "update_available": False,
"message": "No releases published yet"}
else:
return {"current": current, "latest": None, "update_available": False,
"message": f"Forgejo API returned {r.status_code}"}
except Exception as exc:
return {"current": current, "latest": None, "update_available": False,
"message": f"Could not reach update server: {exc}"}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

97
app/routes/_helpers.py Normal file
View file

@ -0,0 +1,97 @@
"""Shared helpers used across multiple route modules.
Anything more than one route file needs lives here. Single-use helpers
stay in their owning route module.
"""
from __future__ import annotations
from datetime import datetime, timezone
from fastapi import HTTPException, Request
from app.config import settings
def client_ip(request: Request) -> str:
"""Best-effort source IP. Trusts X-Forwarded-For when present (we
sit behind nginx-proxy-manager) but falls back to the direct peer."""
fwd = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip()
return fwd or (request.client.host if request.client else "unknown")
def operator_for(request: Request, _ignored_body_value: str | None = None) -> str:
"""Always return the logged-in user's name for audit attribution.
The request body's `operator` field (if any) is ignored — clients
can't spoof the operator identity any more."""
user = getattr(request.state, "current_user", None)
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user.full_name or user.username
def is_stale(last_polled_at: str) -> bool:
"""True if the most recent poll is older than the stale threshold."""
try:
last = datetime.fromisoformat(last_polled_at)
if last.tzinfo is None:
last = last.replace(tzinfo=timezone.utc)
return (datetime.now(timezone.utc) - last).total_seconds() > settings.stale_threshold_seconds
except Exception:
return True
def stale_context(ps: dict) -> dict:
"""Returns the {stale, stale_seconds} dict every HTML page passes
to the layout for the warning banner."""
last = ps.get("last_poll_at")
if not last:
return {"stale": False, "stale_seconds": 0}
try:
t = datetime.fromisoformat(last)
if t.tzinfo is None:
t = t.replace(tzinfo=timezone.utc)
age = (datetime.now(timezone.utc) - t).total_seconds()
return {
"stale": age > settings.stale_threshold_seconds,
"stale_seconds": int(age),
}
except Exception:
return {"stale": False, "stale_seconds": 0}
# Field names that hold secrets and must never be rendered to the UI
# verbatim or included in the redacted-settings dump.
SECRET_FIELDS = ("smtp_password", "ssh_password", "ssh_key", "truenas_api_key")
def secret_status() -> dict[str, str]:
"""Per-secret display string for the settings page so the operator
can see whether each secret is configured (and how) without ever
rendering the value. Distinguishes env-var, mounted-file, and
DB-stored sources for ssh_key the others can only come from the
live settings object."""
import os as _os
from app.ssh_client import _MOUNTED_KEY_PATH
def _has(field: str) -> bool:
v = getattr(settings, field, "")
return bool(v)
if _os.environ.get("SSH_KEY"):
ssh_key_status = "set (environment variable)"
elif _has("ssh_key"):
ssh_key_status = "set (stored in settings DB — prefer a mounted secret in production)"
elif _os.path.exists(
_os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)
):
ssh_key_status = "set (mounted secret)"
else:
ssh_key_status = "unset"
return {
"smtp_password": "set" if _has("smtp_password") else "unset",
"ssh_password": "set" if _has("ssh_password") else "unset",
"ssh_key": ssh_key_status,
"truenas_api_key": "set" if _has("truenas_api_key") else "unset",
}

170
app/routes/auth.py Normal file
View file

@ -0,0 +1,170 @@
"""Login / logout / first-user setup / password change routes.
Public path mounting:
GET /login render login or first-user setup form
POST /login credential check + session bootstrap
POST /api/v1/auth/setup first-user creation (only when zero users)
GET /logout clear session, redirect
POST /logout same, for explicit POST clients
POST /api/v1/auth/change-password rotate password + audit
"""
from __future__ import annotations
import time as _time
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from app import auth
from app.renderer import templates
from ._helpers import client_ip
router = APIRouter()
@router.get("/login", response_class=HTMLResponse)
async def login_page(request: Request, next: str = "/", error: str | None = None):
needs_setup = (await auth.user_count()) == 0
return templates.TemplateResponse(request, "login.html", {
"request": request,
"needs_setup": needs_setup,
"error": error,
"next": next if next.startswith("/") else "/",
})
@router.post("/login")
async def login_submit(request: Request):
form = await request.form()
username = (form.get("username") or "").strip()
password = form.get("password") or ""
next_url = form.get("next") or "/"
if not next_url.startswith("/"):
next_url = "/"
ip = client_ip(request)
# Atomic register-and-check: increments the counter NOW (before any
# await), so a parallel burst of guesses can't all slip past the
# threshold. Cleared on successful auth via clear_login_failures.
attempt = auth.register_login_attempt(username, ip)
if attempt != "ok":
if attempt == "now_locked_out":
await auth.audit_auth_event(
"user_login_locked_out", username,
f"Failed login from {ip} — IP/user locked out for {auth.LOGIN_LOCKOUT_SECONDS // 60} min",
)
locked_until = auth.login_locked_until(username, ip)
remaining = int((locked_until or _time.time()) - _time.time())
return templates.TemplateResponse(request, "login.html", {
"request": request,
"needs_setup": False,
"error": f"Too many failed attempts. Try again in {remaining // 60 + 1} min.",
"next": next_url,
}, status_code=429)
found = await auth.get_user_by_username(username)
if not found or not auth.verify_password(password, found[1]):
# Constant-ish-time: still call verify on a junk hash if user missing
# so the timing of "user not found" matches "wrong password."
if not found:
auth.verify_password(password, "$2b$12$" + "x" * 53)
await auth.audit_auth_event(
"user_login_failed", username, f"Failed login from {ip}",
)
return templates.TemplateResponse(request, "login.html", {
"request": request,
"needs_setup": False,
"error": "Invalid username or password.",
"next": next_url,
}, status_code=401)
user = found[0]
auth.clear_login_failures(username, ip)
# Clear any pre-login session keys before populating the new identity.
# Closes session-fixation: if an attacker had somehow seeded the
# browser with a session cookie, this discards everything in it
# before issuing the new authenticated payload.
request.session.clear()
request.session["user_id"] = user.id
request.session["username"] = user.username
await auth.touch_last_login(user.id)
await auth.audit_auth_event(
"user_login", user.username, f"Signed in from {ip}",
)
return RedirectResponse(url=next_url, status_code=303)
@router.post("/api/v1/auth/setup")
async def auth_first_user_setup(request: Request):
"""Create the first admin from the login page when the users table is
empty. Public endpoint but only does anything when zero users exist."""
if (await auth.user_count()) > 0:
raise HTTPException(status_code=409, detail="Users already exist.")
form = await request.form()
username = (form.get("username") or "").strip()
password = form.get("password") or ""
full_name = (form.get("full_name") or "").strip() or None
try:
# bootstrap_only=True wraps the existence check + insert in an
# IMMEDIATE transaction so two concurrent setup requests can't
# both create admin accounts during the bootstrap window.
user = await auth.create_user(
username, password, full_name, is_admin=True, bootstrap_only=True
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
# Same fixation defense as the login flow — discard any pre-existing
# session payload before issuing the authenticated identity.
request.session.clear()
request.session["user_id"] = user.id
request.session["username"] = user.username
await auth.touch_last_login(user.id)
return RedirectResponse(url="/", status_code=303)
@router.get("/logout")
@router.post("/logout")
async def logout(request: Request):
user = request.state.current_user if hasattr(request.state, "current_user") else None
if user:
await auth.audit_auth_event(
"user_logout", user.username, f"Signed out from {client_ip(request)}",
)
request.session.clear()
return RedirectResponse(url="/login", status_code=303)
@router.post("/api/v1/auth/change-password")
async def change_password(request: Request):
user = request.state.current_user if hasattr(request.state, "current_user") else None
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
ip = client_ip(request)
# Rate-limit before bcrypt to keep an attacker-controlled session
# from burning CPU brute-forcing the current_password field.
keys = (("user", user.username.lower()), ("ip", ip))
attempt = auth.pwchange_limiter.register(*keys)
if attempt != "ok":
raise HTTPException(
status_code=429,
detail="Too many password-change attempts. Try again later.",
)
form = await request.form()
current = form.get("current_password") or ""
new_pw = form.get("new_password") or ""
confirm = form.get("confirm_password") or ""
if new_pw != confirm:
raise HTTPException(status_code=400, detail="New passwords do not match.")
try:
await auth.change_password(user.id, current, new_pw)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
auth.pwchange_limiter.clear(*keys)
await auth.audit_auth_event(
"user_password_changed", user.username,
f"Password changed from {ip}",
)
return {"ok": True}

136
app/routes/system.py Normal file
View file

@ -0,0 +1,136 @@
"""System-level endpoints with no business-logic dependencies.
GET /health readiness probe (DB write + poller + SSH)
GET /api/v1/updates/check check Forgejo for newer release
WS /ws/terminal xterm.js bridge to TrueNAS SSH PTY
"""
from __future__ import annotations
from datetime import datetime, timezone
import aiosqlite
from fastapi import APIRouter, Depends, WebSocket
from fastapi.responses import JSONResponse
from app import poller
from app.config import settings
from app.database import get_db
router = APIRouter()
@router.get("/health")
async def health(db: aiosqlite.Connection = Depends(get_db)):
"""Real readiness check, not just process-is-running.
Verifies (a) DB writable, (b) poller has succeeded recently relative
to the configured stale_threshold_seconds, (c) SSH reachable when
configured. Returns 503 when any check fails so a proxy/orchestrator
health probe can take the container out of rotation.
"""
from app import ssh_client as _ssh
checks: dict[str, dict] = {}
# DB probe — actually exercise the write path (read-only mounts,
# full disks, broken WAL all silently pass a journal_mode read).
# Uses a temp table that lives only inside the connection so the
# round-trip touches the writer without polluting real data.
try:
await db.execute(
"CREATE TEMP TABLE IF NOT EXISTS _hc (k INTEGER PRIMARY KEY, v TEXT)"
)
await db.execute("INSERT OR REPLACE INTO _hc (k, v) VALUES (1, ?)",
(datetime.now(timezone.utc).isoformat(),))
cur = await db.execute("SELECT v FROM _hc WHERE k=1")
row = await cur.fetchone()
await db.commit()
checks["db"] = {"ok": bool(row)}
except Exception as exc:
checks["db"] = {"ok": False, "error": str(exc)}
ps = poller.get_state()
last = ps.get("last_poll_at")
poll_age = None
if last:
try:
t = datetime.fromisoformat(last)
if t.tzinfo is None:
t = t.replace(tzinfo=timezone.utc)
poll_age = (datetime.now(timezone.utc) - t).total_seconds()
except Exception:
poll_age = None
poll_ok = ps.get("healthy") and (
poll_age is None or poll_age <= settings.stale_threshold_seconds * 3
)
checks["poller"] = {
"ok": bool(poll_ok),
"last_error": ps.get("last_error"),
"last_poll_at": last,
"age_seconds": int(poll_age) if poll_age is not None else None,
}
# SSH probe — only when configured. Cheap (single sensors -j).
if _ssh.is_configured():
try:
r = await _ssh.test_connection()
checks["ssh"] = {"ok": bool(r.get("ok")),
"error": r.get("error")}
except Exception as exc:
checks["ssh"] = {"ok": False, "error": str(exc)}
else:
checks["ssh"] = {"ok": True, "skipped": True}
cur = await db.execute("SELECT COUNT(*) FROM drives")
row = await cur.fetchone()
drives_tracked = row[0] if row else 0
status_ok = all(c["ok"] for c in checks.values())
body = {
"status": "ok" if status_ok else "degraded",
"checks": checks,
"drives_tracked": drives_tracked,
"poll_interval_s": settings.poll_interval_seconds,
"version": settings.app_version,
}
return JSONResponse(body, status_code=200 if status_ok else 503)
@router.websocket("/ws/terminal")
async def terminal_ws(websocket: WebSocket):
"""WebSocket endpoint bridging the browser xterm.js terminal to an SSH PTY."""
from app import terminal as _term
await _term.handle(websocket)
@router.get("/api/v1/updates/check")
async def check_updates():
"""Check for a newer release on Forgejo."""
import httpx
current = settings.app_version
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(
"https://git.hellocomputer.xyz/api/v1/repos/brandon/truenas-burnin/releases/latest",
headers={"Accept": "application/json"},
)
if r.status_code == 200:
data = r.json()
latest = data.get("tag_name", "").lstrip("v")
up_to_date = not latest or latest == current
return {
"current": current,
"latest": latest or None,
"update_available": not up_to_date,
"message": None,
}
elif r.status_code == 404:
return {"current": current, "latest": None, "update_available": False,
"message": "No releases published yet"}
else:
return {"current": current, "latest": None, "update_available": False,
"message": f"Forgejo API returned {r.status_code}"}
except Exception as exc:
return {"current": current, "latest": None, "update_available": False,
"message": f"Could not reach update server: {exc}"}

View file

@ -87,6 +87,20 @@ docker run --rm \
BANDITS=$? BANDITS=$?
echo " exit=$BANDITS ($OUT_DIR/bandit.txt)" | tee -a "$OUT_DIR/summary.txt" echo " exit=$BANDITS ($OUT_DIR/bandit.txt)" | tee -a "$OUT_DIR/summary.txt"
# --- mypy against the deploy dir (informational only) -------------------
# Type checker — surfaces None-handling bugs and missing-attribute errors
# the runtime would have caught at the worst possible moment. Doesn't
# count toward the failure exit-code sum until the codebase is annotated
# enough to make findings actionable.
echo "--- mypy (informational) ---" | tee -a "$OUT_DIR/summary.txt"
docker run --rm \
-v "$DEPLOY_DIR/app:/src:ro" \
python:3.12-slim sh -c \
"pip install --quiet --no-cache-dir --disable-pip-version-check mypy 2>&1 | tail -3 && mypy --ignore-missing-imports --no-strict-optional /src" \
> "$OUT_DIR/mypy.txt" 2>&1
MYPY=$?
echo " exit=$MYPY ($OUT_DIR/mypy.txt) — informational only" | tee -a "$OUT_DIR/summary.txt"
# --- gitleaks against the full git history ------------------------------ # --- gitleaks against the full git history ------------------------------
echo "--- gitleaks ---" | tee -a "$OUT_DIR/summary.txt" echo "--- gitleaks ---" | tee -a "$OUT_DIR/summary.txt"
docker run --rm \ docker run --rm \

328
tests/test_lifecycle.py Normal file
View file

@ -0,0 +1,328 @@
"""Burn-in lifecycle tests covering the DB helpers in burnin._common,
plus the public surface of start_job + cancel_job that doesn't require
spinning up _run_job (which would need a mocked TrueNASClient + SSH).
These are the safety net Codex flagged was missing the orchestration
paths were entirely untested before this. Run inside the container
image so app deps (aiosqlite, pydantic-settings, bcrypt) are present.
"""
from __future__ import annotations
import os
import tempfile
import unittest
import aiosqlite
async def _setup_temp_db() -> str:
"""Same pattern as test_unlock_flow.py — temp DB + init_db, returning
the path. Caller must unlink in tearDown."""
fd, path = tempfile.mkstemp(suffix=".db")
os.close(fd)
from app.config import settings
settings.db_path = path
from app.database import init_db
await init_db()
# Seed two drives so start_job has something to attach to.
async with aiosqlite.connect(path) as db:
await db.execute("""
INSERT INTO drives
(truenas_disk_id, devname, serial, model, size_bytes,
temperature_c, smart_health, last_seen_at, last_polled_at)
VALUES ('id-1', 'sda', 'SER1', 'TestModel', 1000, 30, 'PASSED',
'2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00')
""")
await db.execute("""
INSERT INTO drives
(truenas_disk_id, devname, serial, model, size_bytes,
temperature_c, smart_health, last_seen_at, last_polled_at)
VALUES ('id-2', 'sdb', 'SER2', 'TestModel', 1000, 30, 'PASSED',
'2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00')
""")
await db.commit()
return path
class TestCommonHelpers(unittest.IsolatedAsyncioTestCase):
"""The per-stage DB mutators in app.burnin._common — pure SQLite
writes, no asyncssh, no orchestration. Trivially regression-testable."""
async def asyncSetUp(self):
self.db_path = await _setup_temp_db()
# Insert a queued job + 2 stages we can mutate.
async with aiosqlite.connect(self.db_path) as db:
cur = await db.execute(
"""INSERT INTO burnin_jobs
(drive_id, profile, state, percent, operator, created_at)
VALUES (?,?,?,?,?,?) RETURNING id""",
(1, "full", "running", 0, "test", "2026-05-03T00:00:00+00:00"),
)
self.job_id = (await cur.fetchone())[0]
for stage_name in ("precheck", "surface_validate", "final_check"):
await db.execute(
"INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)",
(self.job_id, stage_name, "pending"),
)
await db.commit()
async def asyncTearDown(self):
try:
os.unlink(self.db_path)
except OSError:
pass
async def test_start_stage_marks_running(self):
from app.burnin import _common
await _common._start_stage(self.job_id, "precheck")
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT state, started_at FROM burnin_stages "
"WHERE burnin_job_id=? AND stage_name='precheck'",
(self.job_id,),
)
row = await cur.fetchone()
self.assertEqual(row["state"], "running")
self.assertIsNotNone(row["started_at"])
async def test_finish_stage_success_records_duration(self):
from app.burnin import _common
await _common._start_stage(self.job_id, "precheck")
await _common._finish_stage(self.job_id, "precheck", success=True)
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT state, percent, duration_seconds FROM burnin_stages "
"WHERE burnin_job_id=? AND stage_name='precheck'",
(self.job_id,),
)
row = await cur.fetchone()
self.assertEqual(row["state"], "passed")
self.assertEqual(row["percent"], 100)
# Duration is float seconds since started_at — should be tiny but >0.
self.assertIsNotNone(row["duration_seconds"])
self.assertGreaterEqual(row["duration_seconds"], 0)
async def test_finish_stage_failure_carries_error_text(self):
from app.burnin import _common
await _common._start_stage(self.job_id, "surface_validate")
await _common._finish_stage(
self.job_id, "surface_validate",
success=False, error_text="mock failure",
)
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT state, percent, error_text FROM burnin_stages "
"WHERE burnin_job_id=? AND stage_name='surface_validate'",
(self.job_id,),
)
row = await cur.fetchone()
self.assertEqual(row["state"], "failed")
self.assertIsNone(row["percent"])
self.assertEqual(row["error_text"], "mock failure")
async def test_finish_stage_preserves_existing_error(self):
"""When called with error_text=None, the existing column value
from _set_stage_error must be preserved (not overwritten with NULL).
This is the bug that 1.0.0-12-ish fixed."""
from app.burnin import _common
await _common._start_stage(self.job_id, "surface_validate")
await _common._set_stage_error(
self.job_id, "surface_validate", "set by stage",
)
await _common._finish_stage(
self.job_id, "surface_validate", success=False, error_text=None,
)
async with aiosqlite.connect(self.db_path) as db:
cur = await db.execute(
"SELECT error_text FROM burnin_stages "
"WHERE burnin_job_id=? AND stage_name='surface_validate'",
(self.job_id,),
)
row = await cur.fetchone()
self.assertEqual(row[0], "set by stage")
async def test_recalculate_progress_weights_correctly(self):
from app.burnin import _common
# Mark precheck passed, surface_validate at 50% running.
await _common._start_stage(self.job_id, "precheck")
await _common._finish_stage(self.job_id, "precheck", success=True)
await _common._start_stage(self.job_id, "surface_validate")
await _common._update_stage_percent(self.job_id, "surface_validate", 50)
await _common._recalculate_progress(self.job_id)
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT percent, stage_name FROM burnin_jobs WHERE id=?",
(self.job_id,),
)
row = await cur.fetchone()
# Weights: precheck=5, surface=65, final=5. Total = 75 across these
# 3 stages. Completed = 5 (precheck) + 32.5 (half of 65) = 37.5.
# 37.5 / 75 = 50%.
self.assertEqual(row["percent"], 50)
self.assertEqual(row["stage_name"], "surface_validate")
async def test_is_cancelled_reads_job_state(self):
from app.burnin import _common
self.assertFalse(await _common._is_cancelled(self.job_id))
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"UPDATE burnin_jobs SET state='cancelled' WHERE id=?",
(self.job_id,),
)
await db.commit()
self.assertTrue(await _common._is_cancelled(self.job_id))
async def test_append_stage_log_concatenates(self):
from app.burnin import _common
await _common._append_stage_log(self.job_id, "precheck", "alpha\n")
await _common._append_stage_log(self.job_id, "precheck", "beta\n")
async with aiosqlite.connect(self.db_path) as db:
cur = await db.execute(
"SELECT log_text FROM burnin_stages "
"WHERE burnin_job_id=? AND stage_name='precheck'",
(self.job_id,),
)
row = await cur.fetchone()
self.assertEqual(row[0], "alpha\nbeta\n")
class TestStartCancelJob(unittest.IsolatedAsyncioTestCase):
"""start_job + cancel_job touch the burnin orchestrator state. We
spawn _run_job tasks that try to acquire the semaphore we cancel
immediately after to avoid running real burn-in stages. The real
test value here is "did start_job create the right DB rows" and
"does cancel_job mark them correctly."""
async def asyncSetUp(self):
self.db_path = await _setup_temp_db()
# Initialise burnin without a real TrueNASClient — pass None.
# _run_job will hit the assert at top, but the test cancels
# before _run_job's first await actually runs.
from app import burnin
burnin._unlock_grants.clear()
burnin._active_tasks.clear()
import asyncio
burnin._semaphore = asyncio.Semaphore(2)
burnin._client = None # unused by start_job itself
async def asyncTearDown(self):
# Cancel any outstanding tasks so they don't bleed into later tests.
from app import burnin
for t in list(burnin._active_tasks.values()):
t.cancel()
try:
os.unlink(self.db_path)
except OSError:
pass
async def test_start_job_inserts_queued_row_and_stages(self):
from app import burnin
job_id = await burnin.start_job(1, "surface", "test")
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT state, profile, operator FROM burnin_jobs WHERE id=?",
(job_id,),
)
row = await cur.fetchone()
cur = await db.execute(
"SELECT stage_name FROM burnin_stages "
"WHERE burnin_job_id=? ORDER BY id",
(job_id,),
)
stages = [r[0] for r in await cur.fetchall()]
# State should be queued OR running (the spawned _run_job may
# have raced into the semaphore by now).
self.assertIn(row["state"], ("queued", "running"))
self.assertEqual(row["profile"], "surface")
self.assertEqual(row["operator"], "test")
# surface profile = precheck + surface_validate + final_check.
self.assertEqual(stages, ["precheck", "surface_validate", "final_check"])
async def test_start_job_rejects_duplicate_active(self):
from app import burnin
await burnin.start_job(1, "surface", "test")
# Second start on the same drive should be refused at the
# ValueError level (caught by the inline duplicate check or by
# the partial unique index).
with self.assertRaises(ValueError):
await burnin.start_job(1, "surface", "test")
async def test_cancel_job_marks_state(self):
from app import burnin
job_id = await burnin.start_job(1, "surface", "test")
ok = await burnin.cancel_job(job_id, "test")
self.assertTrue(ok)
async with aiosqlite.connect(self.db_path) as db:
cur = await db.execute(
"SELECT state FROM burnin_jobs WHERE id=?", (job_id,)
)
row = await cur.fetchone()
self.assertEqual(row[0], "cancelled")
async def test_cancel_job_returns_false_for_terminal_state(self):
from app import burnin
# Create a passed job manually
async with aiosqlite.connect(self.db_path) as db:
cur = await db.execute(
"""INSERT INTO burnin_jobs
(drive_id, profile, state, operator, created_at)
VALUES (?,?,?,?,?) RETURNING id""",
(2, "surface", "passed", "x", "2026-05-03T00:00:00+00:00"),
)
job_id = (await cur.fetchone())[0]
await db.commit()
ok = await burnin.cancel_job(job_id, "test")
self.assertFalse(ok)
class TestRateLimiter(unittest.TestCase):
"""The generic rate-limit class added in 1.0.0-33 for the
unlock + password-change endpoints."""
def test_register_allows_under_threshold(self):
from app.auth import _RateLimiter
rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60)
self.assertEqual(rl.register(("k", "alice")), "ok")
self.assertEqual(rl.register(("k", "alice")), "ok")
def test_register_trips_at_threshold(self):
from app.auth import _RateLimiter
rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60)
self.assertEqual(rl.register(("k", "alice")), "ok")
self.assertEqual(rl.register(("k", "alice")), "ok")
# 3rd attempt brings us to threshold — counts as the trip.
self.assertEqual(rl.register(("k", "alice")), "now_locked_out")
# 4th sees the lockout from the prior call.
self.assertEqual(rl.register(("k", "alice")), "locked_out")
def test_clear_removes_counter_and_lockout(self):
from app.auth import _RateLimiter
rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60)
rl.register(("k", "alice"))
rl.register(("k", "alice")) # trips
self.assertIsNotNone(rl.locked_until(("k", "alice")))
rl.clear(("k", "alice"))
self.assertIsNone(rl.locked_until(("k", "alice")))
# Subsequent register should start fresh.
self.assertEqual(rl.register(("k", "alice")), "ok")
def test_separate_keys_dont_interfere(self):
from app.auth import _RateLimiter
rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60)
rl.register(("k", "alice"))
rl.register(("k", "alice")) # trips alice
# Bob's attempt should be allowed and untouched by alice's lockout.
self.assertEqual(rl.register(("k", "bob")), "ok")
self.assertIsNone(rl.locked_until(("k", "bob")))
if __name__ == "__main__":
unittest.main()