Coverage for src / ai_shell / container.py: 94%
217 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 22:12 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 22:12 +0000
1"""Docker container lifecycle management.
3Replaces docker-compose.yml by using Docker SDK to create and manage containers
4with the exact same configuration.
5"""
7from __future__ import annotations
9import logging
10import subprocess
11import sys
12import time
13from pathlib import Path
14from typing import TYPE_CHECKING, NoReturn
16from docker.errors import APIError, ImageNotFound, NotFound
17from docker.types import DeviceRequest, Mount
19import docker
20from ai_shell.defaults import (
21 LLM_NETWORK,
22 OLLAMA_CONTAINER,
23 OLLAMA_CPU_SHARES,
24 OLLAMA_DATA_VOLUME,
25 OLLAMA_IMAGE,
26 OLLAMA_VRAM_BUFFER_BYTES,
27 SHM_SIZE,
28 WEBUI_CONTAINER,
29 WEBUI_DATA_VOLUME,
30 WEBUI_IMAGE,
31 build_dev_environment,
32 build_dev_mounts,
33 dev_container_name,
34)
35from ai_shell.exceptions import (
36 ContainerNotFoundError,
37 DockerNotAvailableError,
38 ImagePullError,
39)
40from ai_shell.gpu import detect_gpu, get_vram_info
42if TYPE_CHECKING:
43 from docker.models.containers import Container
45 from ai_shell.config import AiShellConfig
47logger = logging.getLogger(__name__)
50def _exec_docker(args: list[str]) -> NoReturn:
51 """Execute a docker CLI command with cross-platform TTY support.
53 Uses subprocess.run instead of os.execvp for Windows compatibility.
54 On Windows, os.execvp doesn't truly replace the process, causing TTY issues.
55 """
56 logger.debug("exec: %s", " ".join(args))
57 sys.stdout.flush()
58 sys.stderr.flush()
59 result = subprocess.run(args)
60 sys.exit(result.returncode)
63def _run_docker(args: list[str]) -> tuple[int, float]:
64 """Run a docker CLI command and return (exit_code, elapsed_seconds).
66 Unlike _exec_docker, this does NOT call sys.exit().
67 """
68 logger.debug("run: %s", " ".join(args))
69 sys.stdout.flush()
70 sys.stderr.flush()
71 start = time.monotonic()
72 result = subprocess.run(args)
73 elapsed = time.monotonic() - start
74 return result.returncode, elapsed
77class ContainerManager:
78 """Manages Docker containers for ai-shell.
80 Handles the dev container (per-project) and LLM stack (host-level singletons).
81 """
83 def __init__(self, config: AiShellConfig) -> None:
84 self.config = config
85 try:
86 self.client = docker.from_env() # type: ignore[attr-defined]
87 self.client.ping()
88 except docker.errors.DockerException as e:
89 raise DockerNotAvailableError(
90 f"Docker is not available. Is the Docker daemon running?\n Error: {e}"
91 ) from e
93 # =========================================================================
94 # Dev container (per-project)
95 # =========================================================================
97 def resolve_dev_container(self) -> tuple[str, Container | None]:
98 """Resolve the dev container, checking both current and legacy names.
100 Returns ``(name, container)`` where *container* is ``None`` when no
101 matching container exists. When no container is found under either
102 name, the current hash-based name is returned so callers can use it
103 for creation.
104 """
105 name = dev_container_name(self.config.project_name, self.config.project_dir)
106 container = self._get_container(name)
107 if container is not None:
108 return name, container
110 legacy_name = dev_container_name(self.config.project_name)
111 legacy_container = self._get_container(legacy_name)
112 if legacy_container is not None and self._container_matches_project(
113 legacy_container, self.config.project_dir
114 ):
115 return legacy_name, legacy_container
117 return name, None
119 def ensure_dev_container(self) -> str:
120 """Get or create the dev container for the current project.
122 If the container exists but is stopped, it is started.
123 If it doesn't exist, it is created with the full configuration.
125 Returns the container name.
126 """
127 name, container = self.resolve_dev_container()
129 if container is not None:
130 if container.status != "running":
131 logger.info("Starting existing container: %s", name)
132 container.start()
133 return name
135 logger.info("Creating dev container: %s", name)
136 self._pull_image_if_needed(self.config.full_image)
137 self._create_dev_container(name)
138 return name
140 def _create_dev_container(self, name: str) -> Container:
141 """Create the dev container with all docker-compose config."""
142 mounts = build_dev_mounts(self.config.project_dir, self.config.project_name)
143 environment = build_dev_environment(
144 self.config.extra_env,
145 self.config.project_dir,
146 project_name=self.config.project_name,
147 aws_profile=self.config.ai_profile,
148 aws_region=self.config.aws_region,
149 )
151 # Add any extra volumes from config
152 for vol_spec in self.config.extra_volumes:
153 parts = vol_spec.split(":")
154 if len(parts) >= 2:
155 source, target = parts[0], parts[1]
156 read_only = len(parts) > 2 and parts[2] == "ro"
157 mounts.append(
158 Mount(
159 target=target,
160 source=source,
161 type="bind",
162 read_only=read_only,
163 )
164 )
166 container: Container = self.client.containers.run(
167 image=self.config.full_image,
168 name=name,
169 mounts=mounts,
170 environment=environment,
171 working_dir=f"/root/projects/{self.config.project_name}",
172 command="tail -f /dev/null",
173 stdin_open=True,
174 tty=True,
175 shm_size=SHM_SIZE,
176 init=True,
177 extra_hosts={"host.docker.internal": "host-gateway"},
178 ports={f"{port}/tcp": None for port in self.config.dev_ports},
179 detach=True,
180 )
181 logger.info("Container created: %s", name)
182 return container
184 def exec_interactive(
185 self,
186 container_name: str,
187 command: list[str],
188 extra_env: dict[str, str] | None = None,
189 workdir: str | None = None,
190 ) -> NoReturn:
191 """Execute an interactive command in a container.
193 Uses subprocess.run for cross-platform TTY compatibility.
194 Detects whether stdin is a TTY to decide on -i/-t flags.
195 If *workdir* is given it is passed as ``-w`` to ``docker exec``.
196 """
197 args = ["docker", "exec"]
199 if sys.stdin.isatty():
200 args.append("-it")
202 if workdir:
203 args.extend(["-w", workdir])
205 if extra_env:
206 for key, value in extra_env.items():
207 args.extend(["-e", f"{key}={value}"])
209 args.append(container_name)
210 args.extend(command)
212 _exec_docker(args)
214 def run_interactive(
215 self,
216 container_name: str,
217 command: list[str],
218 extra_env: dict[str, str] | None = None,
219 workdir: str | None = None,
220 ) -> tuple[int, float]:
221 """Execute an interactive command, returning (exit_code, elapsed_seconds).
223 Same as exec_interactive but does not call sys.exit().
224 Used for retry logic (e.g., claude -c fallback).
225 If *workdir* is given it is passed as ``-w`` to ``docker exec``.
226 """
227 args = ["docker", "exec"]
229 if sys.stdin.isatty():
230 args.append("-it")
232 if workdir:
233 args.extend(["-w", workdir])
235 if extra_env:
236 for key, value in extra_env.items():
237 args.extend(["-e", f"{key}={value}"])
239 args.append(container_name)
240 args.extend(command)
242 return _run_docker(args)
244 # =========================================================================
245 # LLM stack (host-level singletons)
246 # =========================================================================
248 def _ensure_llm_network(self) -> str:
249 """Get or create the shared Docker network for the LLM stack."""
250 try:
251 self.client.networks.get(LLM_NETWORK)
252 except NotFound:
253 logger.info("Creating LLM network: %s", LLM_NETWORK)
254 self.client.networks.create(LLM_NETWORK, driver="bridge")
255 return LLM_NETWORK
257 def ensure_ollama(self) -> str:
258 """Get or create the Ollama container with GPU auto-detection.
260 Returns the container name.
261 """
262 container = self._get_container(OLLAMA_CONTAINER)
264 if container is not None:
265 if container.status != "running":
266 logger.info("Starting existing Ollama container")
267 container.start()
268 return OLLAMA_CONTAINER
270 logger.info("Creating Ollama container")
271 self._pull_image_if_needed(OLLAMA_IMAGE)
272 network_name = self._ensure_llm_network()
274 # GPU auto-detection
275 gpu_available = detect_gpu()
276 device_requests = None
277 env: dict[str, str] = {}
278 if gpu_available:
279 device_requests = [DeviceRequest(count=1, capabilities=[["gpu"]])]
280 vram = get_vram_info()
281 if vram:
282 overhead = vram["used"] + OLLAMA_VRAM_BUFFER_BYTES
283 env["OLLAMA_GPU_OVERHEAD"] = str(overhead)
284 logger.info(
285 "VRAM: %.1f GiB total, %.1f GiB free. Reserving %.1f GiB overhead for Ollama.",
286 vram["total"] / 1024**3,
287 vram["free"] / 1024**3,
288 overhead / 1024**3,
289 )
290 else:
291 logger.info("GPU detected - Ollama will use NVIDIA GPU")
292 else:
293 logger.warning("No GPU detected - Ollama will run on CPU (slower inference)")
295 kwargs: dict = {
296 "image": OLLAMA_IMAGE,
297 "name": OLLAMA_CONTAINER,
298 "ports": {"11434/tcp": ("0.0.0.0", self.config.ollama_port)}, # nosec B104
299 "mounts": [
300 Mount(
301 target="/root/.ollama",
302 source=OLLAMA_DATA_VOLUME,
303 type="volume",
304 )
305 ],
306 "restart_policy": {"Name": "unless-stopped"},
307 "detach": True,
308 "network": network_name,
309 "cpu_shares": OLLAMA_CPU_SHARES,
310 }
312 if device_requests:
313 kwargs["device_requests"] = device_requests
314 if env:
315 kwargs["environment"] = env
317 self.client.containers.run(**kwargs)
318 logger.info("Ollama container created on port %d", self.config.ollama_port)
319 return OLLAMA_CONTAINER
321 def ensure_webui(self) -> str:
322 """Get or create the Open WebUI container.
324 Returns the container name.
325 """
326 container = self._get_container(WEBUI_CONTAINER)
328 if container is not None:
329 if container.status != "running":
330 logger.info("Starting existing WebUI container")
331 container.start()
332 return WEBUI_CONTAINER
334 logger.info("Creating Open WebUI container")
335 self._pull_image_if_needed(WEBUI_IMAGE)
336 network_name = self._ensure_llm_network()
338 self.client.containers.run(
339 image=WEBUI_IMAGE,
340 name=WEBUI_CONTAINER,
341 ports={"8080/tcp": ("0.0.0.0", self.config.webui_port)}, # nosec B104
342 environment={
343 "OLLAMA_BASE_URL": f"http://{OLLAMA_CONTAINER}:11434",
344 "WEBUI_AUTH": "false",
345 },
346 mounts=[
347 Mount(
348 target="/app/backend/data",
349 source=WEBUI_DATA_VOLUME,
350 type="volume",
351 )
352 ],
353 restart_policy={"Name": "unless-stopped"},
354 detach=True,
355 network=network_name,
356 )
358 logger.info("Open WebUI container created on port %d", self.config.webui_port)
359 return WEBUI_CONTAINER
361 def exec_in_ollama(self, command: list[str]) -> str:
362 """Run a command in the Ollama container and return stdout.
364 Used for: ollama pull, ollama list, ollama create.
365 """
366 container = self._get_container(OLLAMA_CONTAINER)
367 if container is None or container.status != "running":
368 raise ContainerNotFoundError(OLLAMA_CONTAINER)
370 exit_code, output = container.exec_run(
371 cmd=command,
372 stdout=True,
373 stderr=True,
374 )
375 decoded: str = output.decode("utf-8", errors="replace")
376 if exit_code != 0:
377 logger.error("Command failed in ollama: %s\n%s", " ".join(command), decoded)
378 return decoded
380 # =========================================================================
381 # Container lifecycle
382 # =========================================================================
384 def stop_container(self, name: str) -> None:
385 """Stop a container by name."""
386 container = self._get_container(name)
387 if container is None:
388 raise ContainerNotFoundError(name)
389 if container.status == "running":
390 container.stop()
391 logger.info("Stopped container: %s", name)
393 def remove_container(self, name: str) -> None:
394 """Remove a container by name, stopping it first if running."""
395 container = self._get_container(name)
396 if container is None:
397 raise ContainerNotFoundError(name)
398 if container.status == "running":
399 container.stop()
400 logger.info("Stopped container: %s", name)
401 container.remove()
402 logger.info("Removed container: %s", name)
404 def container_ports(self, name: str) -> dict[str, str] | None:
405 """Get the port mappings for a container.
407 Returns a dict mapping container ports (e.g. '3000/tcp') to host
408 addresses (e.g. '0.0.0.0:49152'), or None if the container doesn't exist.
409 """
410 container = self._get_container(name)
411 if container is None:
412 return None
413 container.reload()
414 ports_data = container.attrs.get("NetworkSettings", {}).get("Ports") or {}
415 result: dict[str, str] = {}
416 for container_port, bindings in sorted(ports_data.items()):
417 if bindings:
418 binding = bindings[0]
419 result[container_port] = f"{binding['HostIp']}:{binding['HostPort']}"
420 return result
422 def container_status(self, name: str) -> str | None:
423 """Get the status of a container, or None if it doesn't exist."""
424 container = self._get_container(name)
425 if container is None:
426 return None
427 return container.status # type: ignore[no-any-return]
429 def container_logs(self, name: str, follow: bool = False, tail: int = 100) -> None:
430 """Print container logs. If follow=True, streams via docker CLI."""
431 if follow:
432 # Use docker CLI for streaming
433 args = ["docker", "logs", "-f", name]
434 _exec_docker(args)
435 else:
436 container = self._get_container(name)
437 if container is None:
438 raise ContainerNotFoundError(name)
439 logs = container.logs(tail=tail).decode("utf-8", errors="replace")
440 print(logs)
442 # =========================================================================
443 # Internal helpers
444 # =========================================================================
446 def _get_container(self, name: str) -> Container | None:
447 """Get a container by name, or None if it doesn't exist."""
448 try:
449 return self.client.containers.get(name)
450 except NotFound:
451 return None
453 def _container_matches_project(self, container: Container, project_dir: Path) -> bool:
454 """Check whether a container's project mount matches *project_dir*."""
455 resolved_project_dir = str(project_dir.resolve())
456 mounts = container.attrs.get("Mounts", [])
457 for mount in mounts:
458 if mount.get("Source") == resolved_project_dir:
459 return True
460 return False
462 def _pull_image_if_needed(self, image: str) -> None:
463 """Pull a Docker image if not available locally."""
464 try:
465 self.client.images.get(image)
466 logger.debug("Image already available: %s", image)
467 except ImageNotFound:
468 logger.info("Pulling image: %s (this may take a while)...", image)
469 try:
470 self.client.images.pull(*image.rsplit(":", 1))
471 logger.info("Image pulled: %s", image)
472 except APIError as e:
473 raise ImagePullError(image, str(e)) from e