Coverage for src / ai_shell / config.py: 96%

136 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 22:12 +0000

1"""Configuration loading for ai-shell. 

2 

3Priority (highest wins): CLI flags > env vars > project config > global config > defaults. 

4 

5Config file lookup order (first match wins): 

6 .ai-shell.yaml > .ai-shell.yml > .ai-shell.toml > ai-shell.toml 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12import os 

13import tomllib 

14from dataclasses import dataclass, field 

15from pathlib import Path 

16 

17import yaml 

18 

19from ai_shell import __version__ 

20from ai_shell.defaults import ( 

21 DEFAULT_CONTEXT_SIZE, 

22 DEFAULT_DEV_PORTS, 

23 DEFAULT_FALLBACK_MODEL, 

24 DEFAULT_IMAGE, 

25 DEFAULT_OLLAMA_PORT, 

26 DEFAULT_PRIMARY_MODEL, 

27 DEFAULT_WEBUI_PORT, 

28) 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33@dataclass 

34class AiShellConfig: 

35 """Configuration for ai-shell.""" 

36 

37 # Container 

38 image: str = DEFAULT_IMAGE 

39 image_tag: str = __version__ 

40 project_name: str = "" 

41 project_dir: Path = field(default_factory=Path.cwd) 

42 

43 # LLM 

44 primary_model: str = DEFAULT_PRIMARY_MODEL 

45 fallback_model: str = DEFAULT_FALLBACK_MODEL 

46 context_size: int = DEFAULT_CONTEXT_SIZE 

47 ollama_port: int = DEFAULT_OLLAMA_PORT 

48 webui_port: int = DEFAULT_WEBUI_PORT 

49 

50 # Aider 

51 aider_model: str = f"ollama_chat/{DEFAULT_PRIMARY_MODEL}" 

52 

53 # Extra configuration 

54 extra_env: dict[str, str] = field(default_factory=dict) 

55 extra_volumes: list[str] = field(default_factory=list) 

56 extra_ports: list[int] = field(default_factory=list) 

57 

58 # AWS 

59 ai_profile: str = "" # AWS profile for infra (sets AWS_PROFILE in container) 

60 aws_region: str = "" # Override AWS_REGION 

61 bedrock_profile: str = "" # AWS profile for Bedrock LLM API calls 

62 

63 # Claude options 

64 local_chrome: bool = False # Attach Chrome DevTools MCP to host Chrome debug port 

65 

66 # Per-tool provider 

67 claude_provider: str = "" # "anthropic" (default) or "aws" 

68 opencode_provider: str = "" # "local" (default, Ollama) or "aws" (Bedrock) 

69 codex_provider: str = "" # "openai" (default) or "aws" (Bedrock) 

70 codex_openai_api_key: str = "" # OpenAI API key (if set, overrides mounted SSO auth) 

71 codex_profile: str = "" # AWS profile for Bedrock auth (when provider = "aws") 

72 

73 @property 

74 def full_image(self) -> str: 

75 """Return the full image reference with tag.""" 

76 return f"{self.image}:{self.image_tag}" 

77 

78 @property 

79 def dev_ports(self) -> list[int]: 

80 """Return deduplicated, sorted list of dev container ports to expose.""" 

81 return sorted(set(DEFAULT_DEV_PORTS + self.extra_ports)) 

82 

83 

84def load_config( 

85 project_override: str | None = None, 

86 project_dir: Path | None = None, 

87) -> AiShellConfig: 

88 """Load configuration from all sources. 

89 

90 Priority: CLI overrides > env vars > project toml > global toml > defaults. 

91 """ 

92 config = AiShellConfig() 

93 

94 if project_dir: 

95 config.project_dir = project_dir 

96 

97 # Load global config 

98 global_config_dir = Path.home() / ".config" / "ai-shell" 

99 for name in ("config.yaml", "config.yml", "config.toml"): 

100 candidate = global_config_dir / name 

101 if candidate.exists(): 

102 _apply_config(config, candidate) 

103 break 

104 

105 # Load project config (first match wins) 

106 for name in (".ai-shell.yaml", ".ai-shell.yml", ".ai-shell.toml", "ai-shell.toml"): 

107 candidate = config.project_dir / name 

108 if candidate.exists(): 

109 _apply_config(config, candidate) 

110 break 

111 

112 # Apply environment variable overrides 

113 _apply_env_vars(config) 

114 

115 # Apply CLI overrides 

116 if project_override: 

117 config.project_name = project_override 

118 

119 # Auto-derive project name from CWD if not set 

120 if not config.project_name: 

121 from ai_shell.defaults import sanitize_project_name 

122 

123 config.project_name = sanitize_project_name(config.project_dir) 

124 

125 return config 

126 

127 

128def _load_config_file(path: Path) -> dict: 

129 """Load a YAML or TOML config file and return the parsed dict.""" 

130 suffix = path.suffix.lower() 

131 if suffix in (".yaml", ".yml"): 

132 with open(path, encoding="utf-8") as f: 

133 return yaml.safe_load(f) or {} 

134 with open(path, "rb") as f: 

135 return tomllib.load(f) 

136 

137 

138def _apply_config(config: AiShellConfig, path: Path) -> None: 

139 """Apply settings from a YAML or TOML config file.""" 

140 try: 

141 data = _load_config_file(path) 

142 except (OSError, tomllib.TOMLDecodeError, yaml.YAMLError) as e: 

143 logger.warning("Failed to load config from %s: %s", path, e) 

144 return 

145 

146 logger.debug("Loading config from %s", path) 

147 

148 # [container] section 

149 container = data.get("container", {}) 

150 if "image" in container: 

151 config.image = container["image"] 

152 if "image_tag" in container: 

153 config.image_tag = container["image_tag"] 

154 if "extra_env" in container: 

155 config.extra_env.update(container["extra_env"]) 

156 if "extra_volumes" in container: 

157 config.extra_volumes.extend(container["extra_volumes"]) 

158 if "ports" in container: 

159 config.extra_ports.extend(int(p) for p in container["ports"]) 

160 

161 # [llm] section 

162 llm = data.get("llm", {}) 

163 if "primary_model" in llm: 

164 config.primary_model = llm["primary_model"] 

165 if "fallback_model" in llm: 

166 config.fallback_model = llm["fallback_model"] 

167 if "context_size" in llm: 

168 config.context_size = int(llm["context_size"]) 

169 if "ollama_port" in llm: 

170 config.ollama_port = int(llm["ollama_port"]) 

171 if "webui_port" in llm: 

172 config.webui_port = int(llm["webui_port"]) 

173 

174 # [aider] section 

175 aider = data.get("aider", {}) 

176 if "model" in aider: 

177 config.aider_model = aider["model"] 

178 

179 # [aws] section 

180 aws = data.get("aws", {}) 

181 if "ai_profile" in aws: 

182 config.ai_profile = aws["ai_profile"] 

183 if "region" in aws: 

184 config.aws_region = aws["region"] 

185 if "bedrock_profile" in aws: 

186 config.bedrock_profile = aws["bedrock_profile"] 

187 

188 # [claude] section 

189 claude_sec = data.get("claude", {}) 

190 if "provider" in claude_sec: 

191 config.claude_provider = claude_sec["provider"] 

192 if "local_chrome" in claude_sec: 

193 config.local_chrome = bool(claude_sec["local_chrome"]) 

194 

195 # [opencode] section 

196 opencode_sec = data.get("opencode", {}) 

197 if "provider" in opencode_sec: 

198 config.opencode_provider = opencode_sec["provider"] 

199 

200 # [codex] section 

201 codex_sec = data.get("codex", {}) 

202 if "provider" in codex_sec: 

203 config.codex_provider = codex_sec["provider"] 

204 if "openai_api_key" in codex_sec: 

205 config.codex_openai_api_key = codex_sec["openai_api_key"] 

206 if "profile" in codex_sec: 

207 config.codex_profile = codex_sec["profile"] 

208 

209 

210def _apply_env_vars(config: AiShellConfig) -> None: 

211 """Apply AI_SHELL_* environment variable overrides.""" 

212 env_map: dict[str, tuple[str, type]] = { 

213 "AI_SHELL_IMAGE": ("image", str), 

214 "AI_SHELL_IMAGE_TAG": ("image_tag", str), 

215 "AI_SHELL_PROJECT": ("project_name", str), 

216 "AI_SHELL_PRIMARY_MODEL": ("primary_model", str), 

217 "AI_SHELL_FALLBACK_MODEL": ("fallback_model", str), 

218 "AI_SHELL_CONTEXT_SIZE": ("context_size", int), 

219 "AI_SHELL_OLLAMA_PORT": ("ollama_port", int), 

220 "AI_SHELL_WEBUI_PORT": ("webui_port", int), 

221 "AI_SHELL_AIDER_MODEL": ("aider_model", str), 

222 "AI_SHELL_AI_PROFILE": ("ai_profile", str), 

223 "AI_SHELL_AWS_REGION": ("aws_region", str), 

224 "AI_SHELL_BEDROCK_PROFILE": ("bedrock_profile", str), 

225 "AI_SHELL_CLAUDE_PROVIDER": ("claude_provider", str), 

226 "AI_SHELL_OPENCODE_PROVIDER": ("opencode_provider", str), 

227 "AI_SHELL_CODEX_PROVIDER": ("codex_provider", str), 

228 "AI_SHELL_CODEX_OPENAI_API_KEY": ("codex_openai_api_key", str), 

229 "AI_SHELL_CODEX_PROFILE": ("codex_profile", str), 

230 "AI_SHELL_LOCAL_CHROME": ("local_chrome", bool), 

231 } 

232 

233 for env_key, (attr, type_fn) in env_map.items(): 

234 value = os.environ.get(env_key) 

235 if value is not None: 

236 if type_fn is bool: 

237 coerced = value.lower() not in ("0", "false", "no", "") 

238 else: 

239 coerced = type_fn(value) 

240 setattr(config, attr, coerced) 

241 logger.debug("Config override from env: %s=%s", env_key, value) 

242 

243 # AI_SHELL_PORTS is comma-separated, extends extra_ports 

244 ports_value = os.environ.get("AI_SHELL_PORTS") 

245 if ports_value: 

246 config.extra_ports.extend(int(p.strip()) for p in ports_value.split(",") if p.strip())