nas-burnin/app/ssh_client.py
Brandon Walter 6c20e57fd8
Some checks are pending
Security scan / pip-audit (push) Waiting to run
Security scan / bandit (push) Waiting to run
Security scan / gitleaks (push) Waiting to run
fix: live pool re-check before start_job + drop dead run_badblocks (1.0.0-29)
Closes the last open Codex finding (#5) and removes one piece of dead
code Codex flagged in passing.

#5 — Live pool re-check before burn-in start:
  Before this change, _is_unlocked compared the operator's unlock grant
  against the cached drives.pool_* row. If a drive was imported into a
  pool, mounted, or had ZFS labels written between the operator's
  unlock click and the next ~12s poll, burn-in could still start
  against the stale identity and silently destroy the new pool.

  start_job now calls a fresh ssh_client.fresh_pool_check_for_drive()
  immediately after the cached gate. That helper re-runs the three
  detection probes (zpool list -vHP / lsblk zfs_member / findmnt) over
  a fresh SSH session and returns the live answer for one devname.
  If it differs from cached state we invalidate any existing unlock
  grant and raise PoolMemberError with the FRESH pool name so the UI
  reflects current reality. If fresh shows free but cached said locked
  the drive came back to free since last poll — log it and allow.

  Cost: ~200ms per burn-in start. For batch starts of 12 drives, that's
  2.4s extra latency — cheap against destroying a freshly-imported pool.

Dead code removal:
  ssh_client.run_badblocks() — no callers since 1.0.0-13 when the SSH
  badblocks logic was inlined into burnin._stage_surface_validate_ssh
  (with the asyncssh-signal-doesn't-actually-kill workaround). Removing
  the dead function also lets us drop the now-unused
  `from typing import Callable` import.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-02 21:29:11 -04:00

627 lines
23 KiB
Python

"""
SSH client for direct TrueNAS command execution (Stage 7).
When ssh_host is configured, burn-in stages use SSH to run smartctl and
badblocks directly on the TrueNAS host instead of going through the REST API.
Falls back to REST API / simulation when SSH is not configured (dev/mock mode).
TrueNAS CORE (FreeBSD) device paths: /dev/ada0, /dev/da0, etc.
TrueNAS SCALE (Linux) device paths: /dev/sda, /dev/sdb, etc.
The devname from the TrueNAS API is used as-is in /dev/{devname}.
"""
import asyncio
import logging
import re
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Monitored SMART attributes
# True → any non-zero raw value is a hard failure (drive rejected)
# False → non-zero is a warning (flagged but test continues)
# ---------------------------------------------------------------------------
SMART_ATTRS: dict[int, tuple[str, bool]] = {
5: ("Reallocated_Sector_Ct", True), # reallocation = FAIL
10: ("Spin_Retry_Count", False), # mechanical stress = WARN
188: ("Command_Timeout", False), # drive not responding = WARN
197: ("Current_Pending_Sector", True), # pending reallocation = FAIL
198: ("Offline_Uncorrectable", True), # unrecoverable read error = FAIL
199: ("UDMA_CRC_Error_Count", False), # cable/controller issue = WARN
}
# ---------------------------------------------------------------------------
# Configuration check
# ---------------------------------------------------------------------------
def is_configured() -> bool:
"""Returns True when SSH host + at least one auth method is available."""
import os
from app.config import settings
if not settings.ssh_host:
return False
has_creds = bool(
settings.ssh_key
or settings.ssh_password
or os.path.exists(os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH))
)
return has_creds
# ---------------------------------------------------------------------------
# Low-level connection
# ---------------------------------------------------------------------------
_MOUNTED_KEY_PATH = "/run/secrets/ssh_key"
async def _connect():
"""Open a single-use SSH connection. Caller must use `async with`."""
import asyncssh
from app.config import settings
kwargs: dict = {
"host": settings.ssh_host,
"port": settings.ssh_port,
"username": settings.ssh_user,
"known_hosts": None, # trust all hosts (same spirit as TRUENAS_VERIFY_TLS=false)
}
if settings.ssh_key:
# Key material provided via env var (base case)
kwargs["client_keys"] = [asyncssh.import_private_key(settings.ssh_key)]
elif settings.ssh_password:
kwargs["password"] = settings.ssh_password
else:
# Fall back to mounted key file (preferred for production — no key in env vars)
import os
key_path = os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)
if os.path.exists(key_path):
kwargs["client_keys"] = [key_path]
# If nothing is configured, asyncssh will attempt agent/default key lookup
return asyncssh.connect(**kwargs)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def test_connection() -> dict:
"""Test SSH connectivity. Returns {"ok": True} or {"ok": False, "error": str}."""
if not is_configured():
return {"ok": False, "error": "SSH not configured (ssh_host is empty)"}
try:
async with await _connect() as conn:
result = await conn.run("echo ok", check=False)
if "ok" in result.stdout:
return {"ok": True}
return {"ok": False, "error": result.stderr.strip() or "unexpected output"}
except Exception as exc:
return {"ok": False, "error": str(exc)}
async def get_smart_attributes(devname: str) -> dict:
"""
Run `smartctl -a /dev/{devname}` and parse the output.
Returns:
health: str — "PASSED" | "FAILED" | "UNKNOWN"
raw_output: str — full smartctl output
attributes: dict[int, {"name": str, "raw": int}]
warnings: list[str] — attribute names with non-zero raw (non-critical)
failures: list[str] — attribute names with non-zero raw (critical)
"""
cmd = f"smartctl -a /dev/{devname}"
try:
async with await _connect() as conn:
result = await conn.run(cmd, check=False)
output = result.stdout + result.stderr
return _parse_smartctl(output)
except Exception as exc:
return {
"health": "UNKNOWN",
"raw_output": str(exc),
"attributes": {},
"warnings": [],
"failures": [f"SSH error: {exc}"],
}
async def start_smart_test(devname: str, test_type: str) -> str:
"""
Run `smartctl -t short|long /dev/{devname}`.
Returns raw output. Raises RuntimeError on unrecoverable failure.
test_type: "SHORT" or "LONG"
"""
arg = "short" if test_type.upper() == "SHORT" else "long"
cmd = f"smartctl -t {arg} /dev/{devname}"
async with await _connect() as conn:
result = await conn.run(cmd, check=False)
output = result.stdout + result.stderr
# smartctl exits 0 or 4 when the test is successfully started on most drives
started = ("Testing has begun" in output or
"test has begun" in output.lower() or
result.returncode in (0, 4))
if not started:
raise RuntimeError(f"smartctl returned exit {result.returncode}: {output[:400]}")
return output
async def poll_smart_progress(devname: str) -> dict:
"""
Run `smartctl -a /dev/{devname}` and extract self-test status.
Returns:
state: "running" | "passed" | "failed" | "unknown"
percent_remaining: int (0 = complete when state != "running")
output: str
"""
cmd = f"smartctl -a /dev/{devname}"
async with await _connect() as conn:
result = await conn.run(cmd, check=False)
output = result.stdout + result.stderr
return _parse_smart_progress(output)
async def abort_smart_test(devname: str) -> None:
"""Send `smartctl -X /dev/{devname}` to abort an in-progress test."""
cmd = f"smartctl -X /dev/{devname}"
async with await _connect() as conn:
await conn.run(cmd, check=False)
def _parse_zpool_list_output(stdout: str) -> dict:
"""Pure parser for `zpool list -vHP` stdout. Exposed for unit tests.
See get_pool_membership() for output semantics. This function never
raises — malformed lines are silently skipped.
"""
import re as _re
def _strip_partition(name: str) -> str:
m = _re.match(r"^(nvme\d+n\d+)", name)
if m:
return m.group(1)
m = _re.match(r"^(sd[a-z]+)", name)
if m:
return m.group(1)
return name
SECTION_MARKERS = {"cache", "log", "logs", "spare", "spares",
"special", "dedup"}
SECTION_NORMALIZE = {"logs": "log", "spares": "spare"}
out: dict = {}
current_pool: str | None = None
current_role: str = "data"
for raw in stdout.splitlines():
if not raw.strip():
continue
depth = 0
while depth < len(raw) and raw[depth] == "\t":
depth += 1
first = raw[depth:].split("\t", 1)[0].strip()
if depth == 0:
current_pool = first
current_role = "data"
continue
if depth == 1:
if first in SECTION_MARKERS:
current_role = SECTION_NORMALIZE.get(first, first)
continue
if first.startswith(("mirror", "raidz", "draid")):
continue
if first.startswith("/dev/") and current_pool:
dn = _strip_partition(first[len("/dev/"):])
out[dn] = {"pool": current_pool, "role": current_role}
continue
if first.startswith("/dev/") and current_pool:
dn = _strip_partition(first[len("/dev/"):])
out[dn] = {"pool": current_pool, "role": current_role}
return out
def _parse_lsblk_zfs_output(stdout: str) -> set:
"""Pure parser for `lsblk -no NAME,FSTYPE -l` stdout. Returns base
devnames carrying ZFS labels (whole-disk OR via partition). Exposed
for unit tests."""
import re as _re
out: set = set()
for line in stdout.splitlines():
parts = line.split()
if len(parts) < 2:
continue
name, fstype = parts[0], parts[1]
if fstype != "zfs_member":
continue
if name.startswith("nvme"):
m = _re.match(r"^(nvme\d+n\d+)", name)
if m:
out.add(m.group(1))
else:
m = _re.match(r"^(sd[a-z]+)", name)
if m:
out.add(m.group(1))
return out
async def get_pool_membership() -> dict | None:
"""Return {devname: {"pool": str, "role": str}} for every drive in any zpool.
Parses `zpool list -vHP` output. Tab-indent depth tells us structure:
depth 0 pool name line
depth 1 vdev type line (mirror-N, raidz*N, draid*) OR section
marker (cache/log/spare/special/dedup/logs) OR a single-disk
vdev that is itself a /dev/... entry
depth 2 device line within a vdev — '/dev/sdX', '/dev/nvmeXnY', etc.
may have a partition suffix that we strip back to the
base devname so it matches what TrueNAS reports.
Roles: data | cache | log | spare | special | dedup
Returns:
- {} when the SSH call succeeded and there are genuinely no pools
- None on any failure (SSH down, parse error, non-zero exit, no
stdout). Callers MUST treat None differently from {}: an
empty dict is "definitely no pool members," None is "we
couldn't tell." Treating None as "no pool members" is a
fail-open security regression.
"""
import re as _re
if not is_configured():
return {}
cmd = "zpool list -vHP 2>/dev/null"
try:
async with await _connect() as conn:
r = await conn.run(cmd, check=False)
if r.returncode != 0:
return None
except Exception:
return None
if not r.stdout:
# rc==0 with empty output = host has no pools. (`zpool list -H`
# returns no rows when zero pools are imported.) That's a real
# answer, not a failure.
return {}
return _parse_zpool_list_output(r.stdout)
async def get_mounted_drives() -> set | None:
"""Return base devnames of every drive whose partitions are mounted
anywhere right now. Defense-in-depth on top of pool detection — catches
XFS/ext4/etc. scratch disks the operator forgot about. Returns None on
any failure (caller treats that as 'preserve previous state')."""
if not is_configured():
return set()
cmd = "findmnt -no SOURCE 2>/dev/null"
try:
async with await _connect() as conn:
r = await conn.run(cmd, check=False)
if r.returncode != 0 or not r.stdout:
# findmnt always has at least / mounted on a Linux host;
# empty output is itself suspicious. Treat as failure.
return None
except Exception:
return None
return _parse_findmnt_sources(r.stdout)
def _parse_findmnt_sources(stdout: str) -> set:
"""Pure parser for findmnt output. Strips partitions; ignores tmpfs,
overlay, zfs (zfs is handled by pool detection).
Recognised devnames (covers TrueNAS SCALE + CORE + LVM/MD stacks):
sd[a-z]+ — Linux SCSI/SATA (sda, sdb, ..., sdaa)
nvmeXnY[pZ] — Linux NVMe namespaces
mapper/<name> — LVM logical volumes (/dev/mapper/vg-lv)
dm-N — devicemapper short names
mdN — Linux MD RAID arrays
ada[0-9]+, da[0-9]+ — TrueNAS CORE (FreeBSD) SATA/SAS
"""
import re as _re
out: set = set()
for raw in stdout.splitlines():
s = raw.strip()
if not s.startswith("/dev/"):
continue
# Skip ZFS filesystems (those are pool/exported drives, handled
# separately and shouldn't double-lock as 'mounted').
if "/dev/zd" in s or "/dev/zvol" in s:
continue
name = s[len("/dev/"):].split("[")[0] # bind mounts can have [subdir]
# Try each recognised devname pattern in order. Mapper/dm-/md
# entries are kept whole because they represent a stack the
# operator should resolve manually before burn-in.
for pat in (
r"^(nvme\d+n\d+)", # NVMe (strip pN)
r"^(sd[a-z]+)", # Linux SCSI/SATA (strip number)
r"^(mapper/[^/]+)", # LVM logical volume
r"^(dm-\d+)", # devicemapper short name
r"^(md\d+)", # MD RAID
r"^(ada\d+)", # FreeBSD SATA
r"^(da\d+)", # FreeBSD SAS/SCSI
):
m = _re.match(pat, name)
if m:
out.add(m.group(1))
break
return out
async def fresh_pool_check_for_drive(devname: str) -> dict | None:
"""Live, on-demand re-detection of one drive's pool/mounted state.
Re-runs `zpool list -vHP`, `lsblk` (zfs_member), and `findmnt` over a
fresh SSH session and returns whichever entry matches `devname`,
falling back to None if the drive is genuinely free right now.
Closes the poll-window gap between an operator unlock and the next
cached state refresh — used as a final gate inside burnin.start_job
so a drive that was imported into a pool after unlock but before the
next poll can't slip through.
Return shape: {"pool": str, "role": str} | None.
Returns None on SSH failure too — caller should treat None
skeptically and only act on it if cached state ALSO says None.
"""
if not is_configured() or not devname:
return None
pm = await get_pool_membership()
if pm is None:
return None
if devname in pm:
return pm[devname]
zs = await get_zfs_member_drives()
if zs is not None and devname in zs:
return {"pool": "(exported)", "role": "exported"}
ms = await get_mounted_drives()
if ms is not None and devname in ms:
return {"pool": "(mounted)", "role": "mounted"}
return None
async def get_smart_health_map(devnames: list[str]) -> dict | None:
"""Return {devname: 'PASSED'|'FAILED'|'UNKNOWN'} for every devname.
Runs `smartctl -H` for each disk in a single SSH session — much faster
than one connection per disk. Returns None on any SSH failure so the
poller can fall back to the previously-stored health value rather than
silently overwriting everything as 'UNKNOWN'.
`smartctl -H` is the cheap SMART self-assessment lookup (no full
attribute scan) — milliseconds per drive. The output format is stable:
SMART overall-health self-assessment test result: PASSED
SMART overall-health self-assessment test result: FAILED!
For drives that don't support the command at all, smartctl exits
non-zero and we record UNKNOWN for that device specifically.
"""
if not is_configured() or not devnames:
return {} if devnames else None
# Build one shell pipeline that prefixes each result with "@@DEVNAME@@"
# so we can split the combined stdout deterministically.
parts = []
for d in devnames:
# Reject anything that doesn't look like a basic devname so we
# never inject shell metacharacters into the remote command.
if not d.replace("nvme", "").replace("n", "").replace("p", "").replace("sd", "").isalnum():
continue
parts.append(f"echo '@@{d}@@'; smartctl -H /dev/{d} 2>&1; echo '@@END@@'")
if not parts:
return {}
cmd = "; ".join(parts)
try:
async with await _connect() as conn:
r = await asyncio.wait_for(conn.run(cmd, check=False), timeout=30)
except Exception:
return None
if not r.stdout:
return None
return _parse_smart_health_batch(r.stdout)
def _parse_smart_health_batch(stdout: str) -> dict:
"""Pure parser for the batched smartctl -H output. Exposed for tests."""
result: dict[str, str] = {}
current: str | None = None
buf: list[str] = []
def _flush():
if current is None:
return
text = "\n".join(buf)
if "PASSED" in text:
result[current] = "PASSED"
elif "FAILED" in text or "FAILURE" in text:
result[current] = "FAILED"
else:
result[current] = "UNKNOWN"
for raw in stdout.splitlines():
line = raw.strip()
if line.startswith("@@") and line.endswith("@@"):
inner = line[2:-2]
if inner == "END":
_flush()
current = None
buf = []
else:
_flush()
current = inner
buf = []
else:
buf.append(line)
_flush()
return result
async def get_zfs_member_drives() -> set | None:
"""Return devnames of every drive whose partitions carry a ZFS label.
Combined with get_pool_membership(): a drive in this set but NOT in the
active-pool map carries ZFS data from a previously-imported pool that
was exported (or imported on a different system). We treat those as
locked too — wiping them would silently destroy a pool.
Returns:
- set() when lsblk succeeded and no drives carry ZFS labels
- None on any failure. Same fail-closed semantics as
get_pool_membership() — callers must NOT treat None as
"no exported drives," that's a security regression.
"""
if not is_configured():
return set()
cmd = "lsblk -no NAME,FSTYPE -l 2>/dev/null"
try:
async with await _connect() as conn:
r = await conn.run(cmd, check=False)
if r.returncode != 0:
return None
except Exception:
return None
if not r.stdout:
# lsblk with rc==0 and no output is impossible on a normal Linux
# host; treat as failure rather than "no drives at all."
return None
return _parse_lsblk_zfs_output(r.stdout)
async def get_system_sensors() -> dict:
"""
Run `sensors -j` on TrueNAS and extract system-level temperatures.
Returns {"cpu_c": int|None, "pch_c": int|None}.
cpu_c = CPU package temp (coretemp chip)
pch_c = PCH/chipset temp (pch_* chip) — proxy for storage I/O lane thermals
Falls back gracefully if SSH is not configured or lm-sensors is unavailable.
"""
if not is_configured():
return {}
try:
async with await _connect() as conn:
result = await conn.run("sensors -j 2>/dev/null", check=False)
output = result.stdout.strip()
if not output:
return {}
return _parse_sensors_json(output)
except Exception as exc:
log.debug("get_system_sensors failed: %s", exc)
return {}
def _parse_sensors_json(output: str) -> dict:
import json as _json
try:
data = _json.loads(output)
except Exception:
return {}
cpu_c: int | None = None
pch_c: int | None = None
for chip_name, chip_data in data.items():
if not isinstance(chip_data, dict):
continue
# CPU package temp — coretemp chip, "Package id N" sensor
if chip_name.startswith("coretemp") and cpu_c is None:
for sensor_name, sensor_vals in chip_data.items():
if not isinstance(sensor_vals, dict):
continue
if "package" in sensor_name.lower():
for k, v in sensor_vals.items():
if k.endswith("_input") and isinstance(v, (int, float)):
cpu_c = int(round(v))
break
if cpu_c is not None:
break
# PCH / chipset temp — manages PCIe lanes including HBA / storage I/O
elif chip_name.startswith("pch_") and pch_c is None:
for sensor_name, sensor_vals in chip_data.items():
if not isinstance(sensor_vals, dict):
continue
for k, v in sensor_vals.items():
if k.endswith("_input") and isinstance(v, (int, float)):
pch_c = int(round(v))
break
if pch_c is not None:
break
return {"cpu_c": cpu_c, "pch_c": pch_c}
# ---------------------------------------------------------------------------
# Parsers
# ---------------------------------------------------------------------------
def _parse_smartctl(output: str) -> dict:
health = "UNKNOWN"
attributes: dict[int, dict] = {}
warnings: list[str] = []
failures: list[str] = []
m = re.search(r"self-assessment test result:\s+(\w+)", output, re.IGNORECASE)
if m:
health = m.group(1).upper()
# Attribute table: ID# NAME FLAG VALUE WORST THRESH TYPE UPDATED WHEN_FAILED RAW_VALUE
for line in output.splitlines():
am = re.match(
r"\s*(\d+)\s+(\S+)\s+\S+\s+\d+\s+\d+\s+\d+\s+\S+\s+\S+\s+\S+\s+(\d+)",
line,
)
if not am:
continue
attr_id = int(am.group(1))
attr_name = am.group(2)
raw_val = int(am.group(3))
attributes[attr_id] = {"name": attr_name, "raw": raw_val}
if attr_id in SMART_ATTRS:
_, is_critical = SMART_ATTRS[attr_id]
if raw_val > 0:
msg = f"{attr_name} = {raw_val}"
if is_critical:
failures.append(msg)
else:
warnings.append(msg)
return {
"health": health,
"raw_output": output,
"attributes": attributes,
"warnings": warnings,
"failures": failures,
}
def _parse_smart_progress(output: str) -> dict:
state = "unknown"
percent_remaining = None # None = "in progress but no % line parsed yet"
lower = output.lower()
if "self-test routine in progress" in lower or "self-test routine in progress" in output:
state = "running"
m = re.search(r"(\d+)%\s+of\s+test\s+remaining", output, re.IGNORECASE)
if m:
percent_remaining = int(m.group(1))
elif "completed without error" in lower:
state = "passed"
elif (
"completed: read failure" in lower
or "completed: write failure" in lower
or "aborted by host" in lower
or ("completed" in lower and "failure" in lower)
):
state = "failed"
elif "in progress" in lower:
state = "running"
return {
"state": state,
"percent_remaining": percent_remaining,
"output": output,
}