Coverage for src / ai_shell / config.py: 96%
136 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 22:12 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 22:12 +0000
1"""Configuration loading for ai-shell.
3Priority (highest wins): CLI flags > env vars > project config > global config > defaults.
5Config file lookup order (first match wins):
6 .ai-shell.yaml > .ai-shell.yml > .ai-shell.toml > ai-shell.toml
7"""
9from __future__ import annotations
11import logging
12import os
13import tomllib
14from dataclasses import dataclass, field
15from pathlib import Path
17import yaml
19from ai_shell import __version__
20from ai_shell.defaults import (
21 DEFAULT_CONTEXT_SIZE,
22 DEFAULT_DEV_PORTS,
23 DEFAULT_FALLBACK_MODEL,
24 DEFAULT_IMAGE,
25 DEFAULT_OLLAMA_PORT,
26 DEFAULT_PRIMARY_MODEL,
27 DEFAULT_WEBUI_PORT,
28)
30logger = logging.getLogger(__name__)
33@dataclass
34class AiShellConfig:
35 """Configuration for ai-shell."""
37 # Container
38 image: str = DEFAULT_IMAGE
39 image_tag: str = __version__
40 project_name: str = ""
41 project_dir: Path = field(default_factory=Path.cwd)
43 # LLM
44 primary_model: str = DEFAULT_PRIMARY_MODEL
45 fallback_model: str = DEFAULT_FALLBACK_MODEL
46 context_size: int = DEFAULT_CONTEXT_SIZE
47 ollama_port: int = DEFAULT_OLLAMA_PORT
48 webui_port: int = DEFAULT_WEBUI_PORT
50 # Aider
51 aider_model: str = f"ollama_chat/{DEFAULT_PRIMARY_MODEL}"
53 # Extra configuration
54 extra_env: dict[str, str] = field(default_factory=dict)
55 extra_volumes: list[str] = field(default_factory=list)
56 extra_ports: list[int] = field(default_factory=list)
58 # AWS
59 ai_profile: str = "" # AWS profile for infra (sets AWS_PROFILE in container)
60 aws_region: str = "" # Override AWS_REGION
61 bedrock_profile: str = "" # AWS profile for Bedrock LLM API calls
63 # Claude options
64 local_chrome: bool = False # Attach Chrome DevTools MCP to host Chrome debug port
66 # Per-tool provider
67 claude_provider: str = "" # "anthropic" (default) or "aws"
68 opencode_provider: str = "" # "local" (default, Ollama) or "aws" (Bedrock)
69 codex_provider: str = "" # "openai" (default) or "aws" (Bedrock)
70 codex_openai_api_key: str = "" # OpenAI API key (if set, overrides mounted SSO auth)
71 codex_profile: str = "" # AWS profile for Bedrock auth (when provider = "aws")
73 @property
74 def full_image(self) -> str:
75 """Return the full image reference with tag."""
76 return f"{self.image}:{self.image_tag}"
78 @property
79 def dev_ports(self) -> list[int]:
80 """Return deduplicated, sorted list of dev container ports to expose."""
81 return sorted(set(DEFAULT_DEV_PORTS + self.extra_ports))
84def load_config(
85 project_override: str | None = None,
86 project_dir: Path | None = None,
87) -> AiShellConfig:
88 """Load configuration from all sources.
90 Priority: CLI overrides > env vars > project toml > global toml > defaults.
91 """
92 config = AiShellConfig()
94 if project_dir:
95 config.project_dir = project_dir
97 # Load global config
98 global_config_dir = Path.home() / ".config" / "ai-shell"
99 for name in ("config.yaml", "config.yml", "config.toml"):
100 candidate = global_config_dir / name
101 if candidate.exists():
102 _apply_config(config, candidate)
103 break
105 # Load project config (first match wins)
106 for name in (".ai-shell.yaml", ".ai-shell.yml", ".ai-shell.toml", "ai-shell.toml"):
107 candidate = config.project_dir / name
108 if candidate.exists():
109 _apply_config(config, candidate)
110 break
112 # Apply environment variable overrides
113 _apply_env_vars(config)
115 # Apply CLI overrides
116 if project_override:
117 config.project_name = project_override
119 # Auto-derive project name from CWD if not set
120 if not config.project_name:
121 from ai_shell.defaults import sanitize_project_name
123 config.project_name = sanitize_project_name(config.project_dir)
125 return config
128def _load_config_file(path: Path) -> dict:
129 """Load a YAML or TOML config file and return the parsed dict."""
130 suffix = path.suffix.lower()
131 if suffix in (".yaml", ".yml"):
132 with open(path, encoding="utf-8") as f:
133 return yaml.safe_load(f) or {}
134 with open(path, "rb") as f:
135 return tomllib.load(f)
138def _apply_config(config: AiShellConfig, path: Path) -> None:
139 """Apply settings from a YAML or TOML config file."""
140 try:
141 data = _load_config_file(path)
142 except (OSError, tomllib.TOMLDecodeError, yaml.YAMLError) as e:
143 logger.warning("Failed to load config from %s: %s", path, e)
144 return
146 logger.debug("Loading config from %s", path)
148 # [container] section
149 container = data.get("container", {})
150 if "image" in container:
151 config.image = container["image"]
152 if "image_tag" in container:
153 config.image_tag = container["image_tag"]
154 if "extra_env" in container:
155 config.extra_env.update(container["extra_env"])
156 if "extra_volumes" in container:
157 config.extra_volumes.extend(container["extra_volumes"])
158 if "ports" in container:
159 config.extra_ports.extend(int(p) for p in container["ports"])
161 # [llm] section
162 llm = data.get("llm", {})
163 if "primary_model" in llm:
164 config.primary_model = llm["primary_model"]
165 if "fallback_model" in llm:
166 config.fallback_model = llm["fallback_model"]
167 if "context_size" in llm:
168 config.context_size = int(llm["context_size"])
169 if "ollama_port" in llm:
170 config.ollama_port = int(llm["ollama_port"])
171 if "webui_port" in llm:
172 config.webui_port = int(llm["webui_port"])
174 # [aider] section
175 aider = data.get("aider", {})
176 if "model" in aider:
177 config.aider_model = aider["model"]
179 # [aws] section
180 aws = data.get("aws", {})
181 if "ai_profile" in aws:
182 config.ai_profile = aws["ai_profile"]
183 if "region" in aws:
184 config.aws_region = aws["region"]
185 if "bedrock_profile" in aws:
186 config.bedrock_profile = aws["bedrock_profile"]
188 # [claude] section
189 claude_sec = data.get("claude", {})
190 if "provider" in claude_sec:
191 config.claude_provider = claude_sec["provider"]
192 if "local_chrome" in claude_sec:
193 config.local_chrome = bool(claude_sec["local_chrome"])
195 # [opencode] section
196 opencode_sec = data.get("opencode", {})
197 if "provider" in opencode_sec:
198 config.opencode_provider = opencode_sec["provider"]
200 # [codex] section
201 codex_sec = data.get("codex", {})
202 if "provider" in codex_sec:
203 config.codex_provider = codex_sec["provider"]
204 if "openai_api_key" in codex_sec:
205 config.codex_openai_api_key = codex_sec["openai_api_key"]
206 if "profile" in codex_sec:
207 config.codex_profile = codex_sec["profile"]
210def _apply_env_vars(config: AiShellConfig) -> None:
211 """Apply AI_SHELL_* environment variable overrides."""
212 env_map: dict[str, tuple[str, type]] = {
213 "AI_SHELL_IMAGE": ("image", str),
214 "AI_SHELL_IMAGE_TAG": ("image_tag", str),
215 "AI_SHELL_PROJECT": ("project_name", str),
216 "AI_SHELL_PRIMARY_MODEL": ("primary_model", str),
217 "AI_SHELL_FALLBACK_MODEL": ("fallback_model", str),
218 "AI_SHELL_CONTEXT_SIZE": ("context_size", int),
219 "AI_SHELL_OLLAMA_PORT": ("ollama_port", int),
220 "AI_SHELL_WEBUI_PORT": ("webui_port", int),
221 "AI_SHELL_AIDER_MODEL": ("aider_model", str),
222 "AI_SHELL_AI_PROFILE": ("ai_profile", str),
223 "AI_SHELL_AWS_REGION": ("aws_region", str),
224 "AI_SHELL_BEDROCK_PROFILE": ("bedrock_profile", str),
225 "AI_SHELL_CLAUDE_PROVIDER": ("claude_provider", str),
226 "AI_SHELL_OPENCODE_PROVIDER": ("opencode_provider", str),
227 "AI_SHELL_CODEX_PROVIDER": ("codex_provider", str),
228 "AI_SHELL_CODEX_OPENAI_API_KEY": ("codex_openai_api_key", str),
229 "AI_SHELL_CODEX_PROFILE": ("codex_profile", str),
230 "AI_SHELL_LOCAL_CHROME": ("local_chrome", bool),
231 }
233 for env_key, (attr, type_fn) in env_map.items():
234 value = os.environ.get(env_key)
235 if value is not None:
236 if type_fn is bool:
237 coerced = value.lower() not in ("0", "false", "no", "")
238 else:
239 coerced = type_fn(value)
240 setattr(config, attr, coerced)
241 logger.debug("Config override from env: %s=%s", env_key, value)
243 # AI_SHELL_PORTS is comma-separated, extends extra_ports
244 ports_value = os.environ.get("AI_SHELL_PORTS")
245 if ports_value:
246 config.extra_ports.extend(int(p.strip()) for p in ports_value.split(",") if p.strip())