refactor: extract _common.py + stages.py from burnin (1.0.0-31)

Continues the staged burnin.py module split started in 1.0.0-30.
Two more clean extractions; orchestration (init, _run_job,
start_job, cancel_job, check_stuck_jobs, semaphore) intentionally
stays in __init__.py for now to avoid threading the TrueNASClient
through cross-module setters.

* app/burnin/_common.py — shared helpers with no upward deps:
  STAGE_ORDER + _STAGE_BASE_WEIGHTS + POLL_INTERVAL constants;
  _now / _db connection helper; _is_cancelled, _start_stage,
  _finish_stage, _cancel_stage, _set_stage_error, _update_stage_*,
  _append_stage_log, _store_smart_*, _recalculate_progress; SSE
  _push_update. Imports nothing from sibling burnin modules.

* app/burnin/stages.py — every per-stage implementation moved
  verbatim: _stage_precheck, _stage_smart_test +
  _stage_smart_test_api / _ssh, _stage_surface_validate +
  _surface_validate_nvme / _ssh / _truenas, _stage_timed_simulate,
  _stage_final_check, plus _badblocks_available, _nvme_cli_available,
  and _dispatch_stage. Pulls the shared helpers from _common,
  remote-PID setters from kill, and the live TrueNASClient via a
  lazy `_get_client()` helper that defers `from app import burnin`
  until call time so we don't trip a circular import.

* __init__.py shrank from ~1480 LoC to ~600. Re-exports every
  public name (start_job, cancel_job, init, check_stuck_jobs,
  PoolMemberError, UNLOCK_TTL_SECONDS, etc.) so external callers
  in routes.py / mailer.py / poller.py see the same surface.

State that didn't move: _semaphore, _client, _active_tasks remain
on the package root (with a runtime _client reference from routes.py
preserved). _run_job and start_job still live in __init__.py — full
task.py extraction would require giving stages access to _client
through a setter rather than the lazy lookup, deferred to a future
slice.

Verification: 44/44 unit tests pass in container; /health 200;
container boots clean. No public API change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Brandon Walter 2026-05-03 01:18:04 -04:00
parent 9cbae44495
commit 19c2c0dc0f
4 changed files with 1041 additions and 918 deletions

View file

@ -29,36 +29,17 @@ from app.truenas import TrueNASClient
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Stage definitions
# ---------------------------------------------------------------------------
# Stage configuration + DB helpers extracted to _common.py in 1.0.0-31.
from ._common import ( # noqa: E402
STAGE_ORDER, _STAGE_BASE_WEIGHTS, POLL_INTERVAL,
_now, _db,
_is_cancelled,
_start_stage, _finish_stage, _cancel_stage, _set_stage_error,
_update_stage_percent, _update_stage_bad_blocks, _append_stage_log,
_store_smart_attrs, _store_smart_raw_output,
_recalculate_progress, _push_update,
)
STAGE_ORDER: dict[str, list[str]] = {
# Legacy
"quick": ["precheck", "short_smart", "io_validate", "final_check"],
# Single-stage selectable profiles
"surface": ["precheck", "surface_validate", "final_check"],
"short": ["precheck", "short_smart", "final_check"],
"long": ["precheck", "long_smart", "final_check"],
# Two-stage combos
"surface_short": ["precheck", "surface_validate", "short_smart", "final_check"],
"surface_long": ["precheck", "surface_validate", "long_smart", "final_check"],
"short_long": ["precheck", "short_smart", "long_smart", "final_check"],
# All three
"full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"],
}
# Per-stage base weights used to compute overall job % progress dynamically
_STAGE_BASE_WEIGHTS: dict[str, int] = {
"precheck": 5,
"surface_validate": 65,
"short_smart": 12,
"long_smart": 13,
"io_validate": 10,
"final_check": 5,
}
POLL_INTERVAL = 5.0 # seconds between progress checks during active stages
# ---------------------------------------------------------------------------
# Module-level state (initialized in init())
@ -97,17 +78,7 @@ _is_unlocked = _unlock.is_unlocked # legacy private name
_kill_remote_process = _kill.kill_remote_process
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
@asynccontextmanager
async def _db():
"""Open a WAL-mode connection with busy_timeout so writers wait for the lock
instead of immediately raising 'database is locked' under contention."""
async with aiosqlite.connect(settings.db_path) as db:
await db.execute("PRAGMA busy_timeout=10000")
yield db
# _now() and _db() are re-exported from _common above.
# ---------------------------------------------------------------------------
@ -533,891 +504,31 @@ async def _execute_stages(job_id: int, stages: list[str], devname: str, drive_id
return True
async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool:
if stage_name == "precheck":
return await _stage_precheck(job_id, drive_id)
elif stage_name == "short_smart":
return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id)
elif stage_name == "long_smart":
return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id)
elif stage_name == "surface_validate":
return await _stage_surface_validate(job_id, devname, drive_id)
elif stage_name == "io_validate":
return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds)
elif stage_name == "final_check":
return await _stage_final_check(job_id, devname, drive_id)
return True
# ---------------------------------------------------------------------------
# Individual stage implementations
# ---------------------------------------------------------------------------
async def _stage_precheck(job_id: int, drive_id: int) -> bool:
"""Check SMART health and temperature before starting destructive work."""
async with _db() as db:
cur = await db.execute(
"SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,)
)
row = await cur.fetchone()
if not row:
return False
health, temp = row[0], row[1]
if health == "FAILED":
await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in")
return False
if temp and temp > settings.temp_crit_c:
await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit")
return False
await asyncio.sleep(1) # Simulate brief check
return True
async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str,
drive_id: int | None = None) -> bool:
"""Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise."""
from app import ssh_client
if ssh_client.is_configured():
return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id)
return await _stage_smart_test_api(job_id, devname, test_type, stage_name)
async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool:
"""TrueNAS REST API path for SMART test (mock / dev mode)."""
tn_job_id = await _client.start_smart_test([devname], test_type)
while True:
if await _is_cancelled(job_id):
try:
await _client.abort_job(tn_job_id)
except Exception:
pass
return False
jobs = await _client.get_smart_jobs()
job = next((j for j in jobs if j["id"] == tn_job_id), None)
if not job:
return False
state = job["state"]
pct = job["progress"]["percent"]
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id, None)
_push_update()
if state == "SUCCESS":
return True
elif state in ("FAILED", "ABORTED"):
await _set_stage_error(job_id, stage_name,
job.get("error") or f"SMART {test_type} test failed")
return False
await asyncio.sleep(POLL_INTERVAL)
async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str,
drive_id: int | None) -> bool:
"""SSH path for SMART test — runs smartctl directly on TrueNAS."""
from app import ssh_client
# Start the test
try:
startup = await ssh_client.start_smart_test(devname, test_type)
await _append_stage_log(job_id, stage_name, startup + "\n")
except Exception as exc:
await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}")
return False
# Brief pause to let the test register in smartctl output
await asyncio.sleep(3)
# Throttle log_text appends — every poll on a multi-hour long_smart bloated
# log_text to 50+ MB and triggered SQLite "database is locked" because each
# COALESCE-then-append rewrites the whole column. Append every ~60s, on the
# first poll, and on any state change.
LOG_EVERY_N_POLLS = 12
poll_count = 0
last_state: str | None = None
# Poll until complete
while True:
if await _is_cancelled(job_id):
try:
await ssh_client.abort_smart_test(devname)
except Exception:
pass
return False
await asyncio.sleep(POLL_INTERVAL)
try:
progress = await ssh_client.poll_smart_progress(devname)
except Exception as exc:
log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id})
await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n")
continue
poll_count += 1
state_changed = progress["state"] != last_state
last_state = progress["state"]
if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed:
await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n")
if progress["state"] == "running":
pct = max(0, 100 - progress["percent_remaining"])
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id)
_push_update()
elif progress["state"] == "passed":
await _update_stage_percent(job_id, stage_name, 100)
# Run attribute check
if drive_id is not None:
try:
attrs = await ssh_client.get_smart_attributes(devname)
await _store_smart_attrs(drive_id, attrs)
await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"])
if attrs["failures"]:
error = "SMART attribute failures: " + "; ".join(attrs["failures"])
await _set_stage_error(job_id, stage_name, error)
return False
if attrs["warnings"]:
await _append_stage_log(
job_id, stage_name,
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
)
except Exception as exc:
log.warning("Failed to retrieve SMART attributes: %s", exc)
await _recalculate_progress(job_id)
_push_update()
return True
elif progress["state"] == "failed":
await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed")
return False
# "unknown" → keep polling
async def _badblocks_available() -> bool:
"""Check if badblocks is installed on the remote host (Linux/SCALE only)."""
from app import ssh_client
try:
async with await ssh_client._connect() as conn:
result = await conn.run("which badblocks", check=False)
return result.returncode == 0
except Exception:
return False
async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool:
"""
Surface validation stage auto-routes to the right implementation:
1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE):
`nvme format -s 1 /dev/{devname}` (cryptographic erase).
Far faster than badblocks on NVMe (seconds vs hours) and
exercises the controller's secure-erase path, not just user-LBA
writes.
2. SSH configured + badblocks available (TrueNAS SCALE / Linux):
badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH.
3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD):
uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check.
4. No SSH:
simulated timed progress (dev/mock mode).
"""
from app import ssh_client
if ssh_client.is_configured():
if devname.startswith("nvme") and await _nvme_cli_available():
return await _stage_surface_validate_nvme(job_id, devname, drive_id)
if await _badblocks_available():
return await _stage_surface_validate_ssh(job_id, devname, drive_id)
# TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API
await _append_stage_log(
job_id, "surface_validate",
"[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — "
"using TrueNAS disk.wipe API (FULL write pass).\n\n"
)
return await _stage_surface_validate_truenas(job_id, devname, drive_id)
return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds)
async def _nvme_cli_available() -> bool:
"""Check if nvme-cli is installed on the remote host."""
from app import ssh_client
try:
async with await ssh_client._connect() as conn:
r = await conn.run("which nvme", check=False)
return r.returncode == 0
except Exception:
return False
async def _stage_surface_validate_nvme(job_id: int, devname: str,
drive_id: int) -> bool:
"""NVMe destructive surface test via `nvme format -s 1` (crypto erase).
Crypto-erase nukes the data encryption key on the drive's controller,
rendering all stored data unrecoverable in milliseconds; the actual
flash is then implicitly trim-able. This is the canonical destructive
burn-in for NVMe badblocks would write the entire LBA space, which
is slower AND wears the flash unnecessarily.
Post-format we re-read SMART attributes; the drive should report all
counters reset (life used + spare) and PASSED health.
"""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] nvme format -s 1 /dev/{devname}\n"
f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n"
# Per-stage implementations and the dispatch router live in stages.py.
from .stages import ( # noqa: E402
_dispatch_stage,
_badblocks_available,
_nvme_cli_available,
_stage_precheck,
_stage_smart_test,
_stage_smart_test_api,
_stage_smart_test_ssh,
_stage_surface_validate,
_stage_surface_validate_nvme,
_stage_surface_validate_ssh,
_stage_surface_validate_truenas,
_stage_timed_simulate,
_stage_final_check,
)
cmd = f"nvme format -s 1 --force /dev/{devname}"
try:
async with await ssh_client._connect() as conn:
r = await asyncio.wait_for(
conn.run(cmd, check=False), timeout=600
)
except Exception as exc:
await _append_stage_log(
job_id, "surface_validate", f"\n[SSH error] {exc}\n"
)
await _set_stage_error(
job_id, "surface_validate", f"NVMe format SSH error: {exc}"
)
return False
output = (r.stdout or "") + (r.stderr or "")
await _append_stage_log(job_id, "surface_validate", output + "\n")
if r.returncode != 0:
await _set_stage_error(
job_id, "surface_validate",
f"nvme format exited {r.returncode}: {output.strip()[:200]}"
)
return False
# Sanity-check post-format SMART health. Mirrors the surface_validate
# SSH path's check parity — fail on FAILED health, fail on real
# SMART attribute failures, log warnings but don't fail. A transport
# error here is treated as a soft pass (log + continue) so a single
# SSH blip after a successful format doesn't undo the work.
try:
attrs = await ssh_client.get_smart_attributes(devname)
ssh_only_failures = [
f for f in (attrs.get("failures") or []) if f.startswith("SSH error:")
]
real_failures = [
f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:")
]
if attrs.get("health") == "FAILED":
await _set_stage_error(
job_id, "surface_validate",
"NVMe SMART health FAILED after format",
)
return False
if real_failures:
await _set_stage_error(
job_id, "surface_validate",
"NVMe SMART attribute failures after format: "
+ "; ".join(real_failures),
)
return False
if ssh_only_failures:
await _append_stage_log(
job_id, "surface_validate",
"[WARN] post-format SMART check had SSH errors "
"(soft-passing): " + "; ".join(ssh_only_failures) + "\n",
)
if attrs.get("warnings"):
await _append_stage_log(
job_id, "surface_validate",
"[WARN] " + "; ".join(attrs["warnings"]) + "\n",
)
except Exception as exc:
log.warning("Post-format SMART check error on %s: %s", devname, exc)
await _append_stage_log(
job_id, "surface_validate",
f"[WARN] post-format SMART check raised: {exc}\n",
)
await _update_stage_percent(job_id, "surface_validate", 100)
await _recalculate_progress(job_id)
_push_update()
return True
async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool:
"""Run badblocks over SSH, streaming output to stage log."""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] badblocks -wsv -b {settings.surface_validate_block_size} "
f"-c {settings.surface_validate_block_buffer} "
f"-p {settings.surface_validate_passes} /dev/{devname}\n"
f"[NOTE] This is a DESTRUCTIVE write test. "
f"All data on /dev/{devname} will be overwritten.\n\n"
)
def _is_cancelled_sync() -> bool:
# Synchronous version — we check the DB state flag set by cancel_job()
import asyncio
loop = asyncio.get_event_loop()
try:
return loop.run_until_complete(_is_cancelled(job_id))
except Exception:
return False
last_logged_pct = [-1]
def on_progress(pct: int, bad_blocks: int, line: str) -> None:
nonlocal last_logged_pct
# Write to log (fire-and-forget via asyncio.create_task from sync context)
# The log append is done in the async flush below
pass
accumulated_lines: list[str] = []
async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None:
accumulated_lines.append(line)
# Flush to DB and update progress every ~25 lines to avoid excessive DB writes
if len(accumulated_lines) % 25 == 0:
await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:]))
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks)
await _update_stage_percent(job_id, "surface_validate", pct)
await _recalculate_progress(job_id)
_push_update()
if await _is_cancelled(job_id):
raise asyncio.CancelledError
# Run badblocks — we adapt the callback pattern to async by collecting then flushing
result = {"bad_blocks": 0, "output": "", "aborted": False}
try:
# The actual streaming; we handle progress via the accumulated_lines pattern
bad_blocks_total = 0
output_lines: list[str] = []
async with await ssh_client._connect() as conn:
# Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote
# PID on the first stdout line. asyncssh's proc.kill() sends an
# SSH signal request that OpenSSH's sshd ignores by default, so
# we need the PID to issue an out-of-band `kill -9` over a fresh
# session when we want to abort.
#
# Block geometry is operator-tunable (Settings → Burn-in):
# -b N block size in bytes (settings.surface_validate_block_size)
# -c N blocks held per IO (settings.surface_validate_block_buffer)
# -p N pass count (settings.surface_validate_passes)
# Defaults preserve original behavior (-b 4096 -c 64 -p 1).
bb_args = (
f"-wsv "
f"-b {settings.surface_validate_block_size} "
f"-c {settings.surface_validate_block_buffer} "
f"-p {settings.surface_validate_passes}"
)
cmd = (
f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'"
)
async with conn.create_process(cmd) as proc:
import re as _re
pid_seen = False
async def _drain(stream, is_stderr: bool):
nonlocal bad_blocks_total, pid_seen
async for raw in stream:
line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace")
# First stdout line is "PID:<n>" from the wrapping shell.
# Capture it and don't append it to the user-visible log.
if not is_stderr and not pid_seen and line.startswith("PID:"):
pid_seen = True
try:
_remote_pids[job_id] = int(line[4:].strip())
log.info(
"Captured remote PID %d for job %d (badblocks)",
_remote_pids[job_id], job_id,
)
except ValueError:
pass
continue
output_lines.append(line)
if is_stderr:
m = _re.search(r"([\d.]+)%\s+done", line)
if m:
pct = min(99, int(float(m.group(1))))
await _update_stage_percent(job_id, "surface_validate", pct)
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total)
await _recalculate_progress(job_id)
_push_update()
else:
stripped = line.strip()
if stripped and stripped.isdigit():
bad_blocks_total += 1
# Append to DB log in chunks
if len(output_lines) % 20 == 0:
chunk = "".join(output_lines[-20:])
await _append_stage_log(job_id, "surface_validate", chunk)
# Abort on bad block threshold
if bad_blocks_total > settings.bad_block_threshold:
await _kill_remote_process(job_id)
output_lines.append(
f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded "
f"threshold ({settings.bad_block_threshold})\n"
)
return
if await _is_cancelled(job_id):
await _kill_remote_process(job_id)
return
await asyncio.gather(
_drain(proc.stdout, False),
_drain(proc.stderr, True),
return_exceptions=True,
)
# Bound proc.wait so a remote process that ignored our kill
# signal (or that we never managed to kill) can't pin this
# task in the semaphore forever. Closing the connection on
# exit will deliver SIGPIPE to the remote on its next write.
try:
await asyncio.wait_for(proc.wait(), timeout=15)
except asyncio.TimeoutError:
log.warning(
"proc.wait() timed out for job %d — abandoning channel",
job_id,
)
# Flush only lines we haven't already written in 20-line chunks.
# Previously we appended the FULL accumulated output here too,
# doubling the stored log_text size for every surface_validate
# stage and pushing app.db into hundreds of MB.
flushed_count = (len(output_lines) // 20) * 20
tail = "".join(output_lines[flushed_count:])
if tail:
await _append_stage_log(job_id, "surface_validate", tail)
result["bad_blocks"] = bad_blocks_total
result["output"] = "".join(output_lines) # in-memory only, not re-stored
result["aborted"] = bad_blocks_total > settings.bad_block_threshold
except asyncio.CancelledError:
return False
except Exception as exc:
await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n")
await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}")
return False
await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"])
if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold:
await _set_stage_error(
job_id, "surface_validate",
f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found "
f"(threshold: {settings.bad_block_threshold})"
)
return False
return True
async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool:
"""
Surface validation via TrueNAS CORE disk.wipe REST API.
Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable.
Sends a FULL write-zero pass across the entire disk, polls progress,
then runs a post-wipe SMART attribute check to catch reallocated sectors.
"""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] TrueNAS disk.wipe FULL — {devname}\n"
f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n"
)
# Start the wipe job
try:
tn_job_id = await _client.wipe_disk(devname, "FULL")
except Exception as exc:
await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}")
return False
await _append_stage_log(
job_id, "surface_validate",
f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n"
)
# Poll until complete
log_flush_counter = 0
while True:
if await _is_cancelled(job_id):
try:
await _client.abort_job(tn_job_id)
except Exception:
pass
return False
await asyncio.sleep(POLL_INTERVAL)
try:
job = await _client.get_job(tn_job_id)
except Exception as exc:
log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id})
await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n")
continue
if not job:
await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found")
return False
state = job.get("state", "")
pct = int(job.get("progress", {}).get("percent", 0) or 0)
desc = job.get("progress", {}).get("description", "")
await _update_stage_percent(job_id, "surface_validate", min(pct, 99))
await _recalculate_progress(job_id)
_push_update()
# Log progress description every ~5 polls to avoid DB spam
log_flush_counter += 1
if desc and log_flush_counter % 5 == 0:
await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n")
if state == "SUCCESS":
await _update_stage_percent(job_id, "surface_validate", 100)
await _append_stage_log(
job_id, "surface_validate",
f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n"
)
# Post-wipe SMART check — catch any sectors that failed under write stress
if ssh_client.is_configured() and drive_id is not None:
await _append_stage_log(
job_id, "surface_validate",
"[CHECK] Running post-wipe SMART attribute check...\n"
)
try:
attrs = await ssh_client.get_smart_attributes(devname)
await _store_smart_attrs(drive_id, attrs)
if attrs["failures"]:
error = "Post-wipe SMART check: " + "; ".join(attrs["failures"])
await _set_stage_error(job_id, "surface_validate", error)
return False
if attrs["warnings"]:
await _append_stage_log(
job_id, "surface_validate",
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
)
await _append_stage_log(
job_id, "surface_validate",
f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n"
)
except Exception as exc:
log.warning("Post-wipe SMART check failed: %s", exc)
await _append_stage_log(
job_id, "surface_validate",
f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n"
)
return True
elif state in ("FAILED", "ABORTED", "ERROR"):
error_msg = job.get("error") or f"Disk wipe failed (state={state})"
await _set_stage_error(
job_id, "surface_validate",
f"TrueNAS disk.wipe FAILED: {error_msg}"
)
return False
# RUNNING or WAITING — keep polling
async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool:
"""Simulate a timed stage with progress updates (mock / dev mode)."""
start = time.monotonic()
while True:
if await _is_cancelled(job_id):
return False
elapsed = time.monotonic() - start
pct = min(100, int(elapsed / duration_seconds * 100))
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id, None)
_push_update()
if pct >= 100:
return True
await asyncio.sleep(POLL_INTERVAL)
async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool:
"""
Verify drive passed all tests.
SSH mode: run smartctl -a and check critical attributes.
Mock mode: check SMART health field in DB.
A transient SSH connectivity failure here must NOT invalidate a prior
multi-day surface_validate. Retry SSH-only failures, then soft-pass.
"""
await asyncio.sleep(1)
from app import ssh_client
def _ssh_only(failures: list[str]) -> bool:
return bool(failures) and all(f.startswith("SSH error:") for f in failures)
if ssh_client.is_configured() and drive_id is not None:
try:
attrs = await ssh_client.get_smart_attributes(devname)
for attempt in range(2):
if not _ssh_only(attrs.get("failures") or []):
break
log.warning(
"final_check SSH unreachable (attempt %d/3); retrying in 30s",
attempt + 1,
extra={"job_id": job_id, "devname": devname},
)
await asyncio.sleep(30)
attrs = await ssh_client.get_smart_attributes(devname)
failures = attrs.get("failures") or []
if _ssh_only(failures):
log.warning(
"final_check soft-pass: SSH unreachable after retries; prior stages stand",
extra={"job_id": job_id, "devname": devname, "ssh_error": failures},
)
return True
await _store_smart_attrs(drive_id, attrs)
if attrs["health"] == "FAILED" or failures:
msg = failures or [f"SMART health: {attrs['health']}"]
await _set_stage_error(job_id, "final_check",
"Final check failed: " + "; ".join(msg))
return False
return True
except Exception as exc:
log.warning("SSH final_check raised, falling back to DB check: %s", exc)
# DB check (mock mode fallback)
async with _db() as db:
cur = await db.execute(
"SELECT smart_health FROM drives WHERE devname=?", (devname,)
)
row = await cur.fetchone()
if not row or row[0] == "FAILED":
await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in")
return False
return True
# ---------------------------------------------------------------------------
# DB helpers
# ---------------------------------------------------------------------------
async def _is_cancelled(job_id: int) -> bool:
async with _db() as db:
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
row = await cur.fetchone()
return bool(row and row[0] == "cancelled")
async def _start_stage(job_id: int, stage_name: str) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?",
(_now(), job_id, stage_name),
)
await db.execute(
"UPDATE burnin_jobs SET stage_name=? WHERE id=?",
(stage_name, job_id),
)
await db.commit()
async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None:
now = _now()
state = "passed" if success else "failed"
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?",
(job_id, stage_name),
)
row = await cur.fetchone()
duration = None
if row and row[0]:
try:
start = datetime.fromisoformat(row[0])
if start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
duration = (datetime.now(timezone.utc) - start).total_seconds()
except Exception:
pass
# Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote
if error_text is not None:
await db.execute(
"""UPDATE burnin_stages
SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=?
WHERE burnin_job_id=? AND stage_name=?""",
(state, 100 if success else None, now, duration, error_text, job_id, stage_name),
)
else:
await db.execute(
"""UPDATE burnin_stages
SET state=?, percent=?, finished_at=?, duration_seconds=?
WHERE burnin_job_id=? AND stage_name=?""",
(state, 100 if success else None, now, duration, job_id, stage_name),
)
await db.commit()
async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?",
(pct, job_id, stage_name),
)
await db.commit()
async def _cancel_stage(job_id: int, stage_name: str) -> None:
now = _now()
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?",
(now, job_id, stage_name),
)
await db.commit()
async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None:
"""Append text to the log_text column of a burnin_stages row."""
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"""UPDATE burnin_stages
SET log_text = COALESCE(log_text, '') || ?
WHERE burnin_job_id=? AND stage_name=?""",
(text, job_id, stage_name),
)
await db.commit()
async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?",
(count, job_id, stage_name),
)
await db.commit()
async def _store_smart_attrs(drive_id: int, attrs: dict) -> None:
"""Persist latest SMART attribute dict to drives.smart_attrs (JSON)."""
import json
# Convert int keys to str for JSON serialisation
serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()}
blob = json.dumps({
"health": attrs.get("health", "UNKNOWN"),
"attrs": serialisable,
"warnings": attrs.get("warnings", []),
"failures": attrs.get("failures", []),
})
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id))
await db.commit()
async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None:
"""Store raw smartctl output in smart_tests.raw_output."""
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?",
(raw, drive_id, test_type.lower()),
)
await db.commit()
async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?",
(error_text, job_id, stage_name),
)
await db.commit()
async def _recalculate_progress(job_id: int, profile: str | None = None) -> None:
"""Recompute overall job % from actual stage rows. profile param is unused (kept for compat)."""
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
(job_id,),
)
stages = await cur.fetchall()
if not stages:
return
total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages)
if total_weight == 0:
return
completed = 0.0
current = None
for s in stages:
w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5)
st = s["state"]
if st == "passed":
completed += w
elif st == "running":
completed += w * (s["percent"] or 0) / 100
current = s["stage_name"]
pct = int(completed / total_weight * 100)
await db.execute(
"UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?",
(pct, current, job_id),
)
await db.commit()
# ---------------------------------------------------------------------------
# SSE push
# ---------------------------------------------------------------------------
def _push_update(alert: dict | None = None) -> None:
"""Notify SSE subscribers that data has changed, with optional browser notification payload."""
try:
from app import poller
poller._notify_subscribers(alert=alert)
except Exception:
pass
# DB helpers / progress / SSE re-exported from _common above.
# ---------------------------------------------------------------------------

277
app/burnin/_common.py Normal file
View file

@ -0,0 +1,277 @@
"""Shared helpers for the burnin package.
Lives below stages.py / task.py / __init__.py these all import from
here. _common itself imports nothing from sibling burnin modules so we
stay free of circular-import landmines.
Owns:
* Stage configuration constants (STAGE_ORDER, _STAGE_BASE_WEIGHTS,
POLL_INTERVAL).
* The connection-helper context manager `_db()` and the `_now()` ISO
timestamp helper used everywhere.
* Per-stage DB mutators called by stage implementations and by the
job orchestrator (`_start_stage`, `_finish_stage`, `_cancel_stage`,
`_set_stage_error`, `_update_stage_percent`,
`_update_stage_bad_blocks`, `_append_stage_log`).
* Drive-row mutators for SMART caches
(`_store_smart_attrs`, `_store_smart_raw_output`).
* The job-state read (`_is_cancelled`) + progress aggregator
(`_recalculate_progress`).
* SSE notifier (`_push_update`).
"""
from __future__ import annotations
import json
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
import aiosqlite
from app.config import settings
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Stage configuration
# ---------------------------------------------------------------------------
STAGE_ORDER: dict[str, list[str]] = {
# Legacy
"quick": ["precheck", "short_smart", "io_validate", "final_check"],
# Single-stage selectable profiles
"surface": ["precheck", "surface_validate", "final_check"],
"short": ["precheck", "short_smart", "final_check"],
"long": ["precheck", "long_smart", "final_check"],
# Two-stage combos
"surface_short": ["precheck", "surface_validate", "short_smart", "final_check"],
"surface_long": ["precheck", "surface_validate", "long_smart", "final_check"],
"short_long": ["precheck", "short_smart", "long_smart", "final_check"],
# All three
"full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"],
}
# Per-stage base weights used to compute overall job % progress dynamically
_STAGE_BASE_WEIGHTS: dict[str, int] = {
"precheck": 5,
"surface_validate": 65,
"short_smart": 12,
"long_smart": 13,
"io_validate": 10,
"final_check": 5,
}
POLL_INTERVAL = 5.0 # seconds between progress checks during active stages
# ---------------------------------------------------------------------------
# Connection helpers
# ---------------------------------------------------------------------------
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
@asynccontextmanager
async def _db():
"""Open a WAL-mode connection with busy_timeout so writers wait for the lock
instead of immediately raising 'database is locked' under contention."""
async with aiosqlite.connect(settings.db_path) as db:
await db.execute("PRAGMA busy_timeout=10000")
yield db
# ---------------------------------------------------------------------------
# Job / stage DB mutators
# ---------------------------------------------------------------------------
async def _is_cancelled(job_id: int) -> bool:
async with _db() as db:
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
row = await cur.fetchone()
return bool(row and row[0] == "cancelled")
async def _start_stage(job_id: int, stage_name: str) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?",
(_now(), job_id, stage_name),
)
await db.execute(
"UPDATE burnin_jobs SET stage_name=? WHERE id=?",
(stage_name, job_id),
)
await db.commit()
async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None:
now = _now()
state = "passed" if success else "failed"
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?",
(job_id, stage_name),
)
row = await cur.fetchone()
duration = None
if row and row[0]:
try:
start = datetime.fromisoformat(row[0])
if start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
duration = (datetime.now(timezone.utc) - start).total_seconds()
except Exception:
pass
# Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote
if error_text is not None:
await db.execute(
"""UPDATE burnin_stages
SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=?
WHERE burnin_job_id=? AND stage_name=?""",
(state, 100 if success else None, now, duration, error_text, job_id, stage_name),
)
else:
await db.execute(
"""UPDATE burnin_stages
SET state=?, percent=?, finished_at=?, duration_seconds=?
WHERE burnin_job_id=? AND stage_name=?""",
(state, 100 if success else None, now, duration, job_id, stage_name),
)
await db.commit()
async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?",
(pct, job_id, stage_name),
)
await db.commit()
async def _cancel_stage(job_id: int, stage_name: str) -> None:
now = _now()
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?",
(now, job_id, stage_name),
)
await db.commit()
async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None:
"""Append text to the log_text column of a burnin_stages row."""
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"""UPDATE burnin_stages
SET log_text = COALESCE(log_text, '') || ?
WHERE burnin_job_id=? AND stage_name=?""",
(text, job_id, stage_name),
)
await db.commit()
async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?",
(count, job_id, stage_name),
)
await db.commit()
async def _store_smart_attrs(drive_id: int, attrs: dict) -> None:
"""Persist latest SMART attribute dict to drives.smart_attrs (JSON)."""
# Convert int keys to str for JSON serialisation
serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()}
blob = json.dumps({
"health": attrs.get("health", "UNKNOWN"),
"attrs": serialisable,
"warnings": attrs.get("warnings", []),
"failures": attrs.get("failures", []),
})
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id))
await db.commit()
async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None:
"""Store raw smartctl output in smart_tests.raw_output."""
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?",
(raw, drive_id, test_type.lower()),
)
await db.commit()
async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?",
(error_text, job_id, stage_name),
)
await db.commit()
async def _recalculate_progress(job_id: int, profile: str | None = None) -> None:
"""Recompute overall job % from actual stage rows. profile param is unused (kept for compat)."""
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
(job_id,),
)
stages = await cur.fetchall()
if not stages:
return
total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages)
if total_weight == 0:
return
completed = 0.0
current = None
for s in stages:
w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5)
st = s["state"]
if st == "passed":
completed += w
elif st == "running":
completed += w * (s["percent"] or 0) / 100
current = s["stage_name"]
pct = int(completed / total_weight * 100)
await db.execute(
"UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?",
(pct, current, job_id),
)
await db.commit()
# ---------------------------------------------------------------------------
# SSE notifier
# ---------------------------------------------------------------------------
def _push_update(alert: dict | None = None) -> None:
"""Notify SSE subscribers that data has changed, with optional browser notification payload."""
try:
from app import poller
poller._notify_subscribers(alert=alert)
except Exception:
pass

735
app/burnin/stages.py Normal file
View file

@ -0,0 +1,735 @@
"""Per-stage burn-in implementations.
Each ``_stage_*`` function runs to completion or returns False. They share
state (DB, helpers, configuration) via ``app.burnin._common`` and pull
the live ``TrueNASClient`` instance lazily from the package root so the
extraction stays free of circular imports at module load.
``_dispatch_stage`` is the per-stage_name router used by the orchestrator
in ``app.burnin.__init__._execute_stages``.
"""
from __future__ import annotations
import asyncio
import logging
import time
from app.config import settings
from . import kill
from ._common import (
POLL_INTERVAL,
_append_stage_log,
_db,
_is_cancelled,
_push_update,
_recalculate_progress,
_set_stage_error,
_store_smart_attrs,
_store_smart_raw_output,
_update_stage_bad_blocks,
_update_stage_percent,
)
log = logging.getLogger(__name__)
def _get_client():
"""Lazy access to the TrueNASClient set by ``burnin.init()``. Lives on
the package root for backward compat with routes.py which reaches
for ``burnin._client`` directly."""
from app import burnin
return burnin._client
async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool:
if stage_name == "precheck":
return await _stage_precheck(job_id, drive_id)
elif stage_name == "short_smart":
return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id)
elif stage_name == "long_smart":
return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id)
elif stage_name == "surface_validate":
return await _stage_surface_validate(job_id, devname, drive_id)
elif stage_name == "io_validate":
return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds)
elif stage_name == "final_check":
return await _stage_final_check(job_id, devname, drive_id)
return True
# ---------------------------------------------------------------------------
# Individual stage implementations
# ---------------------------------------------------------------------------
async def _stage_precheck(job_id: int, drive_id: int) -> bool:
"""Check SMART health and temperature before starting destructive work."""
async with _db() as db:
cur = await db.execute(
"SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,)
)
row = await cur.fetchone()
if not row:
return False
health, temp = row[0], row[1]
if health == "FAILED":
await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in")
return False
if temp and temp > settings.temp_crit_c:
await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit")
return False
await asyncio.sleep(1) # Simulate brief check
return True
async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str,
drive_id: int | None = None) -> bool:
"""Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise."""
from app import ssh_client
if ssh_client.is_configured():
return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id)
return await _stage_smart_test_api(job_id, devname, test_type, stage_name)
async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool:
"""TrueNAS REST API path for SMART test (mock / dev mode)."""
tn_job_id = await _get_client().start_smart_test([devname], test_type)
while True:
if await _is_cancelled(job_id):
try:
await _get_client().abort_job(tn_job_id)
except Exception:
pass
return False
jobs = await _get_client().get_smart_jobs()
job = next((j for j in jobs if j["id"] == tn_job_id), None)
if not job:
return False
state = job["state"]
pct = job["progress"]["percent"]
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id, None)
_push_update()
if state == "SUCCESS":
return True
elif state in ("FAILED", "ABORTED"):
await _set_stage_error(job_id, stage_name,
job.get("error") or f"SMART {test_type} test failed")
return False
await asyncio.sleep(POLL_INTERVAL)
async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str,
drive_id: int | None) -> bool:
"""SSH path for SMART test — runs smartctl directly on TrueNAS."""
from app import ssh_client
# Start the test
try:
startup = await ssh_client.start_smart_test(devname, test_type)
await _append_stage_log(job_id, stage_name, startup + "\n")
except Exception as exc:
await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}")
return False
# Brief pause to let the test register in smartctl output
await asyncio.sleep(3)
# Throttle log_text appends — every poll on a multi-hour long_smart bloated
# log_text to 50+ MB and triggered SQLite "database is locked" because each
# COALESCE-then-append rewrites the whole column. Append every ~60s, on the
# first poll, and on any state change.
LOG_EVERY_N_POLLS = 12
poll_count = 0
last_state: str | None = None
# Poll until complete
while True:
if await _is_cancelled(job_id):
try:
await ssh_client.abort_smart_test(devname)
except Exception:
pass
return False
await asyncio.sleep(POLL_INTERVAL)
try:
progress = await ssh_client.poll_smart_progress(devname)
except Exception as exc:
log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id})
await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n")
continue
poll_count += 1
state_changed = progress["state"] != last_state
last_state = progress["state"]
if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed:
await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n")
if progress["state"] == "running":
pct = max(0, 100 - progress["percent_remaining"])
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id)
_push_update()
elif progress["state"] == "passed":
await _update_stage_percent(job_id, stage_name, 100)
# Run attribute check
if drive_id is not None:
try:
attrs = await ssh_client.get_smart_attributes(devname)
await _store_smart_attrs(drive_id, attrs)
await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"])
if attrs["failures"]:
error = "SMART attribute failures: " + "; ".join(attrs["failures"])
await _set_stage_error(job_id, stage_name, error)
return False
if attrs["warnings"]:
await _append_stage_log(
job_id, stage_name,
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
)
except Exception as exc:
log.warning("Failed to retrieve SMART attributes: %s", exc)
await _recalculate_progress(job_id)
_push_update()
return True
elif progress["state"] == "failed":
await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed")
return False
# "unknown" → keep polling
async def _badblocks_available() -> bool:
"""Check if badblocks is installed on the remote host (Linux/SCALE only)."""
from app import ssh_client
try:
async with await ssh_client._connect() as conn:
result = await conn.run("which badblocks", check=False)
return result.returncode == 0
except Exception:
return False
async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool:
"""
Surface validation stage auto-routes to the right implementation:
1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE):
`nvme format -s 1 /dev/{devname}` (cryptographic erase).
Far faster than badblocks on NVMe (seconds vs hours) and
exercises the controller's secure-erase path, not just user-LBA
writes.
2. SSH configured + badblocks available (TrueNAS SCALE / Linux):
badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH.
3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD):
uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check.
4. No SSH:
simulated timed progress (dev/mock mode).
"""
from app import ssh_client
if ssh_client.is_configured():
if devname.startswith("nvme") and await _nvme_cli_available():
return await _stage_surface_validate_nvme(job_id, devname, drive_id)
if await _badblocks_available():
return await _stage_surface_validate_ssh(job_id, devname, drive_id)
# TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API
await _append_stage_log(
job_id, "surface_validate",
"[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — "
"using TrueNAS disk.wipe API (FULL write pass).\n\n"
)
return await _stage_surface_validate_truenas(job_id, devname, drive_id)
return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds)
async def _nvme_cli_available() -> bool:
"""Check if nvme-cli is installed on the remote host."""
from app import ssh_client
try:
async with await ssh_client._connect() as conn:
r = await conn.run("which nvme", check=False)
return r.returncode == 0
except Exception:
return False
async def _stage_surface_validate_nvme(job_id: int, devname: str,
drive_id: int) -> bool:
"""NVMe destructive surface test via `nvme format -s 1` (crypto erase).
Crypto-erase nukes the data encryption key on the drive's controller,
rendering all stored data unrecoverable in milliseconds; the actual
flash is then implicitly trim-able. This is the canonical destructive
burn-in for NVMe badblocks would write the entire LBA space, which
is slower AND wears the flash unnecessarily.
Post-format we re-read SMART attributes; the drive should report all
counters reset (life used + spare) and PASSED health.
"""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] nvme format -s 1 /dev/{devname}\n"
f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n"
)
cmd = f"nvme format -s 1 --force /dev/{devname}"
try:
async with await ssh_client._connect() as conn:
r = await asyncio.wait_for(
conn.run(cmd, check=False), timeout=600
)
except Exception as exc:
await _append_stage_log(
job_id, "surface_validate", f"\n[SSH error] {exc}\n"
)
await _set_stage_error(
job_id, "surface_validate", f"NVMe format SSH error: {exc}"
)
return False
output = (r.stdout or "") + (r.stderr or "")
await _append_stage_log(job_id, "surface_validate", output + "\n")
if r.returncode != 0:
await _set_stage_error(
job_id, "surface_validate",
f"nvme format exited {r.returncode}: {output.strip()[:200]}"
)
return False
# Sanity-check post-format SMART health. Mirrors the surface_validate
# SSH path's check parity — fail on FAILED health, fail on real
# SMART attribute failures, log warnings but don't fail. A transport
# error here is treated as a soft pass (log + continue) so a single
# SSH blip after a successful format doesn't undo the work.
try:
attrs = await ssh_client.get_smart_attributes(devname)
ssh_only_failures = [
f for f in (attrs.get("failures") or []) if f.startswith("SSH error:")
]
real_failures = [
f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:")
]
if attrs.get("health") == "FAILED":
await _set_stage_error(
job_id, "surface_validate",
"NVMe SMART health FAILED after format",
)
return False
if real_failures:
await _set_stage_error(
job_id, "surface_validate",
"NVMe SMART attribute failures after format: "
+ "; ".join(real_failures),
)
return False
if ssh_only_failures:
await _append_stage_log(
job_id, "surface_validate",
"[WARN] post-format SMART check had SSH errors "
"(soft-passing): " + "; ".join(ssh_only_failures) + "\n",
)
if attrs.get("warnings"):
await _append_stage_log(
job_id, "surface_validate",
"[WARN] " + "; ".join(attrs["warnings"]) + "\n",
)
except Exception as exc:
log.warning("Post-format SMART check error on %s: %s", devname, exc)
await _append_stage_log(
job_id, "surface_validate",
f"[WARN] post-format SMART check raised: {exc}\n",
)
await _update_stage_percent(job_id, "surface_validate", 100)
await _recalculate_progress(job_id)
_push_update()
return True
async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool:
"""Run badblocks over SSH, streaming output to stage log."""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] badblocks -wsv -b {settings.surface_validate_block_size} "
f"-c {settings.surface_validate_block_buffer} "
f"-p {settings.surface_validate_passes} /dev/{devname}\n"
f"[NOTE] This is a DESTRUCTIVE write test. "
f"All data on /dev/{devname} will be overwritten.\n\n"
)
def _is_cancelled_sync() -> bool:
# Synchronous version — we check the DB state flag set by cancel_job()
import asyncio
loop = asyncio.get_event_loop()
try:
return loop.run_until_complete(_is_cancelled(job_id))
except Exception:
return False
last_logged_pct = [-1]
def on_progress(pct: int, bad_blocks: int, line: str) -> None:
nonlocal last_logged_pct
# Write to log (fire-and-forget via asyncio.create_task from sync context)
# The log append is done in the async flush below
pass
accumulated_lines: list[str] = []
async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None:
accumulated_lines.append(line)
# Flush to DB and update progress every ~25 lines to avoid excessive DB writes
if len(accumulated_lines) % 25 == 0:
await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:]))
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks)
await _update_stage_percent(job_id, "surface_validate", pct)
await _recalculate_progress(job_id)
_push_update()
if await _is_cancelled(job_id):
raise asyncio.CancelledError
# Run badblocks — we adapt the callback pattern to async by collecting then flushing
result = {"bad_blocks": 0, "output": "", "aborted": False}
try:
# The actual streaming; we handle progress via the accumulated_lines pattern
bad_blocks_total = 0
output_lines: list[str] = []
async with await ssh_client._connect() as conn:
# Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote
# PID on the first stdout line. asyncssh's proc.kill() sends an
# SSH signal request that OpenSSH's sshd ignores by default, so
# we need the PID to issue an out-of-band `kill -9` over a fresh
# session when we want to abort.
#
# Block geometry is operator-tunable (Settings → Burn-in):
# -b N block size in bytes (settings.surface_validate_block_size)
# -c N blocks held per IO (settings.surface_validate_block_buffer)
# -p N pass count (settings.surface_validate_passes)
# Defaults preserve original behavior (-b 4096 -c 64 -p 1).
bb_args = (
f"-wsv "
f"-b {settings.surface_validate_block_size} "
f"-c {settings.surface_validate_block_buffer} "
f"-p {settings.surface_validate_passes}"
)
cmd = (
f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'"
)
async with conn.create_process(cmd) as proc:
import re as _re
pid_seen = False
async def _drain(stream, is_stderr: bool):
nonlocal bad_blocks_total, pid_seen
async for raw in stream:
line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace")
# First stdout line is "PID:<n>" from the wrapping shell.
# Capture it and don't append it to the user-visible log.
if not is_stderr and not pid_seen and line.startswith("PID:"):
pid_seen = True
try:
kill.set_remote_pid(job_id, int(line[4:].strip()))
log.info(
"Captured remote PID %d for job %d (badblocks)",
kill.get_remote_pid(job_id), job_id,
)
except ValueError:
pass
continue
output_lines.append(line)
if is_stderr:
m = _re.search(r"([\d.]+)%\s+done", line)
if m:
pct = min(99, int(float(m.group(1))))
await _update_stage_percent(job_id, "surface_validate", pct)
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total)
await _recalculate_progress(job_id)
_push_update()
else:
stripped = line.strip()
if stripped and stripped.isdigit():
bad_blocks_total += 1
# Append to DB log in chunks
if len(output_lines) % 20 == 0:
chunk = "".join(output_lines[-20:])
await _append_stage_log(job_id, "surface_validate", chunk)
# Abort on bad block threshold
if bad_blocks_total > settings.bad_block_threshold:
await kill.kill_remote_process(job_id)
output_lines.append(
f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded "
f"threshold ({settings.bad_block_threshold})\n"
)
return
if await _is_cancelled(job_id):
await kill.kill_remote_process(job_id)
return
await asyncio.gather(
_drain(proc.stdout, False),
_drain(proc.stderr, True),
return_exceptions=True,
)
# Bound proc.wait so a remote process that ignored our kill
# signal (or that we never managed to kill) can't pin this
# task in the semaphore forever. Closing the connection on
# exit will deliver SIGPIPE to the remote on its next write.
try:
await asyncio.wait_for(proc.wait(), timeout=15)
except asyncio.TimeoutError:
log.warning(
"proc.wait() timed out for job %d — abandoning channel",
job_id,
)
# Flush only lines we haven't already written in 20-line chunks.
# Previously we appended the FULL accumulated output here too,
# doubling the stored log_text size for every surface_validate
# stage and pushing app.db into hundreds of MB.
flushed_count = (len(output_lines) // 20) * 20
tail = "".join(output_lines[flushed_count:])
if tail:
await _append_stage_log(job_id, "surface_validate", tail)
result["bad_blocks"] = bad_blocks_total
result["output"] = "".join(output_lines) # in-memory only, not re-stored
result["aborted"] = bad_blocks_total > settings.bad_block_threshold
except asyncio.CancelledError:
return False
except Exception as exc:
await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n")
await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}")
return False
await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"])
if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold:
await _set_stage_error(
job_id, "surface_validate",
f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found "
f"(threshold: {settings.bad_block_threshold})"
)
return False
return True
async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool:
"""
Surface validation via TrueNAS CORE disk.wipe REST API.
Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable.
Sends a FULL write-zero pass across the entire disk, polls progress,
then runs a post-wipe SMART attribute check to catch reallocated sectors.
"""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] TrueNAS disk.wipe FULL — {devname}\n"
f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n"
)
# Start the wipe job
try:
tn_job_id = await _get_client().wipe_disk(devname, "FULL")
except Exception as exc:
await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}")
return False
await _append_stage_log(
job_id, "surface_validate",
f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n"
)
# Poll until complete
log_flush_counter = 0
while True:
if await _is_cancelled(job_id):
try:
await _get_client().abort_job(tn_job_id)
except Exception:
pass
return False
await asyncio.sleep(POLL_INTERVAL)
try:
job = await _get_client().get_job(tn_job_id)
except Exception as exc:
log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id})
await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n")
continue
if not job:
await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found")
return False
state = job.get("state", "")
pct = int(job.get("progress", {}).get("percent", 0) or 0)
desc = job.get("progress", {}).get("description", "")
await _update_stage_percent(job_id, "surface_validate", min(pct, 99))
await _recalculate_progress(job_id)
_push_update()
# Log progress description every ~5 polls to avoid DB spam
log_flush_counter += 1
if desc and log_flush_counter % 5 == 0:
await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n")
if state == "SUCCESS":
await _update_stage_percent(job_id, "surface_validate", 100)
await _append_stage_log(
job_id, "surface_validate",
f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n"
)
# Post-wipe SMART check — catch any sectors that failed under write stress
if ssh_client.is_configured() and drive_id is not None:
await _append_stage_log(
job_id, "surface_validate",
"[CHECK] Running post-wipe SMART attribute check...\n"
)
try:
attrs = await ssh_client.get_smart_attributes(devname)
await _store_smart_attrs(drive_id, attrs)
if attrs["failures"]:
error = "Post-wipe SMART check: " + "; ".join(attrs["failures"])
await _set_stage_error(job_id, "surface_validate", error)
return False
if attrs["warnings"]:
await _append_stage_log(
job_id, "surface_validate",
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
)
await _append_stage_log(
job_id, "surface_validate",
f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n"
)
except Exception as exc:
log.warning("Post-wipe SMART check failed: %s", exc)
await _append_stage_log(
job_id, "surface_validate",
f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n"
)
return True
elif state in ("FAILED", "ABORTED", "ERROR"):
error_msg = job.get("error") or f"Disk wipe failed (state={state})"
await _set_stage_error(
job_id, "surface_validate",
f"TrueNAS disk.wipe FAILED: {error_msg}"
)
return False
# RUNNING or WAITING — keep polling
async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool:
"""Simulate a timed stage with progress updates (mock / dev mode)."""
start = time.monotonic()
while True:
if await _is_cancelled(job_id):
return False
elapsed = time.monotonic() - start
pct = min(100, int(elapsed / duration_seconds * 100))
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id, None)
_push_update()
if pct >= 100:
return True
await asyncio.sleep(POLL_INTERVAL)
async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool:
"""
Verify drive passed all tests.
SSH mode: run smartctl -a and check critical attributes.
Mock mode: check SMART health field in DB.
A transient SSH connectivity failure here must NOT invalidate a prior
multi-day surface_validate. Retry SSH-only failures, then soft-pass.
"""
await asyncio.sleep(1)
from app import ssh_client
def _ssh_only(failures: list[str]) -> bool:
return bool(failures) and all(f.startswith("SSH error:") for f in failures)
if ssh_client.is_configured() and drive_id is not None:
try:
attrs = await ssh_client.get_smart_attributes(devname)
for attempt in range(2):
if not _ssh_only(attrs.get("failures") or []):
break
log.warning(
"final_check SSH unreachable (attempt %d/3); retrying in 30s",
attempt + 1,
extra={"job_id": job_id, "devname": devname},
)
await asyncio.sleep(30)
attrs = await ssh_client.get_smart_attributes(devname)
failures = attrs.get("failures") or []
if _ssh_only(failures):
log.warning(
"final_check soft-pass: SSH unreachable after retries; prior stages stand",
extra={"job_id": job_id, "devname": devname, "ssh_error": failures},
)
return True
await _store_smart_attrs(drive_id, attrs)
if attrs["health"] == "FAILED" or failures:
msg = failures or [f"SMART health: {attrs['health']}"]
await _set_stage_error(job_id, "final_check",
"Final check failed: " + "; ".join(msg))
return False
return True
except Exception as exc:
log.warning("SSH final_check raised, falling back to DB check: %s", exc)
# DB check (mock mode fallback)
async with _db() as db:
cur = await db.execute(
"SELECT smart_health FROM drives WHERE devname=?", (devname,)
)
row = await cur.fetchone()
if not row or row[0] == "FAILED":
await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in")
return False
return True

View file

@ -83,7 +83,7 @@ class Settings(BaseSettings):
ssh_key: str = "" # PEM private key content (paste full key including headers)
# Application version — used by the /api/v1/updates/check endpoint
app_version: str = "1.0.0-30"
app_version: str = "1.0.0-31"
# ---- Authentication (1.0.0-22) ----
# session_secret: HMAC key for signing session cookies. Empty = generate