Coverage for src / ai_shell / gpu.py: 80%
80 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 22:12 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 22:12 +0000
1"""NVIDIA GPU detection for Docker containers."""
3import logging
4import shutil
5import subprocess
7logger = logging.getLogger(__name__)
9_MIB = 1024 * 1024
12def detect_gpu() -> bool:
13 """Check if NVIDIA GPU and Docker GPU runtime are available.
15 Returns True if both nvidia-smi succeeds and Docker has GPU support.
16 Falls back to False with a warning if either check fails.
17 """
18 if not _check_nvidia_smi():
19 return False
20 if not _check_docker_gpu_runtime():
21 return False
22 return True
25def _check_nvidia_smi() -> bool:
26 """Check if nvidia-smi is available and reports a GPU."""
27 if not shutil.which("nvidia-smi"):
28 logger.debug("nvidia-smi not found in PATH")
29 return False
30 try:
31 result = subprocess.run(
32 ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
33 capture_output=True,
34 text=True,
35 timeout=10,
36 )
37 if result.returncode == 0 and result.stdout.strip():
38 logger.debug("GPU detected: %s", result.stdout.strip().split("\n")[0])
39 return True
40 logger.debug("nvidia-smi returned no GPUs")
41 return False
42 except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
43 logger.debug("nvidia-smi check failed: %s", e)
44 return False
47def get_vram_info() -> dict[str, int] | None:
48 """Query current GPU VRAM usage.
50 Returns dict with keys total/free/used in bytes, or None if unavailable.
51 Uses the first GPU reported by nvidia-smi.
52 """
53 if not shutil.which("nvidia-smi"):
54 return None
55 try:
56 result = subprocess.run(
57 [
58 "nvidia-smi",
59 "--query-gpu=memory.total,memory.free,memory.used",
60 "--format=csv,noheader,nounits",
61 ],
62 capture_output=True,
63 text=True,
64 timeout=10,
65 )
66 if result.returncode != 0 or not result.stdout.strip():
67 return None
68 line = result.stdout.strip().split("\n")[0]
69 parts = line.split(",")
70 if len(parts) != 3:
71 return None
72 total_mb, free_mb, used_mb = [int(p.strip()) for p in parts]
73 return {
74 "total": total_mb * _MIB,
75 "free": free_mb * _MIB,
76 "used": used_mb * _MIB,
77 }
78 except (subprocess.TimeoutExpired, FileNotFoundError, OSError, ValueError) as e:
79 logger.debug("VRAM info query failed: %s", e)
80 return None
83def get_vram_processes() -> list[tuple[int, int, str]]:
84 """Query processes currently using GPU VRAM.
86 Returns list of (pid, vram_mb, name) tuples, empty list if unavailable.
87 """
88 if not shutil.which("nvidia-smi"):
89 return []
90 try:
91 result = subprocess.run(
92 [
93 "nvidia-smi",
94 "--query-compute-apps=pid,used_gpu_memory,name",
95 "--format=csv,noheader,nounits",
96 ],
97 capture_output=True,
98 text=True,
99 timeout=10,
100 )
101 if result.returncode != 0 or not result.stdout.strip():
102 return []
103 processes = []
104 for line in result.stdout.strip().split("\n"):
105 parts = line.split(",")
106 if len(parts) != 3:
107 continue
108 try:
109 processes.append((int(parts[0].strip()), int(parts[1].strip()), parts[2].strip()))
110 except ValueError:
111 continue
112 return processes
113 except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
114 logger.debug("VRAM process query failed: %s", e)
115 return []
118def _check_docker_gpu_runtime() -> bool:
119 """Check if Docker supports GPU via nvidia runtime."""
120 docker_path = shutil.which("docker")
121 if not docker_path:
122 logger.debug("docker not found in PATH")
123 return False
124 try:
125 result = subprocess.run(
126 ["docker", "info", "--format", "{{.Runtimes}}"],
127 capture_output=True,
128 text=True,
129 timeout=10,
130 )
131 if result.returncode == 0 and "nvidia" in result.stdout.lower():
132 logger.debug("Docker nvidia runtime available")
133 return True
134 # Also check for default GPU support (newer Docker versions)
135 result2 = subprocess.run(
136 ["docker", "info", "--format", "{{json .}}"],
137 capture_output=True,
138 text=True,
139 timeout=10,
140 )
141 if result2.returncode == 0 and "nvidia" in result2.stdout.lower():
142 logger.debug("Docker nvidia support detected via docker info")
143 return True
144 logger.debug("Docker nvidia runtime not found")
145 return False
146 except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
147 logger.debug("Docker GPU runtime check failed: %s", e)
148 return False