nas-burnin/app/ssh_client.py
Brandon Walter d4c0770b9e feat: app-level login + hardening sweep (1.0.0-22 -> 1.0.0-23)
Two layered changes shipped in this branch:

== 1.0.0-22: app-level authentication ==

The dashboard previously had only an IP allowlist. Adds username +
bcrypt password auth, signed-cookie sessions, and a "first user setup"
flow.

* New app/auth.py: User dataclass, bcrypt hash/verify, get_user_by_id/
  username, create_user, touch_last_login, FastAPI `get_current_user`
  dependency. Session secret loaded from SESSION_SECRET env or persisted
  to /data/session_secret.
* New app/auth_cli.py: `python -m app.auth_cli list|reset|add` for
  out-of-band user management. Passwords always read from a TTY prompt.
* Schema: idempotent ALTER for `users` table (id, username unique,
  password_hash, full_name, is_admin, created_at, last_login_at).
* main.py: SessionMiddleware (HMAC-signed cookie, max-age 7 days,
  SameSite=strict — see hardening section) + _AuthGateMiddleware that
  populates request.state.current_user and bounces unauth'd HTML GETs
  to /login while returning 401 JSON for everything else.
* Routes: GET /login renders first-user-setup form when users table is
  empty otherwise sign-in form; POST /login; POST /api/v1/auth/setup
  (only works while empty); GET|POST /logout.
* Bootstrap: env vars INITIAL_ADMIN_USERNAME + INITIAL_ADMIN_PASSWORD
  create the first admin on startup if both set AND users table empty.
  Ignored thereafter — change passwords via UI or CLI.
* Layout: header shows current_user.full_name|username + Logout link.
  Modal operator field auto-fills from the logged-in user via
  <meta name="default-operator"> rendered in layout (replaces the
  localStorage-only previous behaviour).
* requirements.txt: pinned bcrypt>=4.0,<5.0, itsdangerous>=2.1,
  python-multipart>=0.0.7. First step toward addressing the
  unpinned-deps gotcha.
* New app/templates/login.html with first-user-setup variant.

== 1.0.0-23: hardening sweep ==

Closes the eight-item gap audit:

* DB retention + automated backup. New app/retention.py runs daily at
  03:00 local. Nulls burnin_stages.log_text on stages older than
  retention_log_days (default 35), VACUUMs to reclaim pages, then runs
  `sqlite3 .backup` to /data/backups/app-YYYY-MM-DD.db keeping the
  retention_backup_keep most recent (default 14). Wired into the
  lifespan supervisor next to mailer/poller.

* CSRF mitigation. SessionMiddleware bumped to SameSite=strict so the
  browser refuses to send the session cookie on cross-site POSTs —
  removes the actual CSRF vector. Trade-off: external links into the
  app require re-auth.

* Login rate limiting. In-memory per-username AND per-source-IP failure
  counters in auth.py. 10 failures within 10 min trips a 15-min lockout
  for both keys. Returns HTTP 429 with a clear "try again in N min"
  message. Cleared on successful login.

* Login audit events. New event types in audit_events: user_login,
  user_login_failed, user_login_locked_out, user_logout,
  user_password_changed. All include source IP. Recorded via
  auth.audit_auth_event().

* Password change UI. Header link "Change password" opens
  templates/components/modal_password.html (current/new/confirm).
  Posts to POST /api/v1/auth/change-password — bcrypt-verifies current,
  requires >=8 char new pw, writes audit event.

* NVMe burn-in path. _stage_surface_validate now detects nvme*
  devnames and routes to _stage_surface_validate_nvme() which runs
  `nvme format -s 1 --force` (cryptographic erase). Seconds vs hours
  of badblocks, exercises the controller's secure-erase. Falls back
  to badblocks if nvme-cli isn't installed. Post-format SMART check.

* Mounted-FS detection. ssh_client.get_mounted_drives() runs
  `findmnt -no SOURCE`, parses non-ZFS sources back to base devnames.
  Poller treats them as pool_name='(mounted)', pool_role='mounted'.
  Confirm token DESTROY MOUNTED FILESYSTEM, distinct purple styling,
  audit event mounted_drive_unlocked, daily-report banner picks it up.

* Deeper /health. Real readiness check — DB write probe (PRAGMA
  journal_mode), poller freshness (age <= 3x stale_threshold), SSH
  test_connection() when configured. Returns 503 when any check fails
  so a proxy/orchestrator can take the container out of rotation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-02 11:08:29 -04:00

655 lines
24 KiB
Python

"""
SSH client for direct TrueNAS command execution (Stage 7).
When ssh_host is configured, burn-in stages use SSH to run smartctl and
badblocks directly on the TrueNAS host instead of going through the REST API.
Falls back to REST API / simulation when SSH is not configured (dev/mock mode).
TrueNAS CORE (FreeBSD) device paths: /dev/ada0, /dev/da0, etc.
TrueNAS SCALE (Linux) device paths: /dev/sda, /dev/sdb, etc.
The devname from the TrueNAS API is used as-is in /dev/{devname}.
"""
import asyncio
import logging
import re
from typing import Callable
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Monitored SMART attributes
# True → any non-zero raw value is a hard failure (drive rejected)
# False → non-zero is a warning (flagged but test continues)
# ---------------------------------------------------------------------------
SMART_ATTRS: dict[int, tuple[str, bool]] = {
5: ("Reallocated_Sector_Ct", True), # reallocation = FAIL
10: ("Spin_Retry_Count", False), # mechanical stress = WARN
188: ("Command_Timeout", False), # drive not responding = WARN
197: ("Current_Pending_Sector", True), # pending reallocation = FAIL
198: ("Offline_Uncorrectable", True), # unrecoverable read error = FAIL
199: ("UDMA_CRC_Error_Count", False), # cable/controller issue = WARN
}
# ---------------------------------------------------------------------------
# Configuration check
# ---------------------------------------------------------------------------
def is_configured() -> bool:
"""Returns True when SSH host + at least one auth method is available."""
import os
from app.config import settings
if not settings.ssh_host:
return False
has_creds = bool(
settings.ssh_key
or settings.ssh_password
or os.path.exists(os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH))
)
return has_creds
# ---------------------------------------------------------------------------
# Low-level connection
# ---------------------------------------------------------------------------
_MOUNTED_KEY_PATH = "/run/secrets/ssh_key"
async def _connect():
"""Open a single-use SSH connection. Caller must use `async with`."""
import asyncssh
from app.config import settings
kwargs: dict = {
"host": settings.ssh_host,
"port": settings.ssh_port,
"username": settings.ssh_user,
"known_hosts": None, # trust all hosts (same spirit as TRUENAS_VERIFY_TLS=false)
}
if settings.ssh_key:
# Key material provided via env var (base case)
kwargs["client_keys"] = [asyncssh.import_private_key(settings.ssh_key)]
elif settings.ssh_password:
kwargs["password"] = settings.ssh_password
else:
# Fall back to mounted key file (preferred for production — no key in env vars)
import os
key_path = os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)
if os.path.exists(key_path):
kwargs["client_keys"] = [key_path]
# If nothing is configured, asyncssh will attempt agent/default key lookup
return asyncssh.connect(**kwargs)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def test_connection() -> dict:
"""Test SSH connectivity. Returns {"ok": True} or {"ok": False, "error": str}."""
if not is_configured():
return {"ok": False, "error": "SSH not configured (ssh_host is empty)"}
try:
async with await _connect() as conn:
result = await conn.run("echo ok", check=False)
if "ok" in result.stdout:
return {"ok": True}
return {"ok": False, "error": result.stderr.strip() or "unexpected output"}
except Exception as exc:
return {"ok": False, "error": str(exc)}
async def get_smart_attributes(devname: str) -> dict:
"""
Run `smartctl -a /dev/{devname}` and parse the output.
Returns:
health: str — "PASSED" | "FAILED" | "UNKNOWN"
raw_output: str — full smartctl output
attributes: dict[int, {"name": str, "raw": int}]
warnings: list[str] — attribute names with non-zero raw (non-critical)
failures: list[str] — attribute names with non-zero raw (critical)
"""
cmd = f"smartctl -a /dev/{devname}"
try:
async with await _connect() as conn:
result = await conn.run(cmd, check=False)
output = result.stdout + result.stderr
return _parse_smartctl(output)
except Exception as exc:
return {
"health": "UNKNOWN",
"raw_output": str(exc),
"attributes": {},
"warnings": [],
"failures": [f"SSH error: {exc}"],
}
async def start_smart_test(devname: str, test_type: str) -> str:
"""
Run `smartctl -t short|long /dev/{devname}`.
Returns raw output. Raises RuntimeError on unrecoverable failure.
test_type: "SHORT" or "LONG"
"""
arg = "short" if test_type.upper() == "SHORT" else "long"
cmd = f"smartctl -t {arg} /dev/{devname}"
async with await _connect() as conn:
result = await conn.run(cmd, check=False)
output = result.stdout + result.stderr
# smartctl exits 0 or 4 when the test is successfully started on most drives
started = ("Testing has begun" in output or
"test has begun" in output.lower() or
result.returncode in (0, 4))
if not started:
raise RuntimeError(f"smartctl returned exit {result.returncode}: {output[:400]}")
return output
async def poll_smart_progress(devname: str) -> dict:
"""
Run `smartctl -a /dev/{devname}` and extract self-test status.
Returns:
state: "running" | "passed" | "failed" | "unknown"
percent_remaining: int (0 = complete when state != "running")
output: str
"""
cmd = f"smartctl -a /dev/{devname}"
async with await _connect() as conn:
result = await conn.run(cmd, check=False)
output = result.stdout + result.stderr
return _parse_smart_progress(output)
async def abort_smart_test(devname: str) -> None:
"""Send `smartctl -X /dev/{devname}` to abort an in-progress test."""
cmd = f"smartctl -X /dev/{devname}"
async with await _connect() as conn:
await conn.run(cmd, check=False)
async def run_badblocks(
devname: str,
on_progress: Callable[[int, int, str], None],
cancelled_fn: Callable[[], bool] | None = None,
) -> dict:
"""
Run `badblocks -wsv -b 4096 -p 1 /dev/{devname}` and stream output.
on_progress(percent, bad_blocks, line) is called for each line of output.
cancelled_fn() is polled to support mid-test cancellation.
Returns: {"bad_blocks": int, "output": str, "aborted": bool}
"""
from app.config import settings
cmd = f"badblocks -wsv -b 4096 -p 1 /dev/{devname}"
lines: list[str] = []
bad_blocks = 0
aborted = False
last_pct = 0
try:
async with await _connect() as conn:
async with conn.create_process(cmd) as proc:
# badblocks writes progress to stderr, bad block numbers to stdout
async def _read_stream(stream, is_stderr: bool):
nonlocal bad_blocks, last_pct, aborted
async for raw_line in stream:
line = raw_line if isinstance(raw_line, str) else raw_line.decode("utf-8", errors="replace")
lines.append(line)
if is_stderr:
m = re.search(r"([\d.]+)%\s+done", line)
if m:
last_pct = min(99, int(float(m.group(1))))
else:
# Each non-empty stdout line during badblocks is a bad block number
stripped = line.strip()
if stripped and stripped.isdigit():
bad_blocks += 1
on_progress(last_pct, bad_blocks, line)
# Abort if threshold exceeded
if bad_blocks > settings.bad_block_threshold:
aborted = True
proc.kill()
lines.append(
f"\n[ABORTED] Bad block count ({bad_blocks}) exceeded "
f"threshold ({settings.bad_block_threshold})\n"
)
return
# Abort on cancellation
if cancelled_fn and cancelled_fn():
aborted = True
proc.kill()
return
stdout_task = asyncio.create_task(_read_stream(proc.stdout, False))
stderr_task = asyncio.create_task(_read_stream(proc.stderr, True))
await asyncio.gather(stdout_task, stderr_task, return_exceptions=True)
await proc.wait()
except Exception as exc:
lines.append(f"\n[SSH error] {exc}\n")
if not aborted:
last_pct = 100
return {
"bad_blocks": bad_blocks,
"output": "".join(lines),
"aborted": aborted,
}
def _parse_zpool_list_output(stdout: str) -> dict:
"""Pure parser for `zpool list -vHP` stdout. Exposed for unit tests.
See get_pool_membership() for output semantics. This function never
raises — malformed lines are silently skipped.
"""
import re as _re
def _strip_partition(name: str) -> str:
m = _re.match(r"^(nvme\d+n\d+)", name)
if m:
return m.group(1)
m = _re.match(r"^(sd[a-z]+)", name)
if m:
return m.group(1)
return name
SECTION_MARKERS = {"cache", "log", "logs", "spare", "spares",
"special", "dedup"}
SECTION_NORMALIZE = {"logs": "log", "spares": "spare"}
out: dict = {}
current_pool: str | None = None
current_role: str = "data"
for raw in stdout.splitlines():
if not raw.strip():
continue
depth = 0
while depth < len(raw) and raw[depth] == "\t":
depth += 1
first = raw[depth:].split("\t", 1)[0].strip()
if depth == 0:
current_pool = first
current_role = "data"
continue
if depth == 1:
if first in SECTION_MARKERS:
current_role = SECTION_NORMALIZE.get(first, first)
continue
if first.startswith(("mirror", "raidz", "draid")):
continue
if first.startswith("/dev/") and current_pool:
dn = _strip_partition(first[len("/dev/"):])
out[dn] = {"pool": current_pool, "role": current_role}
continue
if first.startswith("/dev/") and current_pool:
dn = _strip_partition(first[len("/dev/"):])
out[dn] = {"pool": current_pool, "role": current_role}
return out
def _parse_lsblk_zfs_output(stdout: str) -> set:
"""Pure parser for `lsblk -no NAME,FSTYPE -l` stdout. Returns base
devnames carrying ZFS labels (whole-disk OR via partition). Exposed
for unit tests."""
import re as _re
out: set = set()
for line in stdout.splitlines():
parts = line.split()
if len(parts) < 2:
continue
name, fstype = parts[0], parts[1]
if fstype != "zfs_member":
continue
if name.startswith("nvme"):
m = _re.match(r"^(nvme\d+n\d+)", name)
if m:
out.add(m.group(1))
else:
m = _re.match(r"^(sd[a-z]+)", name)
if m:
out.add(m.group(1))
return out
async def get_pool_membership() -> dict | None:
"""Return {devname: {"pool": str, "role": str}} for every drive in any zpool.
Parses `zpool list -vHP` output. Tab-indent depth tells us structure:
depth 0 pool name line
depth 1 vdev type line (mirror-N, raidz*N, draid*) OR section
marker (cache/log/spare/special/dedup/logs) OR a single-disk
vdev that is itself a /dev/... entry
depth 2 device line within a vdev — '/dev/sdX', '/dev/nvmeXnY', etc.
may have a partition suffix that we strip back to the
base devname so it matches what TrueNAS reports.
Roles: data | cache | log | spare | special | dedup
Returns:
- {} when the SSH call succeeded and there are genuinely no pools
- None on any failure (SSH down, parse error, non-zero exit, no
stdout). Callers MUST treat None differently from {}: an
empty dict is "definitely no pool members," None is "we
couldn't tell." Treating None as "no pool members" is a
fail-open security regression.
"""
import re as _re
if not is_configured():
return {}
cmd = "zpool list -vHP 2>/dev/null"
try:
async with await _connect() as conn:
r = await conn.run(cmd, check=False)
if r.returncode != 0:
return None
except Exception:
return None
if not r.stdout:
# rc==0 with empty output = host has no pools. (`zpool list -H`
# returns no rows when zero pools are imported.) That's a real
# answer, not a failure.
return {}
return _parse_zpool_list_output(r.stdout)
async def get_mounted_drives() -> set | None:
"""Return base devnames of every drive whose partitions are mounted
anywhere right now. Defense-in-depth on top of pool detection — catches
XFS/ext4/etc. scratch disks the operator forgot about. Returns None on
any failure (caller treats that as 'preserve previous state')."""
if not is_configured():
return set()
cmd = "findmnt -no SOURCE 2>/dev/null"
try:
async with await _connect() as conn:
r = await conn.run(cmd, check=False)
if r.returncode != 0 or not r.stdout:
# findmnt always has at least / mounted on a Linux host;
# empty output is itself suspicious. Treat as failure.
return None
except Exception:
return None
return _parse_findmnt_sources(r.stdout)
def _parse_findmnt_sources(stdout: str) -> set:
"""Pure parser for findmnt output. Strips partitions; ignores tmpfs,
overlay, zfs (zfs is handled by pool detection)."""
import re as _re
out: set = set()
for raw in stdout.splitlines():
s = raw.strip()
if not s.startswith("/dev/"):
continue
# Skip ZFS filesystems (those are pool/exported drives, handled
# separately and shouldn't double-lock as 'mounted').
if "/dev/zd" in s or "/dev/zvol" in s:
continue
name = s[len("/dev/"):].split("[")[0] # bind mounts can have [subdir]
if name.startswith("nvme"):
m = _re.match(r"^(nvme\d+n\d+)", name)
if m:
out.add(m.group(1))
else:
m = _re.match(r"^(sd[a-z]+)", name)
if m:
out.add(m.group(1))
return out
async def get_smart_health_map(devnames: list[str]) -> dict | None:
"""Return {devname: 'PASSED'|'FAILED'|'UNKNOWN'} for every devname.
Runs `smartctl -H` for each disk in a single SSH session — much faster
than one connection per disk. Returns None on any SSH failure so the
poller can fall back to the previously-stored health value rather than
silently overwriting everything as 'UNKNOWN'.
`smartctl -H` is the cheap SMART self-assessment lookup (no full
attribute scan) — milliseconds per drive. The output format is stable:
SMART overall-health self-assessment test result: PASSED
SMART overall-health self-assessment test result: FAILED!
For drives that don't support the command at all, smartctl exits
non-zero and we record UNKNOWN for that device specifically.
"""
if not is_configured() or not devnames:
return {} if devnames else None
# Build one shell pipeline that prefixes each result with "@@DEVNAME@@"
# so we can split the combined stdout deterministically.
parts = []
for d in devnames:
# Reject anything that doesn't look like a basic devname so we
# never inject shell metacharacters into the remote command.
if not d.replace("nvme", "").replace("n", "").replace("p", "").replace("sd", "").isalnum():
continue
parts.append(f"echo '@@{d}@@'; smartctl -H /dev/{d} 2>&1; echo '@@END@@'")
if not parts:
return {}
cmd = "; ".join(parts)
try:
async with await _connect() as conn:
r = await asyncio.wait_for(conn.run(cmd, check=False), timeout=30)
except Exception:
return None
if not r.stdout:
return None
return _parse_smart_health_batch(r.stdout)
def _parse_smart_health_batch(stdout: str) -> dict:
"""Pure parser for the batched smartctl -H output. Exposed for tests."""
result: dict[str, str] = {}
current: str | None = None
buf: list[str] = []
def _flush():
if current is None:
return
text = "\n".join(buf)
if "PASSED" in text:
result[current] = "PASSED"
elif "FAILED" in text or "FAILURE" in text:
result[current] = "FAILED"
else:
result[current] = "UNKNOWN"
for raw in stdout.splitlines():
line = raw.strip()
if line.startswith("@@") and line.endswith("@@"):
inner = line[2:-2]
if inner == "END":
_flush()
current = None
buf = []
else:
_flush()
current = inner
buf = []
else:
buf.append(line)
_flush()
return result
async def get_zfs_member_drives() -> set | None:
"""Return devnames of every drive whose partitions carry a ZFS label.
Combined with get_pool_membership(): a drive in this set but NOT in the
active-pool map carries ZFS data from a previously-imported pool that
was exported (or imported on a different system). We treat those as
locked too — wiping them would silently destroy a pool.
Returns:
- set() when lsblk succeeded and no drives carry ZFS labels
- None on any failure. Same fail-closed semantics as
get_pool_membership() — callers must NOT treat None as
"no exported drives," that's a security regression.
"""
if not is_configured():
return set()
cmd = "lsblk -no NAME,FSTYPE -l 2>/dev/null"
try:
async with await _connect() as conn:
r = await conn.run(cmd, check=False)
if r.returncode != 0:
return None
except Exception:
return None
if not r.stdout:
# lsblk with rc==0 and no output is impossible on a normal Linux
# host; treat as failure rather than "no drives at all."
return None
return _parse_lsblk_zfs_output(r.stdout)
async def get_system_sensors() -> dict:
"""
Run `sensors -j` on TrueNAS and extract system-level temperatures.
Returns {"cpu_c": int|None, "pch_c": int|None}.
cpu_c = CPU package temp (coretemp chip)
pch_c = PCH/chipset temp (pch_* chip) — proxy for storage I/O lane thermals
Falls back gracefully if SSH is not configured or lm-sensors is unavailable.
"""
if not is_configured():
return {}
try:
async with await _connect() as conn:
result = await conn.run("sensors -j 2>/dev/null", check=False)
output = result.stdout.strip()
if not output:
return {}
return _parse_sensors_json(output)
except Exception as exc:
log.debug("get_system_sensors failed: %s", exc)
return {}
def _parse_sensors_json(output: str) -> dict:
import json as _json
try:
data = _json.loads(output)
except Exception:
return {}
cpu_c: int | None = None
pch_c: int | None = None
for chip_name, chip_data in data.items():
if not isinstance(chip_data, dict):
continue
# CPU package temp — coretemp chip, "Package id N" sensor
if chip_name.startswith("coretemp") and cpu_c is None:
for sensor_name, sensor_vals in chip_data.items():
if not isinstance(sensor_vals, dict):
continue
if "package" in sensor_name.lower():
for k, v in sensor_vals.items():
if k.endswith("_input") and isinstance(v, (int, float)):
cpu_c = int(round(v))
break
if cpu_c is not None:
break
# PCH / chipset temp — manages PCIe lanes including HBA / storage I/O
elif chip_name.startswith("pch_") and pch_c is None:
for sensor_name, sensor_vals in chip_data.items():
if not isinstance(sensor_vals, dict):
continue
for k, v in sensor_vals.items():
if k.endswith("_input") and isinstance(v, (int, float)):
pch_c = int(round(v))
break
if pch_c is not None:
break
return {"cpu_c": cpu_c, "pch_c": pch_c}
# ---------------------------------------------------------------------------
# Parsers
# ---------------------------------------------------------------------------
def _parse_smartctl(output: str) -> dict:
health = "UNKNOWN"
attributes: dict[int, dict] = {}
warnings: list[str] = []
failures: list[str] = []
m = re.search(r"self-assessment test result:\s+(\w+)", output, re.IGNORECASE)
if m:
health = m.group(1).upper()
# Attribute table: ID# NAME FLAG VALUE WORST THRESH TYPE UPDATED WHEN_FAILED RAW_VALUE
for line in output.splitlines():
am = re.match(
r"\s*(\d+)\s+(\S+)\s+\S+\s+\d+\s+\d+\s+\d+\s+\S+\s+\S+\s+\S+\s+(\d+)",
line,
)
if not am:
continue
attr_id = int(am.group(1))
attr_name = am.group(2)
raw_val = int(am.group(3))
attributes[attr_id] = {"name": attr_name, "raw": raw_val}
if attr_id in SMART_ATTRS:
_, is_critical = SMART_ATTRS[attr_id]
if raw_val > 0:
msg = f"{attr_name} = {raw_val}"
if is_critical:
failures.append(msg)
else:
warnings.append(msg)
return {
"health": health,
"raw_output": output,
"attributes": attributes,
"warnings": warnings,
"failures": failures,
}
def _parse_smart_progress(output: str) -> dict:
state = "unknown"
percent_remaining = None # None = "in progress but no % line parsed yet"
lower = output.lower()
if "self-test routine in progress" in lower or "self-test routine in progress" in output:
state = "running"
m = re.search(r"(\d+)%\s+of\s+test\s+remaining", output, re.IGNORECASE)
if m:
percent_remaining = int(m.group(1))
elif "completed without error" in lower:
state = "passed"
elif (
"completed: read failure" in lower
or "completed: write failure" in lower
or "aborted by host" in lower
or ("completed" in lower and "failure" in lower)
):
state = "failed"
elif "in progress" in lower:
state = "running"
return {
"state": state,
"percent_remaining": percent_remaining,
"output": output,
}