Coverage for src/ai_shell/typeahead.py: 49%

61 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-05 22:06 +0000

1"""Capture stdin keystrokes during the slow startup phase so they can be replayed 

2into the interactive process once it attaches. 

3 

4Without this the user has to wait for the dev container to come up (~20 s for 

5image checks + tool freshness checks) before any typing reaches the inner shell. 

6Anything typed during that window is otherwise either lost or interpreted by the 

7parent shell after the CLI exits. 

8""" 

9 

10from __future__ import annotations 

11 

12import os 

13import sys 

14import threading 

15from contextlib import contextmanager 

16from typing import TYPE_CHECKING 

17 

18if TYPE_CHECKING: 

19 from collections.abc import Iterator 

20 

21_DISABLE_ENV = "AI_SHELL_NO_TYPEAHEAD" 

22 

23 

24class TypeaheadBuffer: 

25 """Thread-safe accumulator for raw stdin bytes.""" 

26 

27 def __init__(self) -> None: 

28 self._chunks: list[bytes] = [] 

29 self._lock = threading.Lock() 

30 

31 def append(self, data: bytes) -> None: 

32 if not data: 

33 return 

34 with self._lock: 

35 self._chunks.append(data) 

36 

37 def bytes(self) -> bytes: 

38 with self._lock: 

39 return b"".join(self._chunks) 

40 

41 

42def _capture_disabled() -> bool: 

43 if os.environ.get(_DISABLE_ENV): 

44 return True 

45 try: 

46 return not sys.stdin.isatty() 

47 except (ValueError, OSError): 

48 return True 

49 

50 

51@contextmanager 

52def capture_typeahead() -> Iterator[TypeaheadBuffer]: 

53 """Drain stdin into an in-memory buffer until the context exits. 

54 

55 No-op when stdin isn't a TTY or when ``AI_SHELL_NO_TYPEAHEAD=1`` is set; in 

56 those cases the yielded buffer stays empty and the caller falls back to the 

57 existing path. 

58 """ 

59 buf = TypeaheadBuffer() 

60 

61 if sys.platform == "win32" or _capture_disabled(): 

62 yield buf 

63 return 

64 

65 # Imports are guarded behind the platform check above because termios/tty 

66 # are POSIX-only. 

67 import select 

68 import termios 

69 import tty 

70 

71 fd = sys.stdin.fileno() 

72 original = termios.tcgetattr(fd) 

73 stop = threading.Event() 

74 

75 def _drain() -> None: 

76 while not stop.is_set(): 

77 try: 

78 ready, _, _ = select.select([fd], [], [], 0.05) 

79 except (OSError, ValueError): 

80 return 

81 if fd in ready: 

82 try: 

83 chunk = os.read(fd, 4096) 

84 except (OSError, BlockingIOError): 

85 continue 

86 if not chunk: 

87 return 

88 buf.append(chunk) 

89 

90 try: 

91 tty.setcbreak(fd) 

92 thread = threading.Thread(target=_drain, daemon=True) 

93 thread.start() 

94 try: 

95 yield buf 

96 finally: 

97 stop.set() 

98 thread.join(timeout=0.5) 

99 finally: 

100 termios.tcsetattr(fd, termios.TCSADRAIN, original)