""" SSH client for direct TrueNAS command execution (Stage 7). When ssh_host is configured, burn-in stages use SSH to run smartctl and badblocks directly on the TrueNAS host instead of going through the REST API. Falls back to REST API / simulation when SSH is not configured (dev/mock mode). TrueNAS CORE (FreeBSD) device paths: /dev/ada0, /dev/da0, etc. TrueNAS SCALE (Linux) device paths: /dev/sda, /dev/sdb, etc. The devname from the TrueNAS API is used as-is in /dev/{devname}. """ import asyncio import logging import re from typing import Callable log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Monitored SMART attributes # True → any non-zero raw value is a hard failure (drive rejected) # False → non-zero is a warning (flagged but test continues) # --------------------------------------------------------------------------- SMART_ATTRS: dict[int, tuple[str, bool]] = { 5: ("Reallocated_Sector_Ct", True), # reallocation = FAIL 10: ("Spin_Retry_Count", False), # mechanical stress = WARN 188: ("Command_Timeout", False), # drive not responding = WARN 197: ("Current_Pending_Sector", True), # pending reallocation = FAIL 198: ("Offline_Uncorrectable", True), # unrecoverable read error = FAIL 199: ("UDMA_CRC_Error_Count", False), # cable/controller issue = WARN } # --------------------------------------------------------------------------- # Configuration check # --------------------------------------------------------------------------- def is_configured() -> bool: """Returns True when SSH host + at least one auth method is available.""" import os from app.config import settings if not settings.ssh_host: return False has_creds = bool( settings.ssh_key or settings.ssh_password or os.path.exists(os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH)) ) return has_creds # --------------------------------------------------------------------------- # Low-level connection # --------------------------------------------------------------------------- _MOUNTED_KEY_PATH = "/run/secrets/ssh_key" async def _connect(): """Open a single-use SSH connection. Caller must use `async with`.""" import asyncssh from app.config import settings kwargs: dict = { "host": settings.ssh_host, "port": settings.ssh_port, "username": settings.ssh_user, "known_hosts": None, # trust all hosts (same spirit as TRUENAS_VERIFY_TLS=false) } if settings.ssh_key: # Key material provided via env var (base case) kwargs["client_keys"] = [asyncssh.import_private_key(settings.ssh_key)] elif settings.ssh_password: kwargs["password"] = settings.ssh_password else: # Fall back to mounted key file (preferred for production — no key in env vars) import os key_path = os.environ.get("SSH_KEY_FILE", _MOUNTED_KEY_PATH) if os.path.exists(key_path): kwargs["client_keys"] = [key_path] # If nothing is configured, asyncssh will attempt agent/default key lookup return asyncssh.connect(**kwargs) # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- async def test_connection() -> dict: """Test SSH connectivity. Returns {"ok": True} or {"ok": False, "error": str}.""" if not is_configured(): return {"ok": False, "error": "SSH not configured (ssh_host is empty)"} try: async with await _connect() as conn: result = await conn.run("echo ok", check=False) if "ok" in result.stdout: return {"ok": True} return {"ok": False, "error": result.stderr.strip() or "unexpected output"} except Exception as exc: return {"ok": False, "error": str(exc)} async def get_smart_attributes(devname: str) -> dict: """ Run `smartctl -a /dev/{devname}` and parse the output. Returns: health: str — "PASSED" | "FAILED" | "UNKNOWN" raw_output: str — full smartctl output attributes: dict[int, {"name": str, "raw": int}] warnings: list[str] — attribute names with non-zero raw (non-critical) failures: list[str] — attribute names with non-zero raw (critical) """ cmd = f"smartctl -a /dev/{devname}" try: async with await _connect() as conn: result = await conn.run(cmd, check=False) output = result.stdout + result.stderr return _parse_smartctl(output) except Exception as exc: return { "health": "UNKNOWN", "raw_output": str(exc), "attributes": {}, "warnings": [], "failures": [f"SSH error: {exc}"], } async def start_smart_test(devname: str, test_type: str) -> str: """ Run `smartctl -t short|long /dev/{devname}`. Returns raw output. Raises RuntimeError on unrecoverable failure. test_type: "SHORT" or "LONG" """ arg = "short" if test_type.upper() == "SHORT" else "long" cmd = f"smartctl -t {arg} /dev/{devname}" async with await _connect() as conn: result = await conn.run(cmd, check=False) output = result.stdout + result.stderr # smartctl exits 0 or 4 when the test is successfully started on most drives started = ("Testing has begun" in output or "test has begun" in output.lower() or result.returncode in (0, 4)) if not started: raise RuntimeError(f"smartctl returned exit {result.returncode}: {output[:400]}") return output async def poll_smart_progress(devname: str) -> dict: """ Run `smartctl -a /dev/{devname}` and extract self-test status. Returns: state: "running" | "passed" | "failed" | "unknown" percent_remaining: int (0 = complete when state != "running") output: str """ cmd = f"smartctl -a /dev/{devname}" async with await _connect() as conn: result = await conn.run(cmd, check=False) output = result.stdout + result.stderr return _parse_smart_progress(output) async def abort_smart_test(devname: str) -> None: """Send `smartctl -X /dev/{devname}` to abort an in-progress test.""" cmd = f"smartctl -X /dev/{devname}" async with await _connect() as conn: await conn.run(cmd, check=False) async def run_badblocks( devname: str, on_progress: Callable[[int, int, str], None], cancelled_fn: Callable[[], bool] | None = None, ) -> dict: """ Run `badblocks -wsv -b 4096 -p 1 /dev/{devname}` and stream output. on_progress(percent, bad_blocks, line) is called for each line of output. cancelled_fn() is polled to support mid-test cancellation. Returns: {"bad_blocks": int, "output": str, "aborted": bool} """ from app.config import settings cmd = f"badblocks -wsv -b 4096 -p 1 /dev/{devname}" lines: list[str] = [] bad_blocks = 0 aborted = False last_pct = 0 try: async with await _connect() as conn: async with conn.create_process(cmd) as proc: # badblocks writes progress to stderr, bad block numbers to stdout async def _read_stream(stream, is_stderr: bool): nonlocal bad_blocks, last_pct, aborted async for raw_line in stream: line = raw_line if isinstance(raw_line, str) else raw_line.decode("utf-8", errors="replace") lines.append(line) if is_stderr: m = re.search(r"([\d.]+)%\s+done", line) if m: last_pct = min(99, int(float(m.group(1)))) else: # Each non-empty stdout line during badblocks is a bad block number stripped = line.strip() if stripped and stripped.isdigit(): bad_blocks += 1 on_progress(last_pct, bad_blocks, line) # Abort if threshold exceeded if bad_blocks > settings.bad_block_threshold: aborted = True proc.kill() lines.append( f"\n[ABORTED] Bad block count ({bad_blocks}) exceeded " f"threshold ({settings.bad_block_threshold})\n" ) return # Abort on cancellation if cancelled_fn and cancelled_fn(): aborted = True proc.kill() return stdout_task = asyncio.create_task(_read_stream(proc.stdout, False)) stderr_task = asyncio.create_task(_read_stream(proc.stderr, True)) await asyncio.gather(stdout_task, stderr_task, return_exceptions=True) await proc.wait() except Exception as exc: lines.append(f"\n[SSH error] {exc}\n") if not aborted: last_pct = 100 return { "bad_blocks": bad_blocks, "output": "".join(lines), "aborted": aborted, } def _parse_zpool_list_output(stdout: str) -> dict: """Pure parser for `zpool list -vHP` stdout. Exposed for unit tests. See get_pool_membership() for output semantics. This function never raises — malformed lines are silently skipped. """ import re as _re def _strip_partition(name: str) -> str: m = _re.match(r"^(nvme\d+n\d+)", name) if m: return m.group(1) m = _re.match(r"^(sd[a-z]+)", name) if m: return m.group(1) return name SECTION_MARKERS = {"cache", "log", "logs", "spare", "spares", "special", "dedup"} SECTION_NORMALIZE = {"logs": "log", "spares": "spare"} out: dict = {} current_pool: str | None = None current_role: str = "data" for raw in stdout.splitlines(): if not raw.strip(): continue depth = 0 while depth < len(raw) and raw[depth] == "\t": depth += 1 first = raw[depth:].split("\t", 1)[0].strip() if depth == 0: current_pool = first current_role = "data" continue if depth == 1: if first in SECTION_MARKERS: current_role = SECTION_NORMALIZE.get(first, first) continue if first.startswith(("mirror", "raidz", "draid")): continue if first.startswith("/dev/") and current_pool: dn = _strip_partition(first[len("/dev/"):]) out[dn] = {"pool": current_pool, "role": current_role} continue if first.startswith("/dev/") and current_pool: dn = _strip_partition(first[len("/dev/"):]) out[dn] = {"pool": current_pool, "role": current_role} return out def _parse_lsblk_zfs_output(stdout: str) -> set: """Pure parser for `lsblk -no NAME,FSTYPE -l` stdout. Returns base devnames carrying ZFS labels (whole-disk OR via partition). Exposed for unit tests.""" import re as _re out: set = set() for line in stdout.splitlines(): parts = line.split() if len(parts) < 2: continue name, fstype = parts[0], parts[1] if fstype != "zfs_member": continue if name.startswith("nvme"): m = _re.match(r"^(nvme\d+n\d+)", name) if m: out.add(m.group(1)) else: m = _re.match(r"^(sd[a-z]+)", name) if m: out.add(m.group(1)) return out async def get_pool_membership() -> dict | None: """Return {devname: {"pool": str, "role": str}} for every drive in any zpool. Parses `zpool list -vHP` output. Tab-indent depth tells us structure: depth 0 pool name line depth 1 vdev type line (mirror-N, raidz*N, draid*) OR section marker (cache/log/spare/special/dedup/logs) OR a single-disk vdev that is itself a /dev/... entry depth 2 device line within a vdev — '/dev/sdX', '/dev/nvmeXnY', etc. may have a partition suffix that we strip back to the base devname so it matches what TrueNAS reports. Roles: data | cache | log | spare | special | dedup Returns: - {} when the SSH call succeeded and there are genuinely no pools - None on any failure (SSH down, parse error, non-zero exit, no stdout). Callers MUST treat None differently from {}: an empty dict is "definitely no pool members," None is "we couldn't tell." Treating None as "no pool members" is a fail-open security regression. """ import re as _re if not is_configured(): return {} cmd = "zpool list -vHP 2>/dev/null" try: async with await _connect() as conn: r = await conn.run(cmd, check=False) if r.returncode != 0: return None except Exception: return None if not r.stdout: # rc==0 with empty output = host has no pools. (`zpool list -H` # returns no rows when zero pools are imported.) That's a real # answer, not a failure. return {} return _parse_zpool_list_output(r.stdout) async def get_mounted_drives() -> set | None: """Return base devnames of every drive whose partitions are mounted anywhere right now. Defense-in-depth on top of pool detection — catches XFS/ext4/etc. scratch disks the operator forgot about. Returns None on any failure (caller treats that as 'preserve previous state').""" if not is_configured(): return set() cmd = "findmnt -no SOURCE 2>/dev/null" try: async with await _connect() as conn: r = await conn.run(cmd, check=False) if r.returncode != 0 or not r.stdout: # findmnt always has at least / mounted on a Linux host; # empty output is itself suspicious. Treat as failure. return None except Exception: return None return _parse_findmnt_sources(r.stdout) def _parse_findmnt_sources(stdout: str) -> set: """Pure parser for findmnt output. Strips partitions; ignores tmpfs, overlay, zfs (zfs is handled by pool detection).""" import re as _re out: set = set() for raw in stdout.splitlines(): s = raw.strip() if not s.startswith("/dev/"): continue # Skip ZFS filesystems (those are pool/exported drives, handled # separately and shouldn't double-lock as 'mounted'). if "/dev/zd" in s or "/dev/zvol" in s: continue name = s[len("/dev/"):].split("[")[0] # bind mounts can have [subdir] if name.startswith("nvme"): m = _re.match(r"^(nvme\d+n\d+)", name) if m: out.add(m.group(1)) else: m = _re.match(r"^(sd[a-z]+)", name) if m: out.add(m.group(1)) return out async def get_smart_health_map(devnames: list[str]) -> dict | None: """Return {devname: 'PASSED'|'FAILED'|'UNKNOWN'} for every devname. Runs `smartctl -H` for each disk in a single SSH session — much faster than one connection per disk. Returns None on any SSH failure so the poller can fall back to the previously-stored health value rather than silently overwriting everything as 'UNKNOWN'. `smartctl -H` is the cheap SMART self-assessment lookup (no full attribute scan) — milliseconds per drive. The output format is stable: SMART overall-health self-assessment test result: PASSED SMART overall-health self-assessment test result: FAILED! For drives that don't support the command at all, smartctl exits non-zero and we record UNKNOWN for that device specifically. """ if not is_configured() or not devnames: return {} if devnames else None # Build one shell pipeline that prefixes each result with "@@DEVNAME@@" # so we can split the combined stdout deterministically. parts = [] for d in devnames: # Reject anything that doesn't look like a basic devname so we # never inject shell metacharacters into the remote command. if not d.replace("nvme", "").replace("n", "").replace("p", "").replace("sd", "").isalnum(): continue parts.append(f"echo '@@{d}@@'; smartctl -H /dev/{d} 2>&1; echo '@@END@@'") if not parts: return {} cmd = "; ".join(parts) try: async with await _connect() as conn: r = await asyncio.wait_for(conn.run(cmd, check=False), timeout=30) except Exception: return None if not r.stdout: return None return _parse_smart_health_batch(r.stdout) def _parse_smart_health_batch(stdout: str) -> dict: """Pure parser for the batched smartctl -H output. Exposed for tests.""" result: dict[str, str] = {} current: str | None = None buf: list[str] = [] def _flush(): if current is None: return text = "\n".join(buf) if "PASSED" in text: result[current] = "PASSED" elif "FAILED" in text or "FAILURE" in text: result[current] = "FAILED" else: result[current] = "UNKNOWN" for raw in stdout.splitlines(): line = raw.strip() if line.startswith("@@") and line.endswith("@@"): inner = line[2:-2] if inner == "END": _flush() current = None buf = [] else: _flush() current = inner buf = [] else: buf.append(line) _flush() return result async def get_zfs_member_drives() -> set | None: """Return devnames of every drive whose partitions carry a ZFS label. Combined with get_pool_membership(): a drive in this set but NOT in the active-pool map carries ZFS data from a previously-imported pool that was exported (or imported on a different system). We treat those as locked too — wiping them would silently destroy a pool. Returns: - set() when lsblk succeeded and no drives carry ZFS labels - None on any failure. Same fail-closed semantics as get_pool_membership() — callers must NOT treat None as "no exported drives," that's a security regression. """ if not is_configured(): return set() cmd = "lsblk -no NAME,FSTYPE -l 2>/dev/null" try: async with await _connect() as conn: r = await conn.run(cmd, check=False) if r.returncode != 0: return None except Exception: return None if not r.stdout: # lsblk with rc==0 and no output is impossible on a normal Linux # host; treat as failure rather than "no drives at all." return None return _parse_lsblk_zfs_output(r.stdout) async def get_system_sensors() -> dict: """ Run `sensors -j` on TrueNAS and extract system-level temperatures. Returns {"cpu_c": int|None, "pch_c": int|None}. cpu_c = CPU package temp (coretemp chip) pch_c = PCH/chipset temp (pch_* chip) — proxy for storage I/O lane thermals Falls back gracefully if SSH is not configured or lm-sensors is unavailable. """ if not is_configured(): return {} try: async with await _connect() as conn: result = await conn.run("sensors -j 2>/dev/null", check=False) output = result.stdout.strip() if not output: return {} return _parse_sensors_json(output) except Exception as exc: log.debug("get_system_sensors failed: %s", exc) return {} def _parse_sensors_json(output: str) -> dict: import json as _json try: data = _json.loads(output) except Exception: return {} cpu_c: int | None = None pch_c: int | None = None for chip_name, chip_data in data.items(): if not isinstance(chip_data, dict): continue # CPU package temp — coretemp chip, "Package id N" sensor if chip_name.startswith("coretemp") and cpu_c is None: for sensor_name, sensor_vals in chip_data.items(): if not isinstance(sensor_vals, dict): continue if "package" in sensor_name.lower(): for k, v in sensor_vals.items(): if k.endswith("_input") and isinstance(v, (int, float)): cpu_c = int(round(v)) break if cpu_c is not None: break # PCH / chipset temp — manages PCIe lanes including HBA / storage I/O elif chip_name.startswith("pch_") and pch_c is None: for sensor_name, sensor_vals in chip_data.items(): if not isinstance(sensor_vals, dict): continue for k, v in sensor_vals.items(): if k.endswith("_input") and isinstance(v, (int, float)): pch_c = int(round(v)) break if pch_c is not None: break return {"cpu_c": cpu_c, "pch_c": pch_c} # --------------------------------------------------------------------------- # Parsers # --------------------------------------------------------------------------- def _parse_smartctl(output: str) -> dict: health = "UNKNOWN" attributes: dict[int, dict] = {} warnings: list[str] = [] failures: list[str] = [] m = re.search(r"self-assessment test result:\s+(\w+)", output, re.IGNORECASE) if m: health = m.group(1).upper() # Attribute table: ID# NAME FLAG VALUE WORST THRESH TYPE UPDATED WHEN_FAILED RAW_VALUE for line in output.splitlines(): am = re.match( r"\s*(\d+)\s+(\S+)\s+\S+\s+\d+\s+\d+\s+\d+\s+\S+\s+\S+\s+\S+\s+(\d+)", line, ) if not am: continue attr_id = int(am.group(1)) attr_name = am.group(2) raw_val = int(am.group(3)) attributes[attr_id] = {"name": attr_name, "raw": raw_val} if attr_id in SMART_ATTRS: _, is_critical = SMART_ATTRS[attr_id] if raw_val > 0: msg = f"{attr_name} = {raw_val}" if is_critical: failures.append(msg) else: warnings.append(msg) return { "health": health, "raw_output": output, "attributes": attributes, "warnings": warnings, "failures": failures, } def _parse_smart_progress(output: str) -> dict: state = "unknown" percent_remaining = None # None = "in progress but no % line parsed yet" lower = output.lower() if "self-test routine in progress" in lower or "self-test routine in progress" in output: state = "running" m = re.search(r"(\d+)%\s+of\s+test\s+remaining", output, re.IGNORECASE) if m: percent_remaining = int(m.group(1)) elif "completed without error" in lower: state = "passed" elif ( "completed: read failure" in lower or "completed: write failure" in lower or "aborted by host" in lower or ("completed" in lower and "failure" in lower) ): state = "failed" elif "in progress" in lower: state = "running" return { "state": state, "percent_remaining": percent_remaining, "output": output, }