feat: rate limiter + mypy + lifecycle tests + routes/ split (1.0.0-33/-34)
Closes the four remaining items from the post-Codex hardening list. #1 Rate-limit unlock + change-password endpoints (1.0.0-33) * Generalised the existing login limiter into a reusable `_RateLimiter` class in app/auth.py. Atomic check-then-increment in synchronous code so a parallel asyncio burst can't slip past the threshold. * `unlock_limiter` (5 attempts in 10 min → 10 min lockout) gates POST /api/v1/drives/{id}/unlock per-drive AND per-source-IP. * `pwchange_limiter` (5 in 10 min → 15 min lockout) gates POST /api/v1/auth/change-password per-user AND per-IP. * Both clear on successful operation. The login limiter keeps its existing `register_login_attempt` / `clear_login_failures` facade names so external callers don't change. #3 mypy in security-scan (1.0.0-33) * Added a 4th tool to the daily scan + forge workflow. Runs in a throwaway python:3.12-slim container against the deploy dir, exit code is informational only (NOT included in the `TOTAL_EXIT` failure sum). Findings land in ~/security-scans/scan-YYYY-MM-DD/mypy.txt for ratchet-down work over time. * Forge job uses `continue-on-error: true` so it doesn't fail the workflow until the type-debt baseline is annotated down. #4 Lifecycle test coverage (1.0.0-33) * New tests/test_lifecycle.py with 15 cases: - TestCommonHelpers (7 tests): _start_stage, _finish_stage success/failure/error-preservation, _recalculate_progress weighted math, _is_cancelled, _append_stage_log. - TestStartCancelJob (4 tests): start_job inserts queued row + correct stage list, duplicate-active rejection, cancel marks state, cancel returns False on terminal-state jobs. - TestRateLimiter (4 tests): under-threshold ok, trips at threshold, clear removes both counter + lockout, separate keys don't interfere. * Total goes from 44 to 59 tests; closes the orchestration-path coverage gap Codex flagged. #2 Partial routes.py split (1.0.0-34) * routes.py → routes/ package. Same staged-extraction pattern as the burnin.py split. * routes/auth.py — login/logout/setup/change-password (170 LoC). * routes/system.py — /health, /ws/terminal, /api/v1/updates/check (136 LoC). * routes/_helpers.py — shared utilities used by both extracted modules and the still-monolithic remainder: client_ip, operator_for, is_stale, stale_context, secret_status, SECRET_FIELDS (97 LoC). * routes/__init__.py shrank from 1568 LoC to 1261. Future slices can extract drives, burnin, history, settings the same way. * GOTCHA recorded in commit body: `from app import auth` at the top of __init__.py binds `auth` as an attribute on the package namespace, so `from . import auth as _auth_routes` finds the OUTER module and yields `app.auth` instead of the submodule. Fix is `import app.routes.auth as _auth_routes` (absolute). This bit me once at deploy time; container failed to start with `module 'app.auth' has no attribute 'router'`. Verification: 59/59 tests pass (44 existing + 15 new); container boots clean at 1.0.0-34; /health 200 with all checks green; security scan still clean (mypy informational findings ignored from totals). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
eb2a964171
commit
aa7822d6ce
9 changed files with 895 additions and 391 deletions
|
|
@ -59,3 +59,18 @@ jobs:
|
||||||
chmod +x gitleaks
|
chmod +x gitleaks
|
||||||
- name: Scan git history for secrets
|
- name: Scan git history for secrets
|
||||||
run: ./gitleaks detect --source . --no-banner --redact --verbose
|
run: ./gitleaks detect --source . --no-banner --redact --verbose
|
||||||
|
|
||||||
|
mypy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# Informational — does not fail the workflow. Use `continue-on-error`
|
||||||
|
# so the build stays green while we work down the type-debt baseline.
|
||||||
|
continue-on-error: true
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- name: Install mypy
|
||||||
|
run: pip install --upgrade mypy
|
||||||
|
- name: Type check
|
||||||
|
run: mypy --ignore-missing-imports --no-strict-optional app
|
||||||
|
|
|
||||||
128
app/auth.py
128
app/auth.py
|
|
@ -213,89 +213,111 @@ async def change_password(user_id: int, current_password: str,
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Login rate limiting (in-memory, per-username + per-source-IP)
|
# Generic rate limiting (in-memory, multi-key per category)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Each instance is a self-contained limiter for one category (login,
|
||||||
|
# unlock, password change). The atomicity guarantee is "no awaits between
|
||||||
|
# check and increment" — CPython's asyncio loop is single-threaded so
|
||||||
|
# concurrent requests cannot interleave the synchronous register() call.
|
||||||
|
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
LOGIN_FAILURE_WINDOW_SECONDS = 600 # 10 min
|
|
||||||
LOGIN_FAILURE_THRESHOLD = 10 # this many failures within the window
|
|
||||||
LOGIN_LOCKOUT_SECONDS = 900 # then block for 15 min
|
|
||||||
|
|
||||||
# {(key,): [(timestamp, ...), ...]} key = (kind, value), kind in {"user","ip"}
|
class _RateLimiter:
|
||||||
_login_failures: dict = {}
|
def __init__(self, name: str, threshold: int, window_s: int, lockout_s: int):
|
||||||
_login_lockouts: dict = {} # key -> unix expiry
|
self.name = name
|
||||||
|
self.threshold = threshold
|
||||||
|
self.window_s = window_s
|
||||||
|
self.lockout_s = lockout_s
|
||||||
|
self._failures: dict = {} # key -> [unix timestamps within window]
|
||||||
|
self._lockouts: dict = {} # key -> unix expiry
|
||||||
|
|
||||||
|
def _gc(self, key) -> None:
|
||||||
def _gc_failures(key) -> None:
|
cutoff = _time.time() - self.window_s
|
||||||
"""Drop failure timestamps older than the window."""
|
arr = self._failures.get(key, [])
|
||||||
arr = _login_failures.get(key, [])
|
|
||||||
cutoff = _time.time() - LOGIN_FAILURE_WINDOW_SECONDS
|
|
||||||
fresh = [t for t in arr if t >= cutoff]
|
fresh = [t for t in arr if t >= cutoff]
|
||||||
if fresh:
|
if fresh:
|
||||||
_login_failures[key] = fresh
|
self._failures[key] = fresh
|
||||||
elif key in _login_failures:
|
elif key in self._failures:
|
||||||
del _login_failures[key]
|
del self._failures[key]
|
||||||
|
|
||||||
|
def locked_until(self, *keys) -> float | None:
|
||||||
def login_locked_until(username: str, ip: str) -> float | None:
|
"""Soonest active lockout expiry across `keys`, or None."""
|
||||||
"""Returns the lockout expiry (unix ts) if either dimension is locked,
|
|
||||||
else None. Lazily reaps expired lockouts."""
|
|
||||||
now = _time.time()
|
now = _time.time()
|
||||||
soonest = None
|
soonest = None
|
||||||
for key in (("user", username.lower()), ("ip", ip)):
|
for k in keys:
|
||||||
exp = _login_lockouts.get(key)
|
exp = self._lockouts.get(k)
|
||||||
if exp is None:
|
if exp is None:
|
||||||
continue
|
continue
|
||||||
if now >= exp:
|
if now >= exp:
|
||||||
del _login_lockouts[key]
|
del self._lockouts[k]
|
||||||
continue
|
continue
|
||||||
soonest = exp if soonest is None else min(soonest, exp)
|
soonest = exp if soonest is None else min(soonest, exp)
|
||||||
return soonest
|
return soonest
|
||||||
|
|
||||||
|
def register(self, *keys) -> str:
|
||||||
def register_login_attempt(username: str, ip: str) -> str:
|
"""Returns "ok" | "locked_out" | "now_locked_out"."""
|
||||||
"""Atomic check-then-increment for a login attempt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"ok" — allowed, counter incremented
|
|
||||||
"locked_out" — already locked from a prior attempt
|
|
||||||
"now_locked_out" — THIS attempt is what tripped the lockout
|
|
||||||
|
|
||||||
The increment runs synchronously (no awaits) so concurrent requests
|
|
||||||
can't slip past the threshold in CPython's single-threaded asyncio
|
|
||||||
loop. Caller must invoke clear_login_failures() on successful auth
|
|
||||||
to roll back this attempt's contribution.
|
|
||||||
"""
|
|
||||||
now = _time.time()
|
now = _time.time()
|
||||||
# Check existing lockouts first; if already locked, don't even
|
for k in keys:
|
||||||
# increment — the lockout window absorbs everything.
|
exp = self._lockouts.get(k)
|
||||||
for key in (("user", username.lower()), ("ip", ip)):
|
|
||||||
exp = _login_lockouts.get(key)
|
|
||||||
if exp is None:
|
if exp is None:
|
||||||
continue
|
continue
|
||||||
if now >= exp:
|
if now >= exp:
|
||||||
del _login_lockouts[key]
|
del self._lockouts[k]
|
||||||
continue
|
continue
|
||||||
return "locked_out"
|
return "locked_out"
|
||||||
# Increment + arm lockout if this push crosses the threshold.
|
|
||||||
tripped = False
|
tripped = False
|
||||||
for key in (("user", username.lower()), ("ip", ip)):
|
for k in keys:
|
||||||
_gc_failures(key)
|
self._gc(k)
|
||||||
_login_failures.setdefault(key, []).append(now)
|
self._failures.setdefault(k, []).append(now)
|
||||||
if len(_login_failures[key]) >= LOGIN_FAILURE_THRESHOLD:
|
if len(self._failures[k]) >= self.threshold:
|
||||||
_login_lockouts[key] = now + LOGIN_LOCKOUT_SECONDS
|
self._lockouts[k] = now + self.lockout_s
|
||||||
_login_failures[key] = []
|
self._failures[k] = []
|
||||||
tripped = True
|
tripped = True
|
||||||
return "now_locked_out" if tripped else "ok"
|
return "now_locked_out" if tripped else "ok"
|
||||||
|
|
||||||
|
def clear(self, *keys) -> None:
|
||||||
|
for k in keys:
|
||||||
|
self._failures.pop(k, None)
|
||||||
|
self._lockouts.pop(k, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Login: 10 failures in 10 min → 15 min lockout.
|
||||||
|
LOGIN_FAILURE_WINDOW_SECONDS = 600
|
||||||
|
LOGIN_FAILURE_THRESHOLD = 10
|
||||||
|
LOGIN_LOCKOUT_SECONDS = 900
|
||||||
|
|
||||||
|
# Unlock + password change: tighter caps; both are post-auth so a
|
||||||
|
# legitimate operator typoing a token shouldn't be locked out for long.
|
||||||
|
UNLOCK_FAILURE_THRESHOLD = 5
|
||||||
|
UNLOCK_LOCKOUT_SECONDS = 600
|
||||||
|
PWCHANGE_FAILURE_THRESHOLD = 5
|
||||||
|
PWCHANGE_LOCKOUT_SECONDS = 900
|
||||||
|
|
||||||
|
login_limiter = _RateLimiter(
|
||||||
|
"login", LOGIN_FAILURE_THRESHOLD, LOGIN_FAILURE_WINDOW_SECONDS,
|
||||||
|
LOGIN_LOCKOUT_SECONDS,
|
||||||
|
)
|
||||||
|
unlock_limiter = _RateLimiter(
|
||||||
|
"unlock", UNLOCK_FAILURE_THRESHOLD, 600, UNLOCK_LOCKOUT_SECONDS,
|
||||||
|
)
|
||||||
|
pwchange_limiter = _RateLimiter(
|
||||||
|
"pwchange", PWCHANGE_FAILURE_THRESHOLD, 600, PWCHANGE_LOCKOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward-compat facades — preserve the names existing routes.py reaches for.
|
||||||
|
def login_locked_until(username: str, ip: str) -> float | None:
|
||||||
|
return login_limiter.locked_until(("user", username.lower()), ("ip", ip))
|
||||||
|
|
||||||
|
|
||||||
|
def register_login_attempt(username: str, ip: str) -> str:
|
||||||
|
return login_limiter.register(("user", username.lower()), ("ip", ip))
|
||||||
|
|
||||||
|
|
||||||
def clear_login_failures(username: str, ip: str) -> None:
|
def clear_login_failures(username: str, ip: str) -> None:
|
||||||
"""Erase counters AND any lockout for a successful auth — caller
|
login_limiter.clear(("user", username.lower()), ("ip", ip))
|
||||||
proved they have credentials, so the brute-force ladder resets."""
|
|
||||||
for key in (("user", username.lower()), ("ip", ip)):
|
|
||||||
_login_failures.pop(key, None)
|
|
||||||
_login_lockouts.pop(key, None)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class Settings(BaseSettings):
|
||||||
ssh_key: str = "" # PEM private key content (paste full key including headers)
|
ssh_key: str = "" # PEM private key content (paste full key including headers)
|
||||||
|
|
||||||
# Application version — used by the /api/v1/updates/check endpoint
|
# Application version — used by the /api/v1/updates/check endpoint
|
||||||
app_version: str = "1.0.0-32"
|
app_version: str = "1.0.0-34"
|
||||||
|
|
||||||
# ---- Authentication (1.0.0-22) ----
|
# ---- Authentication (1.0.0-22) ----
|
||||||
# session_secret: HMAC key for signing session cookies. Empty = generate
|
# session_secret: HMAC key for signing session cookies. Empty = generate
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,11 @@ import asyncio
|
||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import time as _time
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
from app import auth, burnin, mailer, poller, settings_store
|
from app import auth, burnin, mailer, poller, settings_store
|
||||||
|
|
@ -21,8 +20,35 @@ from app.models import (
|
||||||
)
|
)
|
||||||
from app.renderer import templates
|
from app.renderer import templates
|
||||||
|
|
||||||
|
# Helpers shared with the extracted sub-routers — keep the underscore-
|
||||||
|
# prefixed local names that existing in-file callers reach for.
|
||||||
|
from ._helpers import (
|
||||||
|
client_ip as _client_ip,
|
||||||
|
is_stale as _is_stale,
|
||||||
|
operator_for as _operator_for,
|
||||||
|
secret_status as _secret_status,
|
||||||
|
stale_context as _stale_context,
|
||||||
|
SECRET_FIELDS as _SECRET_FIELDS,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Sub-routers extracted as part of the routes/ package split (1.0.0-34).
|
||||||
|
# Their endpoints get registered against the same APIRouter, so the
|
||||||
|
# external `from app.routes import router` import in app/main.py keeps
|
||||||
|
# working unchanged. Future slices can extract more — drives, burnin,
|
||||||
|
# settings, history — using the same pattern.
|
||||||
|
#
|
||||||
|
# Absolute imports (vs `from . import auth`) because the line-12
|
||||||
|
# `from app import auth` binds `auth` as an attribute on this package's
|
||||||
|
# namespace, which would shadow the relative-submodule lookup and yield
|
||||||
|
# `app.auth` instead of `app.routes.auth`.
|
||||||
|
import app.routes.auth as _auth_routes # noqa: E402
|
||||||
|
import app.routes.system as _system_routes # noqa: E402
|
||||||
|
|
||||||
|
router.include_router(_auth_routes.router)
|
||||||
|
router.include_router(_system_routes.router)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -40,14 +66,7 @@ def _eta_seconds(eta_at: str | None) -> int | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _is_stale(last_polled_at: str) -> bool:
|
# _is_stale is now imported from ._helpers above.
|
||||||
try:
|
|
||||||
last = datetime.fromisoformat(last_polled_at)
|
|
||||||
if last.tzinfo is None:
|
|
||||||
last = last.replace(tzinfo=timezone.utc)
|
|
||||||
return (datetime.now(timezone.utc) - last).total_seconds() > settings.stale_threshold_seconds
|
|
||||||
except Exception:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_eta_seconds(started_at: str | None, percent: int) -> int | None:
|
def _compute_eta_seconds(started_at: str | None, percent: int) -> int | None:
|
||||||
|
|
@ -219,162 +238,9 @@ async def _fetch_drives_for_template(db: aiosqlite.Connection) -> list[dict]:
|
||||||
return drives
|
return drives
|
||||||
|
|
||||||
|
|
||||||
def _stale_context(poller_state: dict) -> dict:
|
# _stale_context is now imported from ._helpers above.
|
||||||
last = poller_state.get("last_poll_at")
|
|
||||||
if not last:
|
|
||||||
return {"stale": False, "stale_seconds": 0}
|
|
||||||
try:
|
|
||||||
dt = datetime.fromisoformat(last)
|
|
||||||
if dt.tzinfo is None:
|
|
||||||
dt = dt.replace(tzinfo=timezone.utc)
|
|
||||||
elapsed = int((datetime.now(timezone.utc) - dt).total_seconds())
|
|
||||||
stale = elapsed > settings.stale_threshold_seconds
|
|
||||||
return {"stale": stale, "stale_seconds": elapsed}
|
|
||||||
except Exception:
|
|
||||||
return {"stale": False, "stale_seconds": 0}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Auth — login / logout / first-user setup
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/login", response_class=HTMLResponse)
|
|
||||||
async def login_page(request: Request, next: str = "/", error: str | None = None):
|
|
||||||
needs_setup = (await auth.user_count()) == 0
|
|
||||||
return templates.TemplateResponse(request, "login.html", {
|
|
||||||
"request": request,
|
|
||||||
"needs_setup": needs_setup,
|
|
||||||
"error": error,
|
|
||||||
"next": next if next.startswith("/") else "/",
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def _client_ip(request: Request) -> str:
|
|
||||||
fwd = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip()
|
|
||||||
return fwd or (request.client.host if request.client else "unknown")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login")
|
|
||||||
async def login_submit(request: Request):
|
|
||||||
form = await request.form()
|
|
||||||
username = (form.get("username") or "").strip()
|
|
||||||
password = form.get("password") or ""
|
|
||||||
next_url = form.get("next") or "/"
|
|
||||||
if not next_url.startswith("/"):
|
|
||||||
next_url = "/"
|
|
||||||
ip = _client_ip(request)
|
|
||||||
|
|
||||||
# Atomic register-and-check: increments the counter NOW (before any
|
|
||||||
# await), so a parallel burst of guesses can't all slip past the
|
|
||||||
# threshold. Cleared on successful auth via clear_login_failures.
|
|
||||||
attempt = auth.register_login_attempt(username, ip)
|
|
||||||
if attempt != "ok":
|
|
||||||
if attempt == "now_locked_out":
|
|
||||||
await auth.audit_auth_event(
|
|
||||||
"user_login_locked_out", username,
|
|
||||||
f"Failed login from {ip} — IP/user locked out for {auth.LOGIN_LOCKOUT_SECONDS // 60} min",
|
|
||||||
)
|
|
||||||
locked_until = auth.login_locked_until(username, ip)
|
|
||||||
remaining = int((locked_until or _time.time()) - _time.time())
|
|
||||||
return templates.TemplateResponse(request, "login.html", {
|
|
||||||
"request": request,
|
|
||||||
"needs_setup": False,
|
|
||||||
"error": f"Too many failed attempts. Try again in {remaining // 60 + 1} min.",
|
|
||||||
"next": next_url,
|
|
||||||
}, status_code=429)
|
|
||||||
|
|
||||||
found = await auth.get_user_by_username(username)
|
|
||||||
if not found or not auth.verify_password(password, found[1]):
|
|
||||||
# Constant-ish-time: still call verify on a junk hash if user missing
|
|
||||||
# so the timing of "user not found" matches "wrong password."
|
|
||||||
if not found:
|
|
||||||
auth.verify_password(password, "$2b$12$" + "x" * 53)
|
|
||||||
await auth.audit_auth_event(
|
|
||||||
"user_login_failed", username, f"Failed login from {ip}",
|
|
||||||
)
|
|
||||||
return templates.TemplateResponse(request, "login.html", {
|
|
||||||
"request": request,
|
|
||||||
"needs_setup": False,
|
|
||||||
"error": "Invalid username or password.",
|
|
||||||
"next": next_url,
|
|
||||||
}, status_code=401)
|
|
||||||
|
|
||||||
user = found[0]
|
|
||||||
auth.clear_login_failures(username, ip)
|
|
||||||
# Clear any pre-login session keys before populating the new identity.
|
|
||||||
# Closes session-fixation: if an attacker had somehow seeded the
|
|
||||||
# browser with a session cookie, this discards everything in it
|
|
||||||
# before issuing the new authenticated payload.
|
|
||||||
request.session.clear()
|
|
||||||
request.session["user_id"] = user.id
|
|
||||||
request.session["username"] = user.username
|
|
||||||
await auth.touch_last_login(user.id)
|
|
||||||
await auth.audit_auth_event(
|
|
||||||
"user_login", user.username, f"Signed in from {ip}"
|
|
||||||
)
|
|
||||||
return RedirectResponse(url=next_url, status_code=303)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/v1/auth/setup")
|
|
||||||
async def auth_first_user_setup(request: Request):
|
|
||||||
"""Create the first admin from the login page when the users table is
|
|
||||||
empty. Public endpoint — but only does anything when zero users exist."""
|
|
||||||
if (await auth.user_count()) > 0:
|
|
||||||
raise HTTPException(status_code=409, detail="Users already exist.")
|
|
||||||
form = await request.form()
|
|
||||||
username = (form.get("username") or "").strip()
|
|
||||||
password = form.get("password") or ""
|
|
||||||
full_name = (form.get("full_name") or "").strip() or None
|
|
||||||
try:
|
|
||||||
# bootstrap_only=True wraps the existence check + insert in an
|
|
||||||
# IMMEDIATE transaction so two concurrent setup requests can't
|
|
||||||
# both create admin accounts during the bootstrap window.
|
|
||||||
user = await auth.create_user(
|
|
||||||
username, password, full_name, is_admin=True, bootstrap_only=True
|
|
||||||
)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(status_code=400, detail=str(exc))
|
|
||||||
# Same fixation defense as the login flow — discard any pre-existing
|
|
||||||
# session payload before issuing the authenticated identity.
|
|
||||||
request.session.clear()
|
|
||||||
request.session["user_id"] = user.id
|
|
||||||
request.session["username"] = user.username
|
|
||||||
await auth.touch_last_login(user.id)
|
|
||||||
return RedirectResponse(url="/", status_code=303)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/logout")
|
|
||||||
@router.post("/logout")
|
|
||||||
async def logout(request: Request):
|
|
||||||
user = request.state.current_user if hasattr(request.state, "current_user") else None
|
|
||||||
if user:
|
|
||||||
await auth.audit_auth_event(
|
|
||||||
"user_logout", user.username, f"Signed out from {_client_ip(request)}"
|
|
||||||
)
|
|
||||||
request.session.clear()
|
|
||||||
return RedirectResponse(url="/login", status_code=303)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/v1/auth/change-password")
|
|
||||||
async def change_password(request: Request):
|
|
||||||
user = request.state.current_user if hasattr(request.state, "current_user") else None
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
form = await request.form()
|
|
||||||
current = form.get("current_password") or ""
|
|
||||||
new_pw = form.get("new_password") or ""
|
|
||||||
confirm = form.get("confirm_password") or ""
|
|
||||||
if new_pw != confirm:
|
|
||||||
raise HTTPException(status_code=400, detail="New passwords do not match.")
|
|
||||||
try:
|
|
||||||
await auth.change_password(user.id, current, new_pw)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(status_code=400, detail=str(exc))
|
|
||||||
await auth.audit_auth_event(
|
|
||||||
"user_password_changed", user.username,
|
|
||||||
f"Password changed from {_client_ip(request)}",
|
|
||||||
)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -459,83 +325,6 @@ async def sse_drives(request: Request):
|
||||||
# JSON API
|
# JSON API
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/health")
|
|
||||||
async def health(db: aiosqlite.Connection = Depends(get_db)):
|
|
||||||
"""Real readiness check, not just process-is-running.
|
|
||||||
|
|
||||||
Verifies (a) DB writable, (b) poller has succeeded recently relative
|
|
||||||
to the configured stale_threshold_seconds, (c) SSH reachable when
|
|
||||||
configured. Returns 503 when any check fails so a proxy/orchestrator
|
|
||||||
health probe can take the container out of rotation.
|
|
||||||
"""
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from app import ssh_client as _ssh
|
|
||||||
|
|
||||||
checks: dict[str, dict] = {}
|
|
||||||
|
|
||||||
# DB probe — actually exercise the write path (read-only mounts,
|
|
||||||
# full disks, broken WAL all silently pass a journal_mode read).
|
|
||||||
# Uses a temp table that lives only inside the connection so the
|
|
||||||
# round-trip touches the writer without polluting real data.
|
|
||||||
try:
|
|
||||||
await db.execute(
|
|
||||||
"CREATE TEMP TABLE IF NOT EXISTS _hc (k INTEGER PRIMARY KEY, v TEXT)"
|
|
||||||
)
|
|
||||||
await db.execute("INSERT OR REPLACE INTO _hc (k, v) VALUES (1, ?)",
|
|
||||||
(datetime.now(timezone.utc).isoformat(),))
|
|
||||||
cur = await db.execute("SELECT v FROM _hc WHERE k=1")
|
|
||||||
row = await cur.fetchone()
|
|
||||||
await db.commit()
|
|
||||||
checks["db"] = {"ok": bool(row)}
|
|
||||||
except Exception as exc:
|
|
||||||
checks["db"] = {"ok": False, "error": str(exc)}
|
|
||||||
|
|
||||||
ps = poller.get_state()
|
|
||||||
last = ps.get("last_poll_at")
|
|
||||||
poll_age = None
|
|
||||||
if last:
|
|
||||||
try:
|
|
||||||
t = datetime.fromisoformat(last)
|
|
||||||
if t.tzinfo is None:
|
|
||||||
t = t.replace(tzinfo=timezone.utc)
|
|
||||||
poll_age = (datetime.now(timezone.utc) - t).total_seconds()
|
|
||||||
except Exception:
|
|
||||||
poll_age = None
|
|
||||||
poll_ok = ps.get("healthy") and (
|
|
||||||
poll_age is None or poll_age <= settings.stale_threshold_seconds * 3
|
|
||||||
)
|
|
||||||
checks["poller"] = {
|
|
||||||
"ok": bool(poll_ok),
|
|
||||||
"last_error": ps.get("last_error"),
|
|
||||||
"last_poll_at": last,
|
|
||||||
"age_seconds": int(poll_age) if poll_age is not None else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# SSH probe — only when configured. Cheap (single sensors -j).
|
|
||||||
if _ssh.is_configured():
|
|
||||||
try:
|
|
||||||
r = await _ssh.test_connection()
|
|
||||||
checks["ssh"] = {"ok": bool(r.get("ok")),
|
|
||||||
"error": r.get("error")}
|
|
||||||
except Exception as exc:
|
|
||||||
checks["ssh"] = {"ok": False, "error": str(exc)}
|
|
||||||
else:
|
|
||||||
checks["ssh"] = {"ok": True, "skipped": True}
|
|
||||||
|
|
||||||
cur = await db.execute("SELECT COUNT(*) FROM drives")
|
|
||||||
row = await cur.fetchone()
|
|
||||||
drives_tracked = row[0] if row else 0
|
|
||||||
|
|
||||||
status_ok = all(c["ok"] for c in checks.values())
|
|
||||||
body = {
|
|
||||||
"status": "ok" if status_ok else "degraded",
|
|
||||||
"checks": checks,
|
|
||||||
"drives_tracked": drives_tracked,
|
|
||||||
"poll_interval_s": settings.poll_interval_seconds,
|
|
||||||
"version": settings.app_version,
|
|
||||||
}
|
|
||||||
return JSONResponse(body, status_code=200 if status_ok else 503)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/v1/drives", response_model=list[DriveResponse])
|
@router.get("/api/v1/drives", response_model=list[DriveResponse])
|
||||||
|
|
@ -797,14 +586,7 @@ def _row_to_burnin(row: aiosqlite.Row, stages: list[aiosqlite.Row]) -> BurninJob
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _operator_for(request: Request, _ignored_body_value: str | None = None) -> str:
|
# _operator_for is now imported from ._helpers above.
|
||||||
"""Always return the logged-in user's name for audit attribution.
|
|
||||||
The request body's `operator` field (if any) is ignored — clients
|
|
||||||
can't spoof the operator identity any more."""
|
|
||||||
user = getattr(request.state, "current_user", None)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
return user.full_name or user.username
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/v1/burnin/start")
|
@router.post("/api/v1/burnin/start")
|
||||||
|
|
@ -838,12 +620,24 @@ async def burnin_start(request: Request, req: StartBurninRequest):
|
||||||
@router.post("/api/v1/drives/{drive_id}/unlock")
|
@router.post("/api/v1/drives/{drive_id}/unlock")
|
||||||
async def unlock_pool_drive(drive_id: int, request: Request, req: UnlockPoolDriveRequest):
|
async def unlock_pool_drive(drive_id: int, request: Request, req: UnlockPoolDriveRequest):
|
||||||
operator = _operator_for(request, req.operator)
|
operator = _operator_for(request, req.operator)
|
||||||
|
ip = _client_ip(request)
|
||||||
|
# Rate-limit by drive AND by source IP. A typo on the confirm token
|
||||||
|
# is the common case so the threshold is loose, but a brute-force
|
||||||
|
# attempt to guess the token still hits the IP cap.
|
||||||
|
keys = (("drive", drive_id), ("ip", ip))
|
||||||
|
attempt = auth.unlock_limiter.register(*keys)
|
||||||
|
if attempt != "ok":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Too many unlock attempts on this drive. Try again later.",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
expiry = await burnin.grant_pool_unlock(
|
expiry = await burnin.grant_pool_unlock(
|
||||||
drive_id, req.confirm_token, operator, req.reason,
|
drive_id, req.confirm_token, operator, req.reason,
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=400, detail=str(exc))
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
auth.unlock_limiter.clear(*keys)
|
||||||
return {"unlocked": True, "expires_at": expiry,
|
return {"unlocked": True, "expires_at": expiry,
|
||||||
# Read from the submodule, not the package-root snapshot
|
# Read from the submodule, not the package-root snapshot
|
||||||
# alias — keeps tests that monkey-patch UNLOCK_TTL_SECONDS
|
# alias — keeps tests that monkey-patch UNLOCK_TTL_SECONDS
|
||||||
|
|
@ -1333,42 +1127,7 @@ async def settings_page(
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# Field names that hold secrets and must never be rendered to the UI or
|
# _SECRET_FIELDS and _secret_status are now imported from ._helpers above.
|
||||||
# included in the redacted-settings dump verbatim.
|
|
||||||
_SECRET_FIELDS = ("smtp_password", "ssh_password", "ssh_key", "truenas_api_key")
|
|
||||||
|
|
||||||
|
|
||||||
def _secret_status() -> dict[str, str]:
|
|
||||||
"""Per-secret display string for the settings page so the operator can
|
|
||||||
see whether each secret is configured (and how) without ever rendering
|
|
||||||
the value. Distinguishes env-var, mounted-file, and DB-stored sources
|
|
||||||
for ssh_key — the others can only come from the live settings object."""
|
|
||||||
import os as _os
|
|
||||||
from app.ssh_client import _MOUNTED_KEY_PATH
|
|
||||||
|
|
||||||
def _has(field: str) -> bool:
|
|
||||||
v = getattr(settings, field, "")
|
|
||||||
return bool(v)
|
|
||||||
|
|
||||||
# ssh_key gets the most granular treatment because we actively prefer
|
|
||||||
# the mounted file path in production but the textarea is still wired.
|
|
||||||
if _os.environ.get("SSH_KEY"):
|
|
||||||
ssh_key_status = "set (environment variable)"
|
|
||||||
elif _has("ssh_key"):
|
|
||||||
ssh_key_status = "set (stored in settings DB — prefer a mounted secret in production)"
|
|
||||||
elif _os.path.exists(
|
|
||||||
_os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)
|
|
||||||
):
|
|
||||||
ssh_key_status = "set (mounted secret)"
|
|
||||||
else:
|
|
||||||
ssh_key_status = "unset"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"smtp_password": "set" if _has("smtp_password") else "unset",
|
|
||||||
"ssh_password": "set" if _has("ssh_password") else "unset",
|
|
||||||
"ssh_key": ssh_key_status,
|
|
||||||
"truenas_api_key": "set" if _has("truenas_api_key") else "unset",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/v1/settings/redacted")
|
@router.get("/api/v1/settings/redacted")
|
||||||
|
|
@ -1445,43 +1204,6 @@ async def test_ssh(request: Request):
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/terminal")
|
|
||||||
async def terminal_ws(websocket: WebSocket):
|
|
||||||
"""WebSocket endpoint bridging the browser xterm.js terminal to an SSH PTY."""
|
|
||||||
from app import terminal as _term
|
|
||||||
await _term.handle(websocket)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/v1/updates/check")
|
|
||||||
async def check_updates():
|
|
||||||
"""Check for a newer release on Forgejo."""
|
|
||||||
import httpx
|
|
||||||
current = settings.app_version
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
|
||||||
r = await client.get(
|
|
||||||
"https://git.hellocomputer.xyz/api/v1/repos/brandon/truenas-burnin/releases/latest",
|
|
||||||
headers={"Accept": "application/json"},
|
|
||||||
)
|
|
||||||
if r.status_code == 200:
|
|
||||||
data = r.json()
|
|
||||||
latest = data.get("tag_name", "").lstrip("v")
|
|
||||||
up_to_date = not latest or latest == current
|
|
||||||
return {
|
|
||||||
"current": current,
|
|
||||||
"latest": latest or None,
|
|
||||||
"update_available": not up_to_date,
|
|
||||||
"message": None,
|
|
||||||
}
|
|
||||||
elif r.status_code == 404:
|
|
||||||
return {"current": current, "latest": None, "update_available": False,
|
|
||||||
"message": "No releases published yet"}
|
|
||||||
else:
|
|
||||||
return {"current": current, "latest": None, "update_available": False,
|
|
||||||
"message": f"Forgejo API returned {r.status_code}"}
|
|
||||||
except Exception as exc:
|
|
||||||
return {"current": current, "latest": None, "update_available": False,
|
|
||||||
"message": f"Could not reach update server: {exc}"}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
97
app/routes/_helpers.py
Normal file
97
app/routes/_helpers.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Shared helpers used across multiple route modules.
|
||||||
|
|
||||||
|
Anything more than one route file needs lives here. Single-use helpers
|
||||||
|
stay in their owning route module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def client_ip(request: Request) -> str:
|
||||||
|
"""Best-effort source IP. Trusts X-Forwarded-For when present (we
|
||||||
|
sit behind nginx-proxy-manager) but falls back to the direct peer."""
|
||||||
|
fwd = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip()
|
||||||
|
return fwd or (request.client.host if request.client else "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def operator_for(request: Request, _ignored_body_value: str | None = None) -> str:
|
||||||
|
"""Always return the logged-in user's name for audit attribution.
|
||||||
|
The request body's `operator` field (if any) is ignored — clients
|
||||||
|
can't spoof the operator identity any more."""
|
||||||
|
user = getattr(request.state, "current_user", None)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
return user.full_name or user.username
|
||||||
|
|
||||||
|
|
||||||
|
def is_stale(last_polled_at: str) -> bool:
|
||||||
|
"""True if the most recent poll is older than the stale threshold."""
|
||||||
|
try:
|
||||||
|
last = datetime.fromisoformat(last_polled_at)
|
||||||
|
if last.tzinfo is None:
|
||||||
|
last = last.replace(tzinfo=timezone.utc)
|
||||||
|
return (datetime.now(timezone.utc) - last).total_seconds() > settings.stale_threshold_seconds
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def stale_context(ps: dict) -> dict:
|
||||||
|
"""Returns the {stale, stale_seconds} dict every HTML page passes
|
||||||
|
to the layout for the warning banner."""
|
||||||
|
last = ps.get("last_poll_at")
|
||||||
|
if not last:
|
||||||
|
return {"stale": False, "stale_seconds": 0}
|
||||||
|
try:
|
||||||
|
t = datetime.fromisoformat(last)
|
||||||
|
if t.tzinfo is None:
|
||||||
|
t = t.replace(tzinfo=timezone.utc)
|
||||||
|
age = (datetime.now(timezone.utc) - t).total_seconds()
|
||||||
|
return {
|
||||||
|
"stale": age > settings.stale_threshold_seconds,
|
||||||
|
"stale_seconds": int(age),
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {"stale": False, "stale_seconds": 0}
|
||||||
|
|
||||||
|
|
||||||
|
# Field names that hold secrets and must never be rendered to the UI
|
||||||
|
# verbatim or included in the redacted-settings dump.
|
||||||
|
SECRET_FIELDS = ("smtp_password", "ssh_password", "ssh_key", "truenas_api_key")
|
||||||
|
|
||||||
|
|
||||||
|
def secret_status() -> dict[str, str]:
|
||||||
|
"""Per-secret display string for the settings page so the operator
|
||||||
|
can see whether each secret is configured (and how) without ever
|
||||||
|
rendering the value. Distinguishes env-var, mounted-file, and
|
||||||
|
DB-stored sources for ssh_key — the others can only come from the
|
||||||
|
live settings object."""
|
||||||
|
import os as _os
|
||||||
|
from app.ssh_client import _MOUNTED_KEY_PATH
|
||||||
|
|
||||||
|
def _has(field: str) -> bool:
|
||||||
|
v = getattr(settings, field, "")
|
||||||
|
return bool(v)
|
||||||
|
|
||||||
|
if _os.environ.get("SSH_KEY"):
|
||||||
|
ssh_key_status = "set (environment variable)"
|
||||||
|
elif _has("ssh_key"):
|
||||||
|
ssh_key_status = "set (stored in settings DB — prefer a mounted secret in production)"
|
||||||
|
elif _os.path.exists(
|
||||||
|
_os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)
|
||||||
|
):
|
||||||
|
ssh_key_status = "set (mounted secret)"
|
||||||
|
else:
|
||||||
|
ssh_key_status = "unset"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"smtp_password": "set" if _has("smtp_password") else "unset",
|
||||||
|
"ssh_password": "set" if _has("ssh_password") else "unset",
|
||||||
|
"ssh_key": ssh_key_status,
|
||||||
|
"truenas_api_key": "set" if _has("truenas_api_key") else "unset",
|
||||||
|
}
|
||||||
170
app/routes/auth.py
Normal file
170
app/routes/auth.py
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
"""Login / logout / first-user setup / password change routes.
|
||||||
|
|
||||||
|
Public path mounting:
|
||||||
|
GET /login — render login or first-user setup form
|
||||||
|
POST /login — credential check + session bootstrap
|
||||||
|
POST /api/v1/auth/setup — first-user creation (only when zero users)
|
||||||
|
GET /logout — clear session, redirect
|
||||||
|
POST /logout — same, for explicit POST clients
|
||||||
|
POST /api/v1/auth/change-password — rotate password + audit
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
|
|
||||||
|
from app import auth
|
||||||
|
from app.renderer import templates
|
||||||
|
|
||||||
|
from ._helpers import client_ip
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/login", response_class=HTMLResponse)
|
||||||
|
async def login_page(request: Request, next: str = "/", error: str | None = None):
|
||||||
|
needs_setup = (await auth.user_count()) == 0
|
||||||
|
return templates.TemplateResponse(request, "login.html", {
|
||||||
|
"request": request,
|
||||||
|
"needs_setup": needs_setup,
|
||||||
|
"error": error,
|
||||||
|
"next": next if next.startswith("/") else "/",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def login_submit(request: Request):
|
||||||
|
form = await request.form()
|
||||||
|
username = (form.get("username") or "").strip()
|
||||||
|
password = form.get("password") or ""
|
||||||
|
next_url = form.get("next") or "/"
|
||||||
|
if not next_url.startswith("/"):
|
||||||
|
next_url = "/"
|
||||||
|
ip = client_ip(request)
|
||||||
|
|
||||||
|
# Atomic register-and-check: increments the counter NOW (before any
|
||||||
|
# await), so a parallel burst of guesses can't all slip past the
|
||||||
|
# threshold. Cleared on successful auth via clear_login_failures.
|
||||||
|
attempt = auth.register_login_attempt(username, ip)
|
||||||
|
if attempt != "ok":
|
||||||
|
if attempt == "now_locked_out":
|
||||||
|
await auth.audit_auth_event(
|
||||||
|
"user_login_locked_out", username,
|
||||||
|
f"Failed login from {ip} — IP/user locked out for {auth.LOGIN_LOCKOUT_SECONDS // 60} min",
|
||||||
|
)
|
||||||
|
locked_until = auth.login_locked_until(username, ip)
|
||||||
|
remaining = int((locked_until or _time.time()) - _time.time())
|
||||||
|
return templates.TemplateResponse(request, "login.html", {
|
||||||
|
"request": request,
|
||||||
|
"needs_setup": False,
|
||||||
|
"error": f"Too many failed attempts. Try again in {remaining // 60 + 1} min.",
|
||||||
|
"next": next_url,
|
||||||
|
}, status_code=429)
|
||||||
|
|
||||||
|
found = await auth.get_user_by_username(username)
|
||||||
|
if not found or not auth.verify_password(password, found[1]):
|
||||||
|
# Constant-ish-time: still call verify on a junk hash if user missing
|
||||||
|
# so the timing of "user not found" matches "wrong password."
|
||||||
|
if not found:
|
||||||
|
auth.verify_password(password, "$2b$12$" + "x" * 53)
|
||||||
|
await auth.audit_auth_event(
|
||||||
|
"user_login_failed", username, f"Failed login from {ip}",
|
||||||
|
)
|
||||||
|
return templates.TemplateResponse(request, "login.html", {
|
||||||
|
"request": request,
|
||||||
|
"needs_setup": False,
|
||||||
|
"error": "Invalid username or password.",
|
||||||
|
"next": next_url,
|
||||||
|
}, status_code=401)
|
||||||
|
|
||||||
|
user = found[0]
|
||||||
|
auth.clear_login_failures(username, ip)
|
||||||
|
# Clear any pre-login session keys before populating the new identity.
|
||||||
|
# Closes session-fixation: if an attacker had somehow seeded the
|
||||||
|
# browser with a session cookie, this discards everything in it
|
||||||
|
# before issuing the new authenticated payload.
|
||||||
|
request.session.clear()
|
||||||
|
request.session["user_id"] = user.id
|
||||||
|
request.session["username"] = user.username
|
||||||
|
await auth.touch_last_login(user.id)
|
||||||
|
await auth.audit_auth_event(
|
||||||
|
"user_login", user.username, f"Signed in from {ip}",
|
||||||
|
)
|
||||||
|
return RedirectResponse(url=next_url, status_code=303)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/v1/auth/setup")
|
||||||
|
async def auth_first_user_setup(request: Request):
|
||||||
|
"""Create the first admin from the login page when the users table is
|
||||||
|
empty. Public endpoint — but only does anything when zero users exist."""
|
||||||
|
if (await auth.user_count()) > 0:
|
||||||
|
raise HTTPException(status_code=409, detail="Users already exist.")
|
||||||
|
form = await request.form()
|
||||||
|
username = (form.get("username") or "").strip()
|
||||||
|
password = form.get("password") or ""
|
||||||
|
full_name = (form.get("full_name") or "").strip() or None
|
||||||
|
try:
|
||||||
|
# bootstrap_only=True wraps the existence check + insert in an
|
||||||
|
# IMMEDIATE transaction so two concurrent setup requests can't
|
||||||
|
# both create admin accounts during the bootstrap window.
|
||||||
|
user = await auth.create_user(
|
||||||
|
username, password, full_name, is_admin=True, bootstrap_only=True
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
# Same fixation defense as the login flow — discard any pre-existing
|
||||||
|
# session payload before issuing the authenticated identity.
|
||||||
|
request.session.clear()
|
||||||
|
request.session["user_id"] = user.id
|
||||||
|
request.session["username"] = user.username
|
||||||
|
await auth.touch_last_login(user.id)
|
||||||
|
return RedirectResponse(url="/", status_code=303)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logout")
|
||||||
|
@router.post("/logout")
|
||||||
|
async def logout(request: Request):
|
||||||
|
user = request.state.current_user if hasattr(request.state, "current_user") else None
|
||||||
|
if user:
|
||||||
|
await auth.audit_auth_event(
|
||||||
|
"user_logout", user.username, f"Signed out from {client_ip(request)}",
|
||||||
|
)
|
||||||
|
request.session.clear()
|
||||||
|
return RedirectResponse(url="/login", status_code=303)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/v1/auth/change-password")
|
||||||
|
async def change_password(request: Request):
|
||||||
|
user = request.state.current_user if hasattr(request.state, "current_user") else None
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
ip = client_ip(request)
|
||||||
|
# Rate-limit before bcrypt to keep an attacker-controlled session
|
||||||
|
# from burning CPU brute-forcing the current_password field.
|
||||||
|
keys = (("user", user.username.lower()), ("ip", ip))
|
||||||
|
attempt = auth.pwchange_limiter.register(*keys)
|
||||||
|
if attempt != "ok":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Too many password-change attempts. Try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
form = await request.form()
|
||||||
|
current = form.get("current_password") or ""
|
||||||
|
new_pw = form.get("new_password") or ""
|
||||||
|
confirm = form.get("confirm_password") or ""
|
||||||
|
if new_pw != confirm:
|
||||||
|
raise HTTPException(status_code=400, detail="New passwords do not match.")
|
||||||
|
try:
|
||||||
|
await auth.change_password(user.id, current, new_pw)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
auth.pwchange_limiter.clear(*keys)
|
||||||
|
await auth.audit_auth_event(
|
||||||
|
"user_password_changed", user.username,
|
||||||
|
f"Password changed from {ip}",
|
||||||
|
)
|
||||||
|
return {"ok": True}
|
||||||
136
app/routes/system.py
Normal file
136
app/routes/system.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""System-level endpoints with no business-logic dependencies.
|
||||||
|
|
||||||
|
GET /health — readiness probe (DB write + poller + SSH)
|
||||||
|
GET /api/v1/updates/check — check Forgejo for newer release
|
||||||
|
WS /ws/terminal — xterm.js bridge to TrueNAS SSH PTY
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from fastapi import APIRouter, Depends, WebSocket
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app import poller
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health(db: aiosqlite.Connection = Depends(get_db)):
|
||||||
|
"""Real readiness check, not just process-is-running.
|
||||||
|
|
||||||
|
Verifies (a) DB writable, (b) poller has succeeded recently relative
|
||||||
|
to the configured stale_threshold_seconds, (c) SSH reachable when
|
||||||
|
configured. Returns 503 when any check fails so a proxy/orchestrator
|
||||||
|
health probe can take the container out of rotation.
|
||||||
|
"""
|
||||||
|
from app import ssh_client as _ssh
|
||||||
|
|
||||||
|
checks: dict[str, dict] = {}
|
||||||
|
|
||||||
|
# DB probe — actually exercise the write path (read-only mounts,
|
||||||
|
# full disks, broken WAL all silently pass a journal_mode read).
|
||||||
|
# Uses a temp table that lives only inside the connection so the
|
||||||
|
# round-trip touches the writer without polluting real data.
|
||||||
|
try:
|
||||||
|
await db.execute(
|
||||||
|
"CREATE TEMP TABLE IF NOT EXISTS _hc (k INTEGER PRIMARY KEY, v TEXT)"
|
||||||
|
)
|
||||||
|
await db.execute("INSERT OR REPLACE INTO _hc (k, v) VALUES (1, ?)",
|
||||||
|
(datetime.now(timezone.utc).isoformat(),))
|
||||||
|
cur = await db.execute("SELECT v FROM _hc WHERE k=1")
|
||||||
|
row = await cur.fetchone()
|
||||||
|
await db.commit()
|
||||||
|
checks["db"] = {"ok": bool(row)}
|
||||||
|
except Exception as exc:
|
||||||
|
checks["db"] = {"ok": False, "error": str(exc)}
|
||||||
|
|
||||||
|
ps = poller.get_state()
|
||||||
|
last = ps.get("last_poll_at")
|
||||||
|
poll_age = None
|
||||||
|
if last:
|
||||||
|
try:
|
||||||
|
t = datetime.fromisoformat(last)
|
||||||
|
if t.tzinfo is None:
|
||||||
|
t = t.replace(tzinfo=timezone.utc)
|
||||||
|
poll_age = (datetime.now(timezone.utc) - t).total_seconds()
|
||||||
|
except Exception:
|
||||||
|
poll_age = None
|
||||||
|
poll_ok = ps.get("healthy") and (
|
||||||
|
poll_age is None or poll_age <= settings.stale_threshold_seconds * 3
|
||||||
|
)
|
||||||
|
checks["poller"] = {
|
||||||
|
"ok": bool(poll_ok),
|
||||||
|
"last_error": ps.get("last_error"),
|
||||||
|
"last_poll_at": last,
|
||||||
|
"age_seconds": int(poll_age) if poll_age is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# SSH probe — only when configured. Cheap (single sensors -j).
|
||||||
|
if _ssh.is_configured():
|
||||||
|
try:
|
||||||
|
r = await _ssh.test_connection()
|
||||||
|
checks["ssh"] = {"ok": bool(r.get("ok")),
|
||||||
|
"error": r.get("error")}
|
||||||
|
except Exception as exc:
|
||||||
|
checks["ssh"] = {"ok": False, "error": str(exc)}
|
||||||
|
else:
|
||||||
|
checks["ssh"] = {"ok": True, "skipped": True}
|
||||||
|
|
||||||
|
cur = await db.execute("SELECT COUNT(*) FROM drives")
|
||||||
|
row = await cur.fetchone()
|
||||||
|
drives_tracked = row[0] if row else 0
|
||||||
|
|
||||||
|
status_ok = all(c["ok"] for c in checks.values())
|
||||||
|
body = {
|
||||||
|
"status": "ok" if status_ok else "degraded",
|
||||||
|
"checks": checks,
|
||||||
|
"drives_tracked": drives_tracked,
|
||||||
|
"poll_interval_s": settings.poll_interval_seconds,
|
||||||
|
"version": settings.app_version,
|
||||||
|
}
|
||||||
|
return JSONResponse(body, status_code=200 if status_ok else 503)
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/ws/terminal")
|
||||||
|
async def terminal_ws(websocket: WebSocket):
|
||||||
|
"""WebSocket endpoint bridging the browser xterm.js terminal to an SSH PTY."""
|
||||||
|
from app import terminal as _term
|
||||||
|
await _term.handle(websocket)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/v1/updates/check")
|
||||||
|
async def check_updates():
|
||||||
|
"""Check for a newer release on Forgejo."""
|
||||||
|
import httpx
|
||||||
|
current = settings.app_version
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||||
|
r = await client.get(
|
||||||
|
"https://git.hellocomputer.xyz/api/v1/repos/brandon/truenas-burnin/releases/latest",
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
if r.status_code == 200:
|
||||||
|
data = r.json()
|
||||||
|
latest = data.get("tag_name", "").lstrip("v")
|
||||||
|
up_to_date = not latest or latest == current
|
||||||
|
return {
|
||||||
|
"current": current,
|
||||||
|
"latest": latest or None,
|
||||||
|
"update_available": not up_to_date,
|
||||||
|
"message": None,
|
||||||
|
}
|
||||||
|
elif r.status_code == 404:
|
||||||
|
return {"current": current, "latest": None, "update_available": False,
|
||||||
|
"message": "No releases published yet"}
|
||||||
|
else:
|
||||||
|
return {"current": current, "latest": None, "update_available": False,
|
||||||
|
"message": f"Forgejo API returned {r.status_code}"}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"current": current, "latest": None, "update_available": False,
|
||||||
|
"message": f"Could not reach update server: {exc}"}
|
||||||
|
|
@ -87,6 +87,20 @@ docker run --rm \
|
||||||
BANDITS=$?
|
BANDITS=$?
|
||||||
echo " exit=$BANDITS ($OUT_DIR/bandit.txt)" | tee -a "$OUT_DIR/summary.txt"
|
echo " exit=$BANDITS ($OUT_DIR/bandit.txt)" | tee -a "$OUT_DIR/summary.txt"
|
||||||
|
|
||||||
|
# --- mypy against the deploy dir (informational only) -------------------
|
||||||
|
# Type checker — surfaces None-handling bugs and missing-attribute errors
|
||||||
|
# the runtime would have caught at the worst possible moment. Doesn't
|
||||||
|
# count toward the failure exit-code sum until the codebase is annotated
|
||||||
|
# enough to make findings actionable.
|
||||||
|
echo "--- mypy (informational) ---" | tee -a "$OUT_DIR/summary.txt"
|
||||||
|
docker run --rm \
|
||||||
|
-v "$DEPLOY_DIR/app:/src:ro" \
|
||||||
|
python:3.12-slim sh -c \
|
||||||
|
"pip install --quiet --no-cache-dir --disable-pip-version-check mypy 2>&1 | tail -3 && mypy --ignore-missing-imports --no-strict-optional /src" \
|
||||||
|
> "$OUT_DIR/mypy.txt" 2>&1
|
||||||
|
MYPY=$?
|
||||||
|
echo " exit=$MYPY ($OUT_DIR/mypy.txt) — informational only" | tee -a "$OUT_DIR/summary.txt"
|
||||||
|
|
||||||
# --- gitleaks against the full git history ------------------------------
|
# --- gitleaks against the full git history ------------------------------
|
||||||
echo "--- gitleaks ---" | tee -a "$OUT_DIR/summary.txt"
|
echo "--- gitleaks ---" | tee -a "$OUT_DIR/summary.txt"
|
||||||
docker run --rm \
|
docker run --rm \
|
||||||
|
|
|
||||||
328
tests/test_lifecycle.py
Normal file
328
tests/test_lifecycle.py
Normal file
|
|
@ -0,0 +1,328 @@
|
||||||
|
"""Burn-in lifecycle tests covering the DB helpers in burnin._common,
|
||||||
|
plus the public surface of start_job + cancel_job that doesn't require
|
||||||
|
spinning up _run_job (which would need a mocked TrueNASClient + SSH).
|
||||||
|
|
||||||
|
These are the safety net Codex flagged was missing — the orchestration
|
||||||
|
paths were entirely untested before this. Run inside the container
|
||||||
|
image so app deps (aiosqlite, pydantic-settings, bcrypt) are present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
|
async def _setup_temp_db() -> str:
|
||||||
|
"""Same pattern as test_unlock_flow.py — temp DB + init_db, returning
|
||||||
|
the path. Caller must unlink in tearDown."""
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
from app.config import settings
|
||||||
|
settings.db_path = path
|
||||||
|
|
||||||
|
from app.database import init_db
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Seed two drives so start_job has something to attach to.
|
||||||
|
async with aiosqlite.connect(path) as db:
|
||||||
|
await db.execute("""
|
||||||
|
INSERT INTO drives
|
||||||
|
(truenas_disk_id, devname, serial, model, size_bytes,
|
||||||
|
temperature_c, smart_health, last_seen_at, last_polled_at)
|
||||||
|
VALUES ('id-1', 'sda', 'SER1', 'TestModel', 1000, 30, 'PASSED',
|
||||||
|
'2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00')
|
||||||
|
""")
|
||||||
|
await db.execute("""
|
||||||
|
INSERT INTO drives
|
||||||
|
(truenas_disk_id, devname, serial, model, size_bytes,
|
||||||
|
temperature_c, smart_health, last_seen_at, last_polled_at)
|
||||||
|
VALUES ('id-2', 'sdb', 'SER2', 'TestModel', 1000, 30, 'PASSED',
|
||||||
|
'2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00')
|
||||||
|
""")
|
||||||
|
await db.commit()
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommonHelpers(unittest.IsolatedAsyncioTestCase):
|
||||||
|
"""The per-stage DB mutators in app.burnin._common — pure SQLite
|
||||||
|
writes, no asyncssh, no orchestration. Trivially regression-testable."""
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.db_path = await _setup_temp_db()
|
||||||
|
# Insert a queued job + 2 stages we can mutate.
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cur = await db.execute(
|
||||||
|
"""INSERT INTO burnin_jobs
|
||||||
|
(drive_id, profile, state, percent, operator, created_at)
|
||||||
|
VALUES (?,?,?,?,?,?) RETURNING id""",
|
||||||
|
(1, "full", "running", 0, "test", "2026-05-03T00:00:00+00:00"),
|
||||||
|
)
|
||||||
|
self.job_id = (await cur.fetchone())[0]
|
||||||
|
for stage_name in ("precheck", "surface_validate", "final_check"):
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)",
|
||||||
|
(self.job_id, stage_name, "pending"),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
try:
|
||||||
|
os.unlink(self.db_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_start_stage_marks_running(self):
|
||||||
|
from app.burnin import _common
|
||||||
|
await _common._start_stage(self.job_id, "precheck")
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT state, started_at FROM burnin_stages "
|
||||||
|
"WHERE burnin_job_id=? AND stage_name='precheck'",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
self.assertEqual(row["state"], "running")
|
||||||
|
self.assertIsNotNone(row["started_at"])
|
||||||
|
|
||||||
|
async def test_finish_stage_success_records_duration(self):
|
||||||
|
from app.burnin import _common
|
||||||
|
await _common._start_stage(self.job_id, "precheck")
|
||||||
|
await _common._finish_stage(self.job_id, "precheck", success=True)
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT state, percent, duration_seconds FROM burnin_stages "
|
||||||
|
"WHERE burnin_job_id=? AND stage_name='precheck'",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
self.assertEqual(row["state"], "passed")
|
||||||
|
self.assertEqual(row["percent"], 100)
|
||||||
|
# Duration is float seconds since started_at — should be tiny but >0.
|
||||||
|
self.assertIsNotNone(row["duration_seconds"])
|
||||||
|
self.assertGreaterEqual(row["duration_seconds"], 0)
|
||||||
|
|
||||||
|
async def test_finish_stage_failure_carries_error_text(self):
|
||||||
|
from app.burnin import _common
|
||||||
|
await _common._start_stage(self.job_id, "surface_validate")
|
||||||
|
await _common._finish_stage(
|
||||||
|
self.job_id, "surface_validate",
|
||||||
|
success=False, error_text="mock failure",
|
||||||
|
)
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT state, percent, error_text FROM burnin_stages "
|
||||||
|
"WHERE burnin_job_id=? AND stage_name='surface_validate'",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
self.assertEqual(row["state"], "failed")
|
||||||
|
self.assertIsNone(row["percent"])
|
||||||
|
self.assertEqual(row["error_text"], "mock failure")
|
||||||
|
|
||||||
|
async def test_finish_stage_preserves_existing_error(self):
|
||||||
|
"""When called with error_text=None, the existing column value
|
||||||
|
from _set_stage_error must be preserved (not overwritten with NULL).
|
||||||
|
This is the bug that 1.0.0-12-ish fixed."""
|
||||||
|
from app.burnin import _common
|
||||||
|
await _common._start_stage(self.job_id, "surface_validate")
|
||||||
|
await _common._set_stage_error(
|
||||||
|
self.job_id, "surface_validate", "set by stage",
|
||||||
|
)
|
||||||
|
await _common._finish_stage(
|
||||||
|
self.job_id, "surface_validate", success=False, error_text=None,
|
||||||
|
)
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT error_text FROM burnin_stages "
|
||||||
|
"WHERE burnin_job_id=? AND stage_name='surface_validate'",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
self.assertEqual(row[0], "set by stage")
|
||||||
|
|
||||||
|
async def test_recalculate_progress_weights_correctly(self):
|
||||||
|
from app.burnin import _common
|
||||||
|
# Mark precheck passed, surface_validate at 50% running.
|
||||||
|
await _common._start_stage(self.job_id, "precheck")
|
||||||
|
await _common._finish_stage(self.job_id, "precheck", success=True)
|
||||||
|
await _common._start_stage(self.job_id, "surface_validate")
|
||||||
|
await _common._update_stage_percent(self.job_id, "surface_validate", 50)
|
||||||
|
await _common._recalculate_progress(self.job_id)
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT percent, stage_name FROM burnin_jobs WHERE id=?",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
# Weights: precheck=5, surface=65, final=5. Total = 75 across these
|
||||||
|
# 3 stages. Completed = 5 (precheck) + 32.5 (half of 65) = 37.5.
|
||||||
|
# 37.5 / 75 = 50%.
|
||||||
|
self.assertEqual(row["percent"], 50)
|
||||||
|
self.assertEqual(row["stage_name"], "surface_validate")
|
||||||
|
|
||||||
|
async def test_is_cancelled_reads_job_state(self):
|
||||||
|
from app.burnin import _common
|
||||||
|
self.assertFalse(await _common._is_cancelled(self.job_id))
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE burnin_jobs SET state='cancelled' WHERE id=?",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
self.assertTrue(await _common._is_cancelled(self.job_id))
|
||||||
|
|
||||||
|
async def test_append_stage_log_concatenates(self):
|
||||||
|
from app.burnin import _common
|
||||||
|
await _common._append_stage_log(self.job_id, "precheck", "alpha\n")
|
||||||
|
await _common._append_stage_log(self.job_id, "precheck", "beta\n")
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT log_text FROM burnin_stages "
|
||||||
|
"WHERE burnin_job_id=? AND stage_name='precheck'",
|
||||||
|
(self.job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
self.assertEqual(row[0], "alpha\nbeta\n")
|
||||||
|
|
||||||
|
|
||||||
|
class TestStartCancelJob(unittest.IsolatedAsyncioTestCase):
|
||||||
|
"""start_job + cancel_job touch the burnin orchestrator state. We
|
||||||
|
spawn _run_job tasks that try to acquire the semaphore — we cancel
|
||||||
|
immediately after to avoid running real burn-in stages. The real
|
||||||
|
test value here is "did start_job create the right DB rows" and
|
||||||
|
"does cancel_job mark them correctly."""
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.db_path = await _setup_temp_db()
|
||||||
|
# Initialise burnin without a real TrueNASClient — pass None.
|
||||||
|
# _run_job will hit the assert at top, but the test cancels
|
||||||
|
# before _run_job's first await actually runs.
|
||||||
|
from app import burnin
|
||||||
|
burnin._unlock_grants.clear()
|
||||||
|
burnin._active_tasks.clear()
|
||||||
|
import asyncio
|
||||||
|
burnin._semaphore = asyncio.Semaphore(2)
|
||||||
|
burnin._client = None # unused by start_job itself
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
# Cancel any outstanding tasks so they don't bleed into later tests.
|
||||||
|
from app import burnin
|
||||||
|
for t in list(burnin._active_tasks.values()):
|
||||||
|
t.cancel()
|
||||||
|
try:
|
||||||
|
os.unlink(self.db_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_start_job_inserts_queued_row_and_stages(self):
|
||||||
|
from app import burnin
|
||||||
|
job_id = await burnin.start_job(1, "surface", "test")
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT state, profile, operator FROM burnin_jobs WHERE id=?",
|
||||||
|
(job_id,),
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT stage_name FROM burnin_stages "
|
||||||
|
"WHERE burnin_job_id=? ORDER BY id",
|
||||||
|
(job_id,),
|
||||||
|
)
|
||||||
|
stages = [r[0] for r in await cur.fetchall()]
|
||||||
|
# State should be queued OR running (the spawned _run_job may
|
||||||
|
# have raced into the semaphore by now).
|
||||||
|
self.assertIn(row["state"], ("queued", "running"))
|
||||||
|
self.assertEqual(row["profile"], "surface")
|
||||||
|
self.assertEqual(row["operator"], "test")
|
||||||
|
# surface profile = precheck + surface_validate + final_check.
|
||||||
|
self.assertEqual(stages, ["precheck", "surface_validate", "final_check"])
|
||||||
|
|
||||||
|
async def test_start_job_rejects_duplicate_active(self):
|
||||||
|
from app import burnin
|
||||||
|
await burnin.start_job(1, "surface", "test")
|
||||||
|
# Second start on the same drive should be refused at the
|
||||||
|
# ValueError level (caught by the inline duplicate check or by
|
||||||
|
# the partial unique index).
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
await burnin.start_job(1, "surface", "test")
|
||||||
|
|
||||||
|
async def test_cancel_job_marks_state(self):
|
||||||
|
from app import burnin
|
||||||
|
job_id = await burnin.start_job(1, "surface", "test")
|
||||||
|
ok = await burnin.cancel_job(job_id, "test")
|
||||||
|
self.assertTrue(ok)
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cur = await db.execute(
|
||||||
|
"SELECT state FROM burnin_jobs WHERE id=?", (job_id,)
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
self.assertEqual(row[0], "cancelled")
|
||||||
|
|
||||||
|
async def test_cancel_job_returns_false_for_terminal_state(self):
|
||||||
|
from app import burnin
|
||||||
|
# Create a passed job manually
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cur = await db.execute(
|
||||||
|
"""INSERT INTO burnin_jobs
|
||||||
|
(drive_id, profile, state, operator, created_at)
|
||||||
|
VALUES (?,?,?,?,?) RETURNING id""",
|
||||||
|
(2, "surface", "passed", "x", "2026-05-03T00:00:00+00:00"),
|
||||||
|
)
|
||||||
|
job_id = (await cur.fetchone())[0]
|
||||||
|
await db.commit()
|
||||||
|
ok = await burnin.cancel_job(job_id, "test")
|
||||||
|
self.assertFalse(ok)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter(unittest.TestCase):
|
||||||
|
"""The generic rate-limit class added in 1.0.0-33 for the
|
||||||
|
unlock + password-change endpoints."""
|
||||||
|
|
||||||
|
def test_register_allows_under_threshold(self):
|
||||||
|
from app.auth import _RateLimiter
|
||||||
|
rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60)
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "ok")
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "ok")
|
||||||
|
|
||||||
|
def test_register_trips_at_threshold(self):
|
||||||
|
from app.auth import _RateLimiter
|
||||||
|
rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60)
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "ok")
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "ok")
|
||||||
|
# 3rd attempt brings us to threshold — counts as the trip.
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "now_locked_out")
|
||||||
|
# 4th sees the lockout from the prior call.
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "locked_out")
|
||||||
|
|
||||||
|
def test_clear_removes_counter_and_lockout(self):
|
||||||
|
from app.auth import _RateLimiter
|
||||||
|
rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60)
|
||||||
|
rl.register(("k", "alice"))
|
||||||
|
rl.register(("k", "alice")) # trips
|
||||||
|
self.assertIsNotNone(rl.locked_until(("k", "alice")))
|
||||||
|
rl.clear(("k", "alice"))
|
||||||
|
self.assertIsNone(rl.locked_until(("k", "alice")))
|
||||||
|
# Subsequent register should start fresh.
|
||||||
|
self.assertEqual(rl.register(("k", "alice")), "ok")
|
||||||
|
|
||||||
|
def test_separate_keys_dont_interfere(self):
|
||||||
|
from app.auth import _RateLimiter
|
||||||
|
rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60)
|
||||||
|
rl.register(("k", "alice"))
|
||||||
|
rl.register(("k", "alice")) # trips alice
|
||||||
|
# Bob's attempt should be allowed and untouched by alice's lockout.
|
||||||
|
self.assertEqual(rl.register(("k", "bob")), "ok")
|
||||||
|
self.assertIsNone(rl.locked_until(("k", "bob")))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Add table
Reference in a new issue