Coverage for src / ai_shell / container.py: 94%

217 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 22:12 +0000

1"""Docker container lifecycle management. 

2 

3Replaces docker-compose.yml by using Docker SDK to create and manage containers 

4with the exact same configuration. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import subprocess 

11import sys 

12import time 

13from pathlib import Path 

14from typing import TYPE_CHECKING, NoReturn 

15 

16from docker.errors import APIError, ImageNotFound, NotFound 

17from docker.types import DeviceRequest, Mount 

18 

19import docker 

20from ai_shell.defaults import ( 

21 LLM_NETWORK, 

22 OLLAMA_CONTAINER, 

23 OLLAMA_CPU_SHARES, 

24 OLLAMA_DATA_VOLUME, 

25 OLLAMA_IMAGE, 

26 OLLAMA_VRAM_BUFFER_BYTES, 

27 SHM_SIZE, 

28 WEBUI_CONTAINER, 

29 WEBUI_DATA_VOLUME, 

30 WEBUI_IMAGE, 

31 build_dev_environment, 

32 build_dev_mounts, 

33 dev_container_name, 

34) 

35from ai_shell.exceptions import ( 

36 ContainerNotFoundError, 

37 DockerNotAvailableError, 

38 ImagePullError, 

39) 

40from ai_shell.gpu import detect_gpu, get_vram_info 

41 

42if TYPE_CHECKING: 

43 from docker.models.containers import Container 

44 

45 from ai_shell.config import AiShellConfig 

46 

47logger = logging.getLogger(__name__) 

48 

49 

50def _exec_docker(args: list[str]) -> NoReturn: 

51 """Execute a docker CLI command with cross-platform TTY support. 

52 

53 Uses subprocess.run instead of os.execvp for Windows compatibility. 

54 On Windows, os.execvp doesn't truly replace the process, causing TTY issues. 

55 """ 

56 logger.debug("exec: %s", " ".join(args)) 

57 sys.stdout.flush() 

58 sys.stderr.flush() 

59 result = subprocess.run(args) 

60 sys.exit(result.returncode) 

61 

62 

63def _run_docker(args: list[str]) -> tuple[int, float]: 

64 """Run a docker CLI command and return (exit_code, elapsed_seconds). 

65 

66 Unlike _exec_docker, this does NOT call sys.exit(). 

67 """ 

68 logger.debug("run: %s", " ".join(args)) 

69 sys.stdout.flush() 

70 sys.stderr.flush() 

71 start = time.monotonic() 

72 result = subprocess.run(args) 

73 elapsed = time.monotonic() - start 

74 return result.returncode, elapsed 

75 

76 

77class ContainerManager: 

78 """Manages Docker containers for ai-shell. 

79 

80 Handles the dev container (per-project) and LLM stack (host-level singletons). 

81 """ 

82 

83 def __init__(self, config: AiShellConfig) -> None: 

84 self.config = config 

85 try: 

86 self.client = docker.from_env() # type: ignore[attr-defined] 

87 self.client.ping() 

88 except docker.errors.DockerException as e: 

89 raise DockerNotAvailableError( 

90 f"Docker is not available. Is the Docker daemon running?\n Error: {e}" 

91 ) from e 

92 

93 # ========================================================================= 

94 # Dev container (per-project) 

95 # ========================================================================= 

96 

97 def resolve_dev_container(self) -> tuple[str, Container | None]: 

98 """Resolve the dev container, checking both current and legacy names. 

99 

100 Returns ``(name, container)`` where *container* is ``None`` when no 

101 matching container exists. When no container is found under either 

102 name, the current hash-based name is returned so callers can use it 

103 for creation. 

104 """ 

105 name = dev_container_name(self.config.project_name, self.config.project_dir) 

106 container = self._get_container(name) 

107 if container is not None: 

108 return name, container 

109 

110 legacy_name = dev_container_name(self.config.project_name) 

111 legacy_container = self._get_container(legacy_name) 

112 if legacy_container is not None and self._container_matches_project( 

113 legacy_container, self.config.project_dir 

114 ): 

115 return legacy_name, legacy_container 

116 

117 return name, None 

118 

119 def ensure_dev_container(self) -> str: 

120 """Get or create the dev container for the current project. 

121 

122 If the container exists but is stopped, it is started. 

123 If it doesn't exist, it is created with the full configuration. 

124 

125 Returns the container name. 

126 """ 

127 name, container = self.resolve_dev_container() 

128 

129 if container is not None: 

130 if container.status != "running": 

131 logger.info("Starting existing container: %s", name) 

132 container.start() 

133 return name 

134 

135 logger.info("Creating dev container: %s", name) 

136 self._pull_image_if_needed(self.config.full_image) 

137 self._create_dev_container(name) 

138 return name 

139 

140 def _create_dev_container(self, name: str) -> Container: 

141 """Create the dev container with all docker-compose config.""" 

142 mounts = build_dev_mounts(self.config.project_dir, self.config.project_name) 

143 environment = build_dev_environment( 

144 self.config.extra_env, 

145 self.config.project_dir, 

146 project_name=self.config.project_name, 

147 aws_profile=self.config.ai_profile, 

148 aws_region=self.config.aws_region, 

149 ) 

150 

151 # Add any extra volumes from config 

152 for vol_spec in self.config.extra_volumes: 

153 parts = vol_spec.split(":") 

154 if len(parts) >= 2: 

155 source, target = parts[0], parts[1] 

156 read_only = len(parts) > 2 and parts[2] == "ro" 

157 mounts.append( 

158 Mount( 

159 target=target, 

160 source=source, 

161 type="bind", 

162 read_only=read_only, 

163 ) 

164 ) 

165 

166 container: Container = self.client.containers.run( 

167 image=self.config.full_image, 

168 name=name, 

169 mounts=mounts, 

170 environment=environment, 

171 working_dir=f"/root/projects/{self.config.project_name}", 

172 command="tail -f /dev/null", 

173 stdin_open=True, 

174 tty=True, 

175 shm_size=SHM_SIZE, 

176 init=True, 

177 extra_hosts={"host.docker.internal": "host-gateway"}, 

178 ports={f"{port}/tcp": None for port in self.config.dev_ports}, 

179 detach=True, 

180 ) 

181 logger.info("Container created: %s", name) 

182 return container 

183 

184 def exec_interactive( 

185 self, 

186 container_name: str, 

187 command: list[str], 

188 extra_env: dict[str, str] | None = None, 

189 workdir: str | None = None, 

190 ) -> NoReturn: 

191 """Execute an interactive command in a container. 

192 

193 Uses subprocess.run for cross-platform TTY compatibility. 

194 Detects whether stdin is a TTY to decide on -i/-t flags. 

195 If *workdir* is given it is passed as ``-w`` to ``docker exec``. 

196 """ 

197 args = ["docker", "exec"] 

198 

199 if sys.stdin.isatty(): 

200 args.append("-it") 

201 

202 if workdir: 

203 args.extend(["-w", workdir]) 

204 

205 if extra_env: 

206 for key, value in extra_env.items(): 

207 args.extend(["-e", f"{key}={value}"]) 

208 

209 args.append(container_name) 

210 args.extend(command) 

211 

212 _exec_docker(args) 

213 

214 def run_interactive( 

215 self, 

216 container_name: str, 

217 command: list[str], 

218 extra_env: dict[str, str] | None = None, 

219 workdir: str | None = None, 

220 ) -> tuple[int, float]: 

221 """Execute an interactive command, returning (exit_code, elapsed_seconds). 

222 

223 Same as exec_interactive but does not call sys.exit(). 

224 Used for retry logic (e.g., claude -c fallback). 

225 If *workdir* is given it is passed as ``-w`` to ``docker exec``. 

226 """ 

227 args = ["docker", "exec"] 

228 

229 if sys.stdin.isatty(): 

230 args.append("-it") 

231 

232 if workdir: 

233 args.extend(["-w", workdir]) 

234 

235 if extra_env: 

236 for key, value in extra_env.items(): 

237 args.extend(["-e", f"{key}={value}"]) 

238 

239 args.append(container_name) 

240 args.extend(command) 

241 

242 return _run_docker(args) 

243 

244 # ========================================================================= 

245 # LLM stack (host-level singletons) 

246 # ========================================================================= 

247 

248 def _ensure_llm_network(self) -> str: 

249 """Get or create the shared Docker network for the LLM stack.""" 

250 try: 

251 self.client.networks.get(LLM_NETWORK) 

252 except NotFound: 

253 logger.info("Creating LLM network: %s", LLM_NETWORK) 

254 self.client.networks.create(LLM_NETWORK, driver="bridge") 

255 return LLM_NETWORK 

256 

257 def ensure_ollama(self) -> str: 

258 """Get or create the Ollama container with GPU auto-detection. 

259 

260 Returns the container name. 

261 """ 

262 container = self._get_container(OLLAMA_CONTAINER) 

263 

264 if container is not None: 

265 if container.status != "running": 

266 logger.info("Starting existing Ollama container") 

267 container.start() 

268 return OLLAMA_CONTAINER 

269 

270 logger.info("Creating Ollama container") 

271 self._pull_image_if_needed(OLLAMA_IMAGE) 

272 network_name = self._ensure_llm_network() 

273 

274 # GPU auto-detection 

275 gpu_available = detect_gpu() 

276 device_requests = None 

277 env: dict[str, str] = {} 

278 if gpu_available: 

279 device_requests = [DeviceRequest(count=1, capabilities=[["gpu"]])] 

280 vram = get_vram_info() 

281 if vram: 

282 overhead = vram["used"] + OLLAMA_VRAM_BUFFER_BYTES 

283 env["OLLAMA_GPU_OVERHEAD"] = str(overhead) 

284 logger.info( 

285 "VRAM: %.1f GiB total, %.1f GiB free. Reserving %.1f GiB overhead for Ollama.", 

286 vram["total"] / 1024**3, 

287 vram["free"] / 1024**3, 

288 overhead / 1024**3, 

289 ) 

290 else: 

291 logger.info("GPU detected - Ollama will use NVIDIA GPU") 

292 else: 

293 logger.warning("No GPU detected - Ollama will run on CPU (slower inference)") 

294 

295 kwargs: dict = { 

296 "image": OLLAMA_IMAGE, 

297 "name": OLLAMA_CONTAINER, 

298 "ports": {"11434/tcp": ("0.0.0.0", self.config.ollama_port)}, # nosec B104 

299 "mounts": [ 

300 Mount( 

301 target="/root/.ollama", 

302 source=OLLAMA_DATA_VOLUME, 

303 type="volume", 

304 ) 

305 ], 

306 "restart_policy": {"Name": "unless-stopped"}, 

307 "detach": True, 

308 "network": network_name, 

309 "cpu_shares": OLLAMA_CPU_SHARES, 

310 } 

311 

312 if device_requests: 

313 kwargs["device_requests"] = device_requests 

314 if env: 

315 kwargs["environment"] = env 

316 

317 self.client.containers.run(**kwargs) 

318 logger.info("Ollama container created on port %d", self.config.ollama_port) 

319 return OLLAMA_CONTAINER 

320 

321 def ensure_webui(self) -> str: 

322 """Get or create the Open WebUI container. 

323 

324 Returns the container name. 

325 """ 

326 container = self._get_container(WEBUI_CONTAINER) 

327 

328 if container is not None: 

329 if container.status != "running": 

330 logger.info("Starting existing WebUI container") 

331 container.start() 

332 return WEBUI_CONTAINER 

333 

334 logger.info("Creating Open WebUI container") 

335 self._pull_image_if_needed(WEBUI_IMAGE) 

336 network_name = self._ensure_llm_network() 

337 

338 self.client.containers.run( 

339 image=WEBUI_IMAGE, 

340 name=WEBUI_CONTAINER, 

341 ports={"8080/tcp": ("0.0.0.0", self.config.webui_port)}, # nosec B104 

342 environment={ 

343 "OLLAMA_BASE_URL": f"http://{OLLAMA_CONTAINER}:11434", 

344 "WEBUI_AUTH": "false", 

345 }, 

346 mounts=[ 

347 Mount( 

348 target="/app/backend/data", 

349 source=WEBUI_DATA_VOLUME, 

350 type="volume", 

351 ) 

352 ], 

353 restart_policy={"Name": "unless-stopped"}, 

354 detach=True, 

355 network=network_name, 

356 ) 

357 

358 logger.info("Open WebUI container created on port %d", self.config.webui_port) 

359 return WEBUI_CONTAINER 

360 

361 def exec_in_ollama(self, command: list[str]) -> str: 

362 """Run a command in the Ollama container and return stdout. 

363 

364 Used for: ollama pull, ollama list, ollama create. 

365 """ 

366 container = self._get_container(OLLAMA_CONTAINER) 

367 if container is None or container.status != "running": 

368 raise ContainerNotFoundError(OLLAMA_CONTAINER) 

369 

370 exit_code, output = container.exec_run( 

371 cmd=command, 

372 stdout=True, 

373 stderr=True, 

374 ) 

375 decoded: str = output.decode("utf-8", errors="replace") 

376 if exit_code != 0: 

377 logger.error("Command failed in ollama: %s\n%s", " ".join(command), decoded) 

378 return decoded 

379 

380 # ========================================================================= 

381 # Container lifecycle 

382 # ========================================================================= 

383 

384 def stop_container(self, name: str) -> None: 

385 """Stop a container by name.""" 

386 container = self._get_container(name) 

387 if container is None: 

388 raise ContainerNotFoundError(name) 

389 if container.status == "running": 

390 container.stop() 

391 logger.info("Stopped container: %s", name) 

392 

393 def remove_container(self, name: str) -> None: 

394 """Remove a container by name, stopping it first if running.""" 

395 container = self._get_container(name) 

396 if container is None: 

397 raise ContainerNotFoundError(name) 

398 if container.status == "running": 

399 container.stop() 

400 logger.info("Stopped container: %s", name) 

401 container.remove() 

402 logger.info("Removed container: %s", name) 

403 

404 def container_ports(self, name: str) -> dict[str, str] | None: 

405 """Get the port mappings for a container. 

406 

407 Returns a dict mapping container ports (e.g. '3000/tcp') to host 

408 addresses (e.g. '0.0.0.0:49152'), or None if the container doesn't exist. 

409 """ 

410 container = self._get_container(name) 

411 if container is None: 

412 return None 

413 container.reload() 

414 ports_data = container.attrs.get("NetworkSettings", {}).get("Ports") or {} 

415 result: dict[str, str] = {} 

416 for container_port, bindings in sorted(ports_data.items()): 

417 if bindings: 

418 binding = bindings[0] 

419 result[container_port] = f"{binding['HostIp']}:{binding['HostPort']}" 

420 return result 

421 

422 def container_status(self, name: str) -> str | None: 

423 """Get the status of a container, or None if it doesn't exist.""" 

424 container = self._get_container(name) 

425 if container is None: 

426 return None 

427 return container.status # type: ignore[no-any-return] 

428 

429 def container_logs(self, name: str, follow: bool = False, tail: int = 100) -> None: 

430 """Print container logs. If follow=True, streams via docker CLI.""" 

431 if follow: 

432 # Use docker CLI for streaming 

433 args = ["docker", "logs", "-f", name] 

434 _exec_docker(args) 

435 else: 

436 container = self._get_container(name) 

437 if container is None: 

438 raise ContainerNotFoundError(name) 

439 logs = container.logs(tail=tail).decode("utf-8", errors="replace") 

440 print(logs) 

441 

442 # ========================================================================= 

443 # Internal helpers 

444 # ========================================================================= 

445 

446 def _get_container(self, name: str) -> Container | None: 

447 """Get a container by name, or None if it doesn't exist.""" 

448 try: 

449 return self.client.containers.get(name) 

450 except NotFound: 

451 return None 

452 

453 def _container_matches_project(self, container: Container, project_dir: Path) -> bool: 

454 """Check whether a container's project mount matches *project_dir*.""" 

455 resolved_project_dir = str(project_dir.resolve()) 

456 mounts = container.attrs.get("Mounts", []) 

457 for mount in mounts: 

458 if mount.get("Source") == resolved_project_dir: 

459 return True 

460 return False 

461 

462 def _pull_image_if_needed(self, image: str) -> None: 

463 """Pull a Docker image if not available locally.""" 

464 try: 

465 self.client.images.get(image) 

466 logger.debug("Image already available: %s", image) 

467 except ImageNotFound: 

468 logger.info("Pulling image: %s (this may take a while)...", image) 

469 try: 

470 self.client.images.pull(*image.rsplit(":", 1)) 

471 logger.info("Image pulled: %s", image) 

472 except APIError as e: 

473 raise ImagePullError(image, str(e)) from e