Coverage for src/ai_shell/typeahead.py: 49%
61 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-05 22:06 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-05 22:06 +0000
1"""Capture stdin keystrokes during the slow startup phase so they can be replayed
2into the interactive process once it attaches.
4Without this the user has to wait for the dev container to come up (~20 s for
5image checks + tool freshness checks) before any typing reaches the inner shell.
6Anything typed during that window is otherwise either lost or interpreted by the
7parent shell after the CLI exits.
8"""
10from __future__ import annotations
12import os
13import sys
14import threading
15from contextlib import contextmanager
16from typing import TYPE_CHECKING
18if TYPE_CHECKING:
19 from collections.abc import Iterator
21_DISABLE_ENV = "AI_SHELL_NO_TYPEAHEAD"
24class TypeaheadBuffer:
25 """Thread-safe accumulator for raw stdin bytes."""
27 def __init__(self) -> None:
28 self._chunks: list[bytes] = []
29 self._lock = threading.Lock()
31 def append(self, data: bytes) -> None:
32 if not data:
33 return
34 with self._lock:
35 self._chunks.append(data)
37 def bytes(self) -> bytes:
38 with self._lock:
39 return b"".join(self._chunks)
42def _capture_disabled() -> bool:
43 if os.environ.get(_DISABLE_ENV):
44 return True
45 try:
46 return not sys.stdin.isatty()
47 except (ValueError, OSError):
48 return True
51@contextmanager
52def capture_typeahead() -> Iterator[TypeaheadBuffer]:
53 """Drain stdin into an in-memory buffer until the context exits.
55 No-op when stdin isn't a TTY or when ``AI_SHELL_NO_TYPEAHEAD=1`` is set; in
56 those cases the yielded buffer stays empty and the caller falls back to the
57 existing path.
58 """
59 buf = TypeaheadBuffer()
61 if sys.platform == "win32" or _capture_disabled():
62 yield buf
63 return
65 # Imports are guarded behind the platform check above because termios/tty
66 # are POSIX-only.
67 import select
68 import termios
69 import tty
71 fd = sys.stdin.fileno()
72 original = termios.tcgetattr(fd)
73 stop = threading.Event()
75 def _drain() -> None:
76 while not stop.is_set():
77 try:
78 ready, _, _ = select.select([fd], [], [], 0.05)
79 except (OSError, ValueError):
80 return
81 if fd in ready:
82 try:
83 chunk = os.read(fd, 4096)
84 except (OSError, BlockingIOError):
85 continue
86 if not chunk:
87 return
88 buf.append(chunk)
90 try:
91 tty.setcbreak(fd)
92 thread = threading.Thread(target=_drain, daemon=True)
93 thread.start()
94 try:
95 yield buf
96 finally:
97 stop.set()
98 thread.join(timeout=0.5)
99 finally:
100 termios.tcsetattr(fd, termios.TCSADRAIN, original)