Coverage for src / ai_shell / gpu.py: 80%

80 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 22:12 +0000

1"""NVIDIA GPU detection for Docker containers.""" 

2 

3import logging 

4import shutil 

5import subprocess 

6 

7logger = logging.getLogger(__name__) 

8 

9_MIB = 1024 * 1024 

10 

11 

12def detect_gpu() -> bool: 

13 """Check if NVIDIA GPU and Docker GPU runtime are available. 

14 

15 Returns True if both nvidia-smi succeeds and Docker has GPU support. 

16 Falls back to False with a warning if either check fails. 

17 """ 

18 if not _check_nvidia_smi(): 

19 return False 

20 if not _check_docker_gpu_runtime(): 

21 return False 

22 return True 

23 

24 

25def _check_nvidia_smi() -> bool: 

26 """Check if nvidia-smi is available and reports a GPU.""" 

27 if not shutil.which("nvidia-smi"): 

28 logger.debug("nvidia-smi not found in PATH") 

29 return False 

30 try: 

31 result = subprocess.run( 

32 ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], 

33 capture_output=True, 

34 text=True, 

35 timeout=10, 

36 ) 

37 if result.returncode == 0 and result.stdout.strip(): 

38 logger.debug("GPU detected: %s", result.stdout.strip().split("\n")[0]) 

39 return True 

40 logger.debug("nvidia-smi returned no GPUs") 

41 return False 

42 except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e: 

43 logger.debug("nvidia-smi check failed: %s", e) 

44 return False 

45 

46 

47def get_vram_info() -> dict[str, int] | None: 

48 """Query current GPU VRAM usage. 

49 

50 Returns dict with keys total/free/used in bytes, or None if unavailable. 

51 Uses the first GPU reported by nvidia-smi. 

52 """ 

53 if not shutil.which("nvidia-smi"): 

54 return None 

55 try: 

56 result = subprocess.run( 

57 [ 

58 "nvidia-smi", 

59 "--query-gpu=memory.total,memory.free,memory.used", 

60 "--format=csv,noheader,nounits", 

61 ], 

62 capture_output=True, 

63 text=True, 

64 timeout=10, 

65 ) 

66 if result.returncode != 0 or not result.stdout.strip(): 

67 return None 

68 line = result.stdout.strip().split("\n")[0] 

69 parts = line.split(",") 

70 if len(parts) != 3: 

71 return None 

72 total_mb, free_mb, used_mb = [int(p.strip()) for p in parts] 

73 return { 

74 "total": total_mb * _MIB, 

75 "free": free_mb * _MIB, 

76 "used": used_mb * _MIB, 

77 } 

78 except (subprocess.TimeoutExpired, FileNotFoundError, OSError, ValueError) as e: 

79 logger.debug("VRAM info query failed: %s", e) 

80 return None 

81 

82 

83def get_vram_processes() -> list[tuple[int, int, str]]: 

84 """Query processes currently using GPU VRAM. 

85 

86 Returns list of (pid, vram_mb, name) tuples, empty list if unavailable. 

87 """ 

88 if not shutil.which("nvidia-smi"): 

89 return [] 

90 try: 

91 result = subprocess.run( 

92 [ 

93 "nvidia-smi", 

94 "--query-compute-apps=pid,used_gpu_memory,name", 

95 "--format=csv,noheader,nounits", 

96 ], 

97 capture_output=True, 

98 text=True, 

99 timeout=10, 

100 ) 

101 if result.returncode != 0 or not result.stdout.strip(): 

102 return [] 

103 processes = [] 

104 for line in result.stdout.strip().split("\n"): 

105 parts = line.split(",") 

106 if len(parts) != 3: 

107 continue 

108 try: 

109 processes.append((int(parts[0].strip()), int(parts[1].strip()), parts[2].strip())) 

110 except ValueError: 

111 continue 

112 return processes 

113 except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e: 

114 logger.debug("VRAM process query failed: %s", e) 

115 return [] 

116 

117 

118def _check_docker_gpu_runtime() -> bool: 

119 """Check if Docker supports GPU via nvidia runtime.""" 

120 docker_path = shutil.which("docker") 

121 if not docker_path: 

122 logger.debug("docker not found in PATH") 

123 return False 

124 try: 

125 result = subprocess.run( 

126 ["docker", "info", "--format", "{{.Runtimes}}"], 

127 capture_output=True, 

128 text=True, 

129 timeout=10, 

130 ) 

131 if result.returncode == 0 and "nvidia" in result.stdout.lower(): 

132 logger.debug("Docker nvidia runtime available") 

133 return True 

134 # Also check for default GPU support (newer Docker versions) 

135 result2 = subprocess.run( 

136 ["docker", "info", "--format", "{{json .}}"], 

137 capture_output=True, 

138 text=True, 

139 timeout=10, 

140 ) 

141 if result2.returncode == 0 and "nvidia" in result2.stdout.lower(): 

142 logger.debug("Docker nvidia support detected via docker info") 

143 return True 

144 logger.debug("Docker nvidia runtime not found") 

145 return False 

146 except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e: 

147 logger.debug("Docker GPU runtime check failed: %s", e) 

148 return False