Coverage for src/ai_shell/config.py: 87%
325 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-05 22:06 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-05 22:06 +0000
1"""Configuration loading for ai-shell.
3Priority (highest wins): CLI flags > env vars > project config > global config > defaults.
5Global config lookup order (first match wins):
6 ~/.ai-shell.yaml > ~/.ai-shell.yml > ~/.ai-shell.toml
7 > ~/.config/ai-shell/config.yaml > ~/.config/ai-shell/config.yml > ~/.config/ai-shell/config.toml
9Project config lookup order (first match wins):
10 .ai-shell.yaml > .ai-shell.yml > .ai-shell.toml > ai-shell.toml
11"""
13from __future__ import annotations
15import logging
16import os
17import tomllib
18from dataclasses import dataclass, field
19from pathlib import Path
21import yaml
23from ai_shell.defaults import (
24 DEFAULT_COMFYUI_PORT,
25 DEFAULT_CONTEXT_SIZE,
26 DEFAULT_DEV_PORTS,
27 DEFAULT_EXTRA_MODELS,
28 DEFAULT_IMAGE,
29 DEFAULT_KOKORO_PORT,
30 DEFAULT_KOKORO_VOICE,
31 DEFAULT_N8N_PORT,
32 DEFAULT_OLLAMA_PORT,
33 DEFAULT_PRIMARY_CHAT_MODEL,
34 DEFAULT_PRIMARY_CODING_MODEL,
35 DEFAULT_SECONDARY_CHAT_MODEL,
36 DEFAULT_SECONDARY_CODING_MODEL,
37 DEFAULT_VOICE_AGENT_PORT,
38 DEFAULT_WEBUI_PORT,
39 DEFAULT_WHISPER_MODEL,
40 DEFAULT_WHISPER_PORT,
41)
43logger = logging.getLogger(__name__)
46@dataclass
47class VoiceAgentModelProfile:
48 """A named pair of primary + secondary chat models for the voice agent."""
50 primary: str = ""
51 secondary: str = ""
54@dataclass
55class VoiceAgentVadConfig:
56 """Silero VAD / barge-in behavior."""
58 silence_timeout_ms: int = 2500
59 barge_in: bool = True
62@dataclass
63class VoiceAgentFilesystemConfig:
64 """Filesystem tool scoping. Consumed by Phase 4."""
66 root: str = "~/gigachad"
67 read: list[str] = field(default_factory=lambda: ["~/gigachad"])
68 write: list[str] = field(default_factory=lambda: ["~/gigachad"])
69 deny_glob: list[str] = field(default_factory=lambda: ["**/.env*", "**/.git/**"])
72@dataclass
73class VoiceAgentMemoryConfig:
74 """Sqlite memory behavior. Consumed by Phase 5."""
76 enabled: bool = True
77 summarize_after_turns: int = 20
80@dataclass
81class VoiceAgentAuthConfig:
82 """App-level session auth. Consumed by Phase 3."""
84 username: str = ""
85 password_bcrypt: str = ""
86 session_secret: str = ""
89@dataclass
90class VoiceAgentProvidersConfig:
91 """LLM provider selection. Consumed by Phase 6."""
93 default: str = "ollama"
94 available: list[str] = field(default_factory=lambda: ["ollama"])
97@dataclass
98class VoiceAgentToolConfig:
99 """A single tool entry under `voice_agent.tools`."""
101 enabled: bool = False
102 provider: str = ""
105@dataclass
106class VoiceAgentToolsConfig:
107 """Tool registry. Consumed by Phase 4."""
109 filesystem: VoiceAgentToolConfig = field(default_factory=VoiceAgentToolConfig)
110 web_search: VoiceAgentToolConfig = field(
111 default_factory=lambda: VoiceAgentToolConfig(provider="brave")
112 )
113 github: VoiceAgentToolConfig = field(default_factory=VoiceAgentToolConfig)
116@dataclass
117class VoiceAgentWakeWordConfig:
118 """Wake-word gating. Consumed by Phase 3."""
120 enabled: bool = False
121 name: str = "hey_jarvis"
124@dataclass
125class VoiceAgentConfig:
126 """Full voice-agent config tree.
128 Phase 2 wires only ``port`` at the container layer. The remaining fields
129 are schema placeholders for Phases 3-6 with reasonable defaults so early
130 adopters can see the shape without the CLI refusing unknown keys.
131 """
133 port: int = DEFAULT_VOICE_AGENT_PORT
134 domain: str = ""
135 profile: str = "resident"
136 profiles: dict[str, VoiceAgentModelProfile] = field(
137 default_factory=lambda: {
138 "resident": VoiceAgentModelProfile(
139 primary="qwen3.5:9b",
140 secondary="huihui_ai/qwen3.5-abliterated:9b",
141 ),
142 "swap": VoiceAgentModelProfile(
143 primary="qwen3.5:27b",
144 secondary="dolphin3:8b",
145 ),
146 }
147 )
148 vad: VoiceAgentVadConfig = field(default_factory=VoiceAgentVadConfig)
149 filesystem: VoiceAgentFilesystemConfig = field(default_factory=VoiceAgentFilesystemConfig)
150 memory: VoiceAgentMemoryConfig = field(default_factory=VoiceAgentMemoryConfig)
151 auth: VoiceAgentAuthConfig = field(default_factory=VoiceAgentAuthConfig)
152 providers: VoiceAgentProvidersConfig = field(default_factory=VoiceAgentProvidersConfig)
153 tools: VoiceAgentToolsConfig = field(default_factory=VoiceAgentToolsConfig)
154 wake_word: VoiceAgentWakeWordConfig = field(default_factory=VoiceAgentWakeWordConfig)
157@dataclass
158class AiShellConfig:
159 """Configuration for ai-shell."""
161 # Container
162 image: str = DEFAULT_IMAGE
163 image_tag: str = "latest"
164 project_name: str = ""
165 project_dir: Path = field(default_factory=Path.cwd)
167 # LLM model slots. Primary = best-available; secondary = best uncensored
168 # alternative. Chat slots are routed to Open WebUI, coding slots to
169 # OpenCode. `extra_models` is a free-form list of additional
170 # Ollama tags to pull alongside the 4 slots (deduped).
171 primary_chat_model: str = DEFAULT_PRIMARY_CHAT_MODEL
172 secondary_chat_model: str = DEFAULT_SECONDARY_CHAT_MODEL
173 primary_coding_model: str = DEFAULT_PRIMARY_CODING_MODEL
174 secondary_coding_model: str = DEFAULT_SECONDARY_CODING_MODEL
175 extra_models: list[str] = field(default_factory=lambda: list(DEFAULT_EXTRA_MODELS))
176 context_size: int = DEFAULT_CONTEXT_SIZE
177 ollama_port: int = DEFAULT_OLLAMA_PORT
178 webui_port: int = DEFAULT_WEBUI_PORT
179 kokoro_port: int = DEFAULT_KOKORO_PORT
180 kokoro_voice: str = DEFAULT_KOKORO_VOICE
181 n8n_port: int = DEFAULT_N8N_PORT
182 whisper_port: int = DEFAULT_WHISPER_PORT
183 whisper_model: str = DEFAULT_WHISPER_MODEL
184 comfyui_port: int = DEFAULT_COMFYUI_PORT
186 # Voice agent (Phase 2 wires `port`; remaining fields are schema
187 # placeholders that Phases 3-6 consume — see VoiceAgentConfig).
188 voice_agent: VoiceAgentConfig = field(default_factory=VoiceAgentConfig)
190 # Extra configuration
191 extra_env: dict[str, str] = field(default_factory=dict)
192 extra_volumes: list[str] = field(default_factory=list)
193 extra_ports: list[int] = field(default_factory=list)
194 # Glob patterns (relative to project_dir) for monorepo workspace
195 # node_modules directories. Each match gets an isolated named volume.
196 node_modules_paths: list[str] = field(default_factory=list)
198 # AWS
199 ai_profile: str = "" # AWS profile for infra (sets AWS_PROFILE in container)
200 aws_region: str = "" # Override AWS_REGION
201 bedrock_profile: str = "" # AWS profile for Bedrock LLM API calls
202 bedrock_region: str = "" # AWS region for Bedrock (falls back to aws_region)
203 bedrock_model: str = "us.meta.llama3-3-70b-instruct-v1:0"
205 # OpenAI
206 openai_profile: str = "" # Suffixed .env key name for multi-account switching
208 # Claude options
209 local_chrome: bool = False # Attach Chrome DevTools MCP to project-scoped host Chrome
210 skip_updates: bool = False # When True, skip pre-launch tool freshness checks
212 # Pre-launch cache TTLs (seconds). Set to 0 to disable.
213 image_pull_cache_ttl: int = 900 # 15 min: skip docker pull if checked recently
214 bedrock_check_cache_ttl: int = 86400 # 24 h: skip Bedrock preflight if checked recently
216 # Per-tool provider
217 claude_provider: str = "" # "anthropic" (default) or "aws"
219 @property
220 def full_image(self) -> str:
221 """Return the full image reference with tag."""
222 return f"{self.image}:{self.image_tag}"
224 @property
225 def dev_ports(self) -> list[int]:
226 """Return deduplicated, sorted list of dev container ports to expose."""
227 return sorted(set(DEFAULT_DEV_PORTS + self.extra_ports))
229 @property
230 def models_to_pull(self) -> list[str]:
231 """Return the full deduped list of Ollama model tags to pull.
233 The 4 slots in order, followed by any ``extra_models``. Duplicates
234 are removed while preserving first-occurrence order.
235 """
236 ordered = [
237 self.primary_chat_model,
238 self.secondary_chat_model,
239 self.primary_coding_model,
240 self.secondary_coding_model,
241 *self.extra_models,
242 ]
243 seen: set[str] = set()
244 deduped: list[str] = []
245 for model in ordered:
246 if model and model not in seen:
247 seen.add(model)
248 deduped.append(model)
249 return deduped
252def load_config(
253 project_override: str | None = None,
254 project_dir: Path | None = None,
255) -> AiShellConfig:
256 """Load configuration from all sources.
258 Priority: CLI overrides > env vars > project toml > global toml > defaults.
259 """
260 config = AiShellConfig()
262 if project_dir:
263 config.project_dir = project_dir
265 # Load global config (~/.augint/ canonical, ~/.ai-shell.yaml legacy fallback)
266 home = Path.home()
267 for candidate in (
268 home / ".augint" / ".ai-shell.yaml",
269 home / ".ai-shell.yaml",
270 home / ".ai-shell.yml",
271 home / ".ai-shell.toml",
272 home / ".config" / "ai-shell" / "config.yaml",
273 home / ".config" / "ai-shell" / "config.yml",
274 home / ".config" / "ai-shell" / "config.toml",
275 ):
276 if candidate.exists():
277 _apply_config(config, candidate)
278 break
280 # Load project config (first match wins)
281 for name in (".ai-shell.yaml", ".ai-shell.yml", ".ai-shell.toml", "ai-shell.toml"):
282 candidate = config.project_dir / name
283 if candidate.exists():
284 _apply_config(config, candidate)
285 break
287 # Apply environment variable overrides
288 _apply_env_vars(config)
290 # Apply CLI overrides
291 if project_override:
292 config.project_name = project_override
294 # Auto-derive project name from CWD if not set
295 if not config.project_name:
296 from ai_shell.defaults import sanitize_project_name
298 config.project_name = sanitize_project_name(config.project_dir)
300 return config
303def _load_config_file(path: Path) -> dict:
304 """Load a YAML or TOML config file and return the parsed dict."""
305 suffix = path.suffix.lower()
306 if suffix in (".yaml", ".yml"):
307 with open(path, encoding="utf-8") as f:
308 return yaml.safe_load(f) or {}
309 with open(path, "rb") as f:
310 return tomllib.load(f)
313_LEGACY_LLM_KEY_HINT = {
314 "primary_model": (
315 "renamed to `primary_coding_model` (coding) or `primary_chat_model` "
316 "(chat). The new config uses 4 role-specific slots; pick the one "
317 "that matches your intent. See the generated .ai-shell.yaml for the "
318 "full layout."
319 ),
320 "fallback_model": (
321 "removed. The previous `fallback_model` was role-ambiguous. Use "
322 "`secondary_chat_model` and `secondary_coding_model` instead "
323 "(both default to the best uncensored variants). See the generated "
324 ".ai-shell.yaml for the full layout."
325 ),
326}
329def _reject_legacy_llm_keys(llm_section: dict, path: Path) -> None:
330 """Raise on deprecated `primary_model` / `fallback_model` keys.
332 These were removed when the llm config split into 4 role-specific slots
333 (primary/secondary x chat/coding). Silently aliasing them would corrupt
334 intent — e.g. the old `fallback_model` meant different things to chat and
335 coding users. Fail loudly with migration guidance.
336 """
337 bad = [k for k in _LEGACY_LLM_KEY_HINT if k in llm_section]
338 if not bad:
339 return
340 lines = [f"\nDeprecated llm key(s) found in {path}:"]
341 for key in bad:
342 lines.append(f" - `{key}`: {_LEGACY_LLM_KEY_HINT[key]}")
343 raise ValueError("\n".join(lines))
346def _apply_voice_agent_config(va: VoiceAgentConfig, data: dict) -> None:
347 """Merge a parsed ``voice_agent:`` section into a VoiceAgentConfig.
349 Only keys present in *data* override defaults; everything else keeps
350 the dataclass default. Nested sections are merged field-by-field so
351 partial user configs work.
352 """
353 if "port" in data:
354 va.port = int(data["port"])
355 if "domain" in data:
356 va.domain = str(data["domain"])
357 if "profile" in data:
358 va.profile = str(data["profile"])
359 if "profiles" in data and isinstance(data["profiles"], dict):
360 for name, entry in data["profiles"].items():
361 profile = va.profiles.get(name, VoiceAgentModelProfile())
362 if isinstance(entry, dict):
363 if "primary" in entry:
364 profile.primary = str(entry["primary"])
365 if "secondary" in entry:
366 profile.secondary = str(entry["secondary"])
367 va.profiles[name] = profile
368 if "vad" in data and isinstance(data["vad"], dict):
369 vad = data["vad"]
370 if "silence_timeout_ms" in vad:
371 va.vad.silence_timeout_ms = int(vad["silence_timeout_ms"])
372 if "barge_in" in vad:
373 va.vad.barge_in = bool(vad["barge_in"])
374 if "filesystem" in data and isinstance(data["filesystem"], dict):
375 fs = data["filesystem"]
376 if "root" in fs:
377 va.filesystem.root = str(fs["root"])
378 if "read" in fs:
379 va.filesystem.read = [str(p) for p in fs["read"]]
380 if "write" in fs:
381 va.filesystem.write = [str(p) for p in fs["write"]]
382 if "deny_glob" in fs:
383 va.filesystem.deny_glob = [str(p) for p in fs["deny_glob"]]
384 if "memory" in data and isinstance(data["memory"], dict):
385 mem = data["memory"]
386 if "enabled" in mem:
387 va.memory.enabled = bool(mem["enabled"])
388 if "summarize_after_turns" in mem:
389 va.memory.summarize_after_turns = int(mem["summarize_after_turns"])
390 if "auth" in data and isinstance(data["auth"], dict):
391 auth = data["auth"]
392 if "username" in auth:
393 va.auth.username = str(auth["username"])
394 if "password_bcrypt" in auth:
395 va.auth.password_bcrypt = str(auth["password_bcrypt"])
396 if "session_secret" in auth:
397 va.auth.session_secret = str(auth["session_secret"])
398 if "providers" in data and isinstance(data["providers"], dict):
399 providers = data["providers"]
400 if "default" in providers:
401 va.providers.default = str(providers["default"])
402 if "available" in providers:
403 va.providers.available = [str(p) for p in providers["available"]]
404 if "tools" in data and isinstance(data["tools"], dict):
405 tools = data["tools"]
406 for tool_name in ("filesystem", "web_search", "github"):
407 entry = tools.get(tool_name)
408 if isinstance(entry, dict):
409 tool = getattr(va.tools, tool_name)
410 if "enabled" in entry:
411 tool.enabled = bool(entry["enabled"])
412 if "provider" in entry:
413 tool.provider = str(entry["provider"])
414 if "wake_word" in data and isinstance(data["wake_word"], dict):
415 wake = data["wake_word"]
416 if "enabled" in wake:
417 va.wake_word.enabled = bool(wake["enabled"])
418 if "name" in wake:
419 va.wake_word.name = str(wake["name"])
422def _apply_config(config: AiShellConfig, path: Path) -> None:
423 """Apply settings from a YAML or TOML config file."""
424 try:
425 data = _load_config_file(path)
426 except (OSError, tomllib.TOMLDecodeError, yaml.YAMLError) as e:
427 logger.warning("Failed to load config from %s: %s", path, e)
428 return
430 logger.debug("Loading config from %s", path)
432 # [container] section
433 container = data.get("container", {})
434 if "image" in container:
435 config.image = container["image"]
436 if "image_tag" in container:
437 config.image_tag = container["image_tag"]
438 if "extra_env" in container:
439 config.extra_env.update(container["extra_env"])
440 if "extra_volumes" in container:
441 config.extra_volumes.extend(container["extra_volumes"])
442 if "ports" in container:
443 config.extra_ports.extend(int(p) for p in container["ports"])
444 if "node_modules_paths" in container:
445 config.node_modules_paths.extend(str(p) for p in container["node_modules_paths"])
447 # [llm] section
448 llm = data.get("llm", {})
449 _reject_legacy_llm_keys(llm, path)
450 if "primary_chat_model" in llm:
451 config.primary_chat_model = llm["primary_chat_model"]
452 if "secondary_chat_model" in llm:
453 config.secondary_chat_model = llm["secondary_chat_model"]
454 if "primary_coding_model" in llm:
455 config.primary_coding_model = llm["primary_coding_model"]
456 if "secondary_coding_model" in llm:
457 config.secondary_coding_model = llm["secondary_coding_model"]
458 if "extra_models" in llm:
459 config.extra_models = [str(m) for m in llm["extra_models"]]
460 if "context_size" in llm:
461 config.context_size = int(llm["context_size"])
462 if "ollama_port" in llm:
463 config.ollama_port = int(llm["ollama_port"])
464 if "webui_port" in llm:
465 config.webui_port = int(llm["webui_port"])
466 if "kokoro_port" in llm:
467 config.kokoro_port = int(llm["kokoro_port"])
468 if "kokoro_voice" in llm:
469 config.kokoro_voice = str(llm["kokoro_voice"])
470 if "n8n_port" in llm:
471 config.n8n_port = int(llm["n8n_port"])
472 if "whisper_port" in llm:
473 config.whisper_port = int(llm["whisper_port"])
474 if "whisper_model" in llm:
475 config.whisper_model = str(llm["whisper_model"])
476 if "comfyui_port" in llm:
477 config.comfyui_port = int(llm["comfyui_port"])
479 # [voice_agent] section (top-level, not under llm)
480 if "voice_agent" in data:
481 _apply_voice_agent_config(config.voice_agent, data["voice_agent"])
483 # [aws] section
484 aws = data.get("aws", {})
485 if "ai_profile" in aws:
486 config.ai_profile = aws["ai_profile"]
487 if "region" in aws:
488 config.aws_region = aws["region"]
489 if "bedrock_profile" in aws:
490 config.bedrock_profile = aws["bedrock_profile"]
491 if "bedrock_region" in aws:
492 config.bedrock_region = aws["bedrock_region"]
493 if "bedrock_model" in aws:
494 config.bedrock_model = aws["bedrock_model"]
496 # [openai] section
497 openai = data.get("openai", {})
498 if "profile" in openai:
499 config.openai_profile = openai["profile"]
501 # [claude] section
502 claude_sec = data.get("claude", {})
503 if "provider" in claude_sec:
504 config.claude_provider = claude_sec["provider"]
505 if "local_chrome" in claude_sec:
506 config.local_chrome = bool(claude_sec["local_chrome"])
507 if "skip_updates" in container:
508 config.skip_updates = bool(container["skip_updates"])
509 if "image_pull_cache_ttl" in container:
510 config.image_pull_cache_ttl = int(container["image_pull_cache_ttl"])
511 if "bedrock_check_cache_ttl" in aws:
512 config.bedrock_check_cache_ttl = int(aws["bedrock_check_cache_ttl"])
515_LEGACY_ENV_VARS = {
516 "AI_SHELL_PRIMARY_MODEL": ("AI_SHELL_PRIMARY_CODING_MODEL or AI_SHELL_PRIMARY_CHAT_MODEL"),
517 "AI_SHELL_FALLBACK_MODEL": ("AI_SHELL_SECONDARY_CHAT_MODEL or AI_SHELL_SECONDARY_CODING_MODEL"),
518}
521def _apply_env_vars(config: AiShellConfig) -> None:
522 """Apply AI_SHELL_* environment variable overrides."""
523 bad_env = [k for k in _LEGACY_ENV_VARS if os.environ.get(k) is not None]
524 if bad_env:
525 lines = ["\nDeprecated AI_SHELL_* env var(s) set:"]
526 for key in bad_env:
527 lines.append(f" - {key}: use {_LEGACY_ENV_VARS[key]} instead")
528 raise ValueError("\n".join(lines))
530 env_map: dict[str, tuple[str, type]] = {
531 "AI_SHELL_IMAGE": ("image", str),
532 "AI_SHELL_IMAGE_TAG": ("image_tag", str),
533 "AI_SHELL_PROJECT": ("project_name", str),
534 "AI_SHELL_PRIMARY_CHAT_MODEL": ("primary_chat_model", str),
535 "AI_SHELL_SECONDARY_CHAT_MODEL": ("secondary_chat_model", str),
536 "AI_SHELL_PRIMARY_CODING_MODEL": ("primary_coding_model", str),
537 "AI_SHELL_SECONDARY_CODING_MODEL": ("secondary_coding_model", str),
538 "AI_SHELL_CONTEXT_SIZE": ("context_size", int),
539 "AI_SHELL_OLLAMA_PORT": ("ollama_port", int),
540 "AI_SHELL_WEBUI_PORT": ("webui_port", int),
541 "AI_SHELL_KOKORO_PORT": ("kokoro_port", int),
542 "AI_SHELL_KOKORO_VOICE": ("kokoro_voice", str),
543 "AI_SHELL_N8N_PORT": ("n8n_port", int),
544 "AI_SHELL_WHISPER_PORT": ("whisper_port", int),
545 "AI_SHELL_WHISPER_MODEL": ("whisper_model", str),
546 "AI_SHELL_COMFYUI_PORT": ("comfyui_port", int),
547 "AI_SHELL_AI_PROFILE": ("ai_profile", str),
548 "AI_SHELL_AWS_REGION": ("aws_region", str),
549 "AI_SHELL_BEDROCK_PROFILE": ("bedrock_profile", str),
550 "AI_SHELL_BEDROCK_REGION": ("bedrock_region", str),
551 "AI_SHELL_BEDROCK_MODEL": ("bedrock_model", str),
552 "AI_SHELL_OPENAI_PROFILE": ("openai_profile", str),
553 "AI_SHELL_CLAUDE_PROVIDER": ("claude_provider", str),
554 "AI_SHELL_LOCAL_CHROME": ("local_chrome", bool),
555 "AI_SHELL_SKIP_UPDATES": ("skip_updates", bool),
556 "AI_SHELL_IMAGE_PULL_CACHE_TTL": ("image_pull_cache_ttl", int),
557 "AI_SHELL_BEDROCK_CHECK_CACHE_TTL": ("bedrock_check_cache_ttl", int),
558 }
560 for env_key, (attr, type_fn) in env_map.items():
561 value = os.environ.get(env_key)
562 if value is not None:
563 if type_fn is bool:
564 coerced = value.lower() not in ("0", "false", "no", "")
565 else:
566 coerced = type_fn(value)
567 setattr(config, attr, coerced)
568 logger.debug("Config override from env: %s=%s", env_key, value)
570 # AI_SHELL_PORTS is comma-separated, extends extra_ports
571 ports_value = os.environ.get("AI_SHELL_PORTS")
572 if ports_value:
573 config.extra_ports.extend(int(p.strip()) for p in ports_value.split(",") if p.strip())
575 # Nested voice_agent overrides (flat env vars map to nested fields)
576 voice_agent_port = os.environ.get("AI_SHELL_VOICE_AGENT_PORT")
577 if voice_agent_port is not None:
578 config.voice_agent.port = int(voice_agent_port)
579 voice_agent_domain = os.environ.get("AI_SHELL_VOICE_AGENT_DOMAIN")
580 if voice_agent_domain is not None:
581 config.voice_agent.domain = voice_agent_domain
582 voice_agent_profile = os.environ.get("AI_SHELL_VOICE_AGENT_PROFILE")
583 if voice_agent_profile is not None:
584 config.voice_agent.profile = voice_agent_profile