fix: live pool re-check before start_job + drop dead run_badblocks (1.0.0-29)
Some checks are pending
Security scan / pip-audit (push) Waiting to run
Security scan / bandit (push) Waiting to run
Security scan / gitleaks (push) Waiting to run

Closes the last open Codex finding (#5) and removes one piece of dead
code Codex flagged in passing.

#5 — Live pool re-check before burn-in start:
  Before this change, _is_unlocked compared the operator's unlock grant
  against the cached drives.pool_* row. If a drive was imported into a
  pool, mounted, or had ZFS labels written between the operator's
  unlock click and the next ~12s poll, burn-in could still start
  against the stale identity and silently destroy the new pool.

  start_job now calls a fresh ssh_client.fresh_pool_check_for_drive()
  immediately after the cached gate. That helper re-runs the three
  detection probes (zpool list -vHP / lsblk zfs_member / findmnt) over
  a fresh SSH session and returns the live answer for one devname.
  If it differs from cached state we invalidate any existing unlock
  grant and raise PoolMemberError with the FRESH pool name so the UI
  reflects current reality. If fresh shows free but cached said locked
  the drive came back to free since last poll — log it and allow.

  Cost: ~200ms per burn-in start. For batch starts of 12 drives, that's
  2.4s extra latency — cheap against destroying a freshly-imported pool.

Dead code removal:
  ssh_client.run_badblocks() — no callers since 1.0.0-13 when the SSH
  badblocks logic was inlined into burnin._stage_surface_validate_ssh
  (with the asyncssh-signal-doesn't-actually-kill workaround). Removing
  the dead function also lets us drop the now-unused
  `from typing import Callable` import.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Brandon Walter 2026-05-02 21:29:11 -04:00
parent 066fbbc403
commit 6c20e57fd8
3 changed files with 69 additions and 79 deletions

View file

@ -389,7 +389,7 @@ async def start_job(drive_id: int, profile: str, operator: str,
# exported drive doesn't carry over if the drive turns out to be
# in an active pool on the next poll.
cur = await db.execute(
"SELECT pool_name, pool_role FROM drives WHERE id=?", (drive_id,)
"SELECT pool_name, pool_role, devname FROM drives WHERE id=?", (drive_id,)
)
drow = await cur.fetchone()
if drow and drow["pool_name"] and not _is_unlocked(
@ -397,6 +397,41 @@ async def start_job(drive_id: int, profile: str, operator: str,
):
raise PoolMemberError(drive_id, drow["pool_name"], drow["pool_role"])
# Closes Codex finding #5: re-check pool state OVER SSH right now,
# not against cached row. Defends against the 12s poll window
# where a drive could have been imported into a pool, mounted, or
# had ZFS labels written between when the operator unlocked it
# and when they clicked Start. Adds ~200ms per start; cheap
# against the cost of destroying a freshly-imported pool.
if drow:
from app import ssh_client as _ssh
if _ssh.is_configured():
fresh = await _ssh.fresh_pool_check_for_drive(drow["devname"])
cached = (
{"pool": drow["pool_name"], "role": drow["pool_role"]}
if drow["pool_name"] else None
)
if fresh != cached:
# State changed since the last poll. Invalidate any
# unlock grant (it was bound to stale identity) and
# refuse with a descriptive error so the operator
# knows to wait for the next poll cycle.
_unlock_grants.pop(drive_id, None)
fresh_pool = fresh["pool"] if fresh else None
fresh_role = fresh["role"] if fresh else None
if fresh_pool:
raise PoolMemberError(drive_id, fresh_pool, fresh_role)
# If the FRESH check shows free but cached said
# locked, the drive was just removed from a pool —
# safe to start, but invalidate any stale grant so
# the operator doesn't reuse old confirmations.
log.warning(
"Live pool check for drive_id=%d (%s): cached=%s "
"fresh=None — drive came free since last poll, "
"allowing burn-in",
drive_id, drow["devname"], cached,
)
# Create job. The partial unique index uniq_active_burnin_per_drive
# (database.py) is the actual race-stopper here: if two concurrent
# /api/v1/burnin/start calls both pass the SELECT-COUNT check above,

View file

@ -83,7 +83,7 @@ class Settings(BaseSettings):
ssh_key: str = "" # PEM private key content (paste full key including headers)
# Application version — used by the /api/v1/updates/check endpoint
app_version: str = "1.0.0-28"
app_version: str = "1.0.0-29"
# ---- Authentication (1.0.0-22) ----
# session_secret: HMAC key for signing session cookies. Empty = generate

View file

@ -13,7 +13,6 @@ The devname from the TrueNAS API is used as-is in /dev/{devname}.
import asyncio
import logging
import re
from typing import Callable
log = logging.getLogger(__name__)
@ -171,82 +170,6 @@ async def abort_smart_test(devname: str) -> None:
await conn.run(cmd, check=False)
async def run_badblocks(
devname: str,
on_progress: Callable[[int, int, str], None],
cancelled_fn: Callable[[], bool] | None = None,
) -> dict:
"""
Run `badblocks -wsv -b 4096 -p 1 /dev/{devname}` and stream output.
on_progress(percent, bad_blocks, line) is called for each line of output.
cancelled_fn() is polled to support mid-test cancellation.
Returns: {"bad_blocks": int, "output": str, "aborted": bool}
"""
from app.config import settings
cmd = f"badblocks -wsv -b 4096 -p 1 /dev/{devname}"
lines: list[str] = []
bad_blocks = 0
aborted = False
last_pct = 0
try:
async with await _connect() as conn:
async with conn.create_process(cmd) as proc:
# badblocks writes progress to stderr, bad block numbers to stdout
async def _read_stream(stream, is_stderr: bool):
nonlocal bad_blocks, last_pct, aborted
async for raw_line in stream:
line = raw_line if isinstance(raw_line, str) else raw_line.decode("utf-8", errors="replace")
lines.append(line)
if is_stderr:
m = re.search(r"([\d.]+)%\s+done", line)
if m:
last_pct = min(99, int(float(m.group(1))))
else:
# Each non-empty stdout line during badblocks is a bad block number
stripped = line.strip()
if stripped and stripped.isdigit():
bad_blocks += 1
on_progress(last_pct, bad_blocks, line)
# Abort if threshold exceeded
if bad_blocks > settings.bad_block_threshold:
aborted = True
proc.kill()
lines.append(
f"\n[ABORTED] Bad block count ({bad_blocks}) exceeded "
f"threshold ({settings.bad_block_threshold})\n"
)
return
# Abort on cancellation
if cancelled_fn and cancelled_fn():
aborted = True
proc.kill()
return
stdout_task = asyncio.create_task(_read_stream(proc.stdout, False))
stderr_task = asyncio.create_task(_read_stream(proc.stderr, True))
await asyncio.gather(stdout_task, stderr_task, return_exceptions=True)
await proc.wait()
except Exception as exc:
lines.append(f"\n[SSH error] {exc}\n")
if not aborted:
last_pct = 100
return {
"bad_blocks": bad_blocks,
"output": "".join(lines),
"aborted": aborted,
}
def _parse_zpool_list_output(stdout: str) -> dict:
"""Pure parser for `zpool list -vHP` stdout. Exposed for unit tests.
@ -428,6 +351,38 @@ def _parse_findmnt_sources(stdout: str) -> set:
return out
async def fresh_pool_check_for_drive(devname: str) -> dict | None:
"""Live, on-demand re-detection of one drive's pool/mounted state.
Re-runs `zpool list -vHP`, `lsblk` (zfs_member), and `findmnt` over a
fresh SSH session and returns whichever entry matches `devname`,
falling back to None if the drive is genuinely free right now.
Closes the poll-window gap between an operator unlock and the next
cached state refresh used as a final gate inside burnin.start_job
so a drive that was imported into a pool after unlock but before the
next poll can't slip through.
Return shape: {"pool": str, "role": str} | None.
Returns None on SSH failure too caller should treat None
skeptically and only act on it if cached state ALSO says None.
"""
if not is_configured() or not devname:
return None
pm = await get_pool_membership()
if pm is None:
return None
if devname in pm:
return pm[devname]
zs = await get_zfs_member_drives()
if zs is not None and devname in zs:
return {"pool": "(exported)", "role": "exported"}
ms = await get_mounted_drives()
if ms is not None and devname in ms:
return {"pool": "(mounted)", "role": "mounted"}
return None
async def get_smart_health_map(devnames: list[str]) -> dict | None:
"""Return {devname: 'PASSED'|'FAILED'|'UNKNOWN'} for every devname.