""" SSH client for direct TrueNAS command execution (Stage 7). When ssh_host is configured, burn-in stages use SSH to run smartctl and badblocks directly on the TrueNAS host instead of going through the REST API. Falls back to REST API / simulation when SSH is not configured (dev/mock mode). TrueNAS CORE (FreeBSD) device paths: /dev/ada0, /dev/da0, etc. TrueNAS SCALE (Linux) device paths: /dev/sda, /dev/sdb, etc. The devname from the TrueNAS API is used as-is in /dev/{devname}. """ import asyncio import logging import re from typing import Callable log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Monitored SMART attributes # True → any non-zero raw value is a hard failure (drive rejected) # False → non-zero is a warning (flagged but test continues) # --------------------------------------------------------------------------- SMART_ATTRS: dict[int, tuple[str, bool]] = { 5: ("Reallocated_Sector_Ct", True), # reallocation = FAIL 10: ("Spin_Retry_Count", False), # mechanical stress = WARN 188: ("Command_Timeout", False), # drive not responding = WARN 197: ("Current_Pending_Sector", True), # pending reallocation = FAIL 198: ("Offline_Uncorrectable", True), # unrecoverable read error = FAIL 199: ("UDMA_CRC_Error_Count", False), # cable/controller issue = WARN } # --------------------------------------------------------------------------- # Configuration check # --------------------------------------------------------------------------- def is_configured() -> bool: """Returns True when SSH credentials are present and usable.""" from app.config import settings return bool(settings.ssh_host and (settings.ssh_password or settings.ssh_key)) # --------------------------------------------------------------------------- # Low-level connection # --------------------------------------------------------------------------- async def _connect(): """Open a single-use SSH connection. Caller must use `async with`.""" import asyncssh from app.config import settings kwargs: dict = { "host": settings.ssh_host, "port": settings.ssh_port, "username": settings.ssh_user, "known_hosts": None, # trust all hosts (same spirit as TRUENAS_VERIFY_TLS=false) } if settings.ssh_key: kwargs["client_keys"] = [asyncssh.import_private_key(settings.ssh_key)] if settings.ssh_password: kwargs["password"] = settings.ssh_password return asyncssh.connect(**kwargs) # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- async def test_connection() -> dict: """Test SSH connectivity. Returns {"ok": True} or {"ok": False, "error": str}.""" if not is_configured(): return {"ok": False, "error": "SSH not configured (ssh_host is empty)"} try: async with await _connect() as conn: result = await conn.run("echo ok", check=False) if "ok" in result.stdout: return {"ok": True} return {"ok": False, "error": result.stderr.strip() or "unexpected output"} except Exception as exc: return {"ok": False, "error": str(exc)} async def get_smart_attributes(devname: str) -> dict: """ Run `smartctl -a /dev/{devname}` and parse the output. Returns: health: str — "PASSED" | "FAILED" | "UNKNOWN" raw_output: str — full smartctl output attributes: dict[int, {"name": str, "raw": int}] warnings: list[str] — attribute names with non-zero raw (non-critical) failures: list[str] — attribute names with non-zero raw (critical) """ cmd = f"smartctl -a /dev/{devname}" try: async with await _connect() as conn: result = await conn.run(cmd, check=False) output = result.stdout + result.stderr return _parse_smartctl(output) except Exception as exc: return { "health": "UNKNOWN", "raw_output": str(exc), "attributes": {}, "warnings": [], "failures": [f"SSH error: {exc}"], } async def start_smart_test(devname: str, test_type: str) -> str: """ Run `smartctl -t short|long /dev/{devname}`. Returns raw output. Raises RuntimeError on unrecoverable failure. test_type: "SHORT" or "LONG" """ arg = "short" if test_type.upper() == "SHORT" else "long" cmd = f"smartctl -t {arg} /dev/{devname}" async with await _connect() as conn: result = await conn.run(cmd, check=False) output = result.stdout + result.stderr # smartctl exits 0 or 4 when the test is successfully started on most drives started = ("Testing has begun" in output or "test has begun" in output.lower() or result.returncode in (0, 4)) if not started: raise RuntimeError(f"smartctl returned exit {result.returncode}: {output[:400]}") return output async def poll_smart_progress(devname: str) -> dict: """ Run `smartctl -a /dev/{devname}` and extract self-test status. Returns: state: "running" | "passed" | "failed" | "unknown" percent_remaining: int (0 = complete when state != "running") output: str """ cmd = f"smartctl -a /dev/{devname}" async with await _connect() as conn: result = await conn.run(cmd, check=False) output = result.stdout + result.stderr return _parse_smart_progress(output) async def abort_smart_test(devname: str) -> None: """Send `smartctl -X /dev/{devname}` to abort an in-progress test.""" cmd = f"smartctl -X /dev/{devname}" async with await _connect() as conn: await conn.run(cmd, check=False) async def run_badblocks( devname: str, on_progress: Callable[[int, int, str], None], cancelled_fn: Callable[[], bool] | None = None, ) -> dict: """ Run `badblocks -wsv -b 4096 -p 1 /dev/{devname}` and stream output. on_progress(percent, bad_blocks, line) is called for each line of output. cancelled_fn() is polled to support mid-test cancellation. Returns: {"bad_blocks": int, "output": str, "aborted": bool} """ from app.config import settings cmd = f"badblocks -wsv -b 4096 -p 1 /dev/{devname}" lines: list[str] = [] bad_blocks = 0 aborted = False last_pct = 0 try: async with await _connect() as conn: async with conn.create_process(cmd) as proc: # badblocks writes progress to stderr, bad block numbers to stdout async def _read_stream(stream, is_stderr: bool): nonlocal bad_blocks, last_pct, aborted async for raw_line in stream: line = raw_line if isinstance(raw_line, str) else raw_line.decode("utf-8", errors="replace") lines.append(line) if is_stderr: m = re.search(r"([\d.]+)%\s+done", line) if m: last_pct = min(99, int(float(m.group(1)))) else: # Each non-empty stdout line during badblocks is a bad block number stripped = line.strip() if stripped and stripped.isdigit(): bad_blocks += 1 on_progress(last_pct, bad_blocks, line) # Abort if threshold exceeded if bad_blocks > settings.bad_block_threshold: aborted = True proc.kill() lines.append( f"\n[ABORTED] Bad block count ({bad_blocks}) exceeded " f"threshold ({settings.bad_block_threshold})\n" ) return # Abort on cancellation if cancelled_fn and cancelled_fn(): aborted = True proc.kill() return stdout_task = asyncio.create_task(_read_stream(proc.stdout, False)) stderr_task = asyncio.create_task(_read_stream(proc.stderr, True)) await asyncio.gather(stdout_task, stderr_task, return_exceptions=True) await proc.wait() except Exception as exc: lines.append(f"\n[SSH error] {exc}\n") if not aborted: last_pct = 100 return { "bad_blocks": bad_blocks, "output": "".join(lines), "aborted": aborted, } # --------------------------------------------------------------------------- # Parsers # --------------------------------------------------------------------------- def _parse_smartctl(output: str) -> dict: health = "UNKNOWN" attributes: dict[int, dict] = {} warnings: list[str] = [] failures: list[str] = [] m = re.search(r"self-assessment test result:\s+(\w+)", output, re.IGNORECASE) if m: health = m.group(1).upper() # Attribute table: ID# NAME FLAG VALUE WORST THRESH TYPE UPDATED WHEN_FAILED RAW_VALUE for line in output.splitlines(): am = re.match( r"\s*(\d+)\s+(\S+)\s+\S+\s+\d+\s+\d+\s+\d+\s+\S+\s+\S+\s+\S+\s+(\d+)", line, ) if not am: continue attr_id = int(am.group(1)) attr_name = am.group(2) raw_val = int(am.group(3)) attributes[attr_id] = {"name": attr_name, "raw": raw_val} if attr_id in SMART_ATTRS: _, is_critical = SMART_ATTRS[attr_id] if raw_val > 0: msg = f"{attr_name} = {raw_val}" if is_critical: failures.append(msg) else: warnings.append(msg) return { "health": health, "raw_output": output, "attributes": attributes, "warnings": warnings, "failures": failures, } def _parse_smart_progress(output: str) -> dict: state = "unknown" percent_remaining = 0 lower = output.lower() if "self-test routine in progress" in lower or "self-test routine in progress" in output: state = "running" m = re.search(r"(\d+)%\s+of\s+test\s+remaining", output, re.IGNORECASE) if m: percent_remaining = int(m.group(1)) elif "completed without error" in lower: state = "passed" elif ( "completed: read failure" in lower or "completed: write failure" in lower or "aborted by host" in lower or ("completed" in lower and "failure" in lower) ): state = "failed" elif "in progress" in lower: state = "running" return { "state": state, "percent_remaining": percent_remaining, "output": output, }