Files
patterm/internal/harness/session.go

265 lines
5.7 KiB
Go

package harness
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"
pkgpty "github.com/hjbdev/patterm/internal/pty"
"github.com/hjbdev/patterm/internal/vt"
)
type Session struct {
pty *pkgpty.PTY
em *vt.GhosttyEmulator
mcp *MCPClient
env *testEnv
bytesMu sync.Mutex
bytes []byte
lastWriteNS atomic.Int64
readerDone chan struct{}
closeOnce sync.Once
closeErr error
}
func New(t testing.TB, opts Options) *Session {
t.Helper()
s, err := NewCLI(opts)
if err != nil {
t.Fatalf("harness New: %v", err)
}
t.Cleanup(func() { _ = s.Close() })
return s
}
func NewCLI(opts Options) (*Session, error) {
if opts.Scenario == nil {
return nil, fmt.Errorf("harness: Scenario required")
}
env, childEnv, err := prepareEnv(opts)
if err != nil {
return nil, err
}
em, err := vt.NewGhosttyEmulator(env.Cols, env.Rows)
if err != nil {
return nil, err
}
p, err := pkgpty.Start([]string{env.PattermBin, "--project", env.ProjectDir}, childEnv, env.Cols, env.Rows)
if err != nil {
_ = em.Close()
return nil, err
}
em.OnWritePTY(func(b []byte) { _, _ = p.Write(b) })
s := &Session{pty: p, em: em, env: env, readerDone: make(chan struct{})}
go s.readLoop()
if err := s.bootstrapMCP(2 * time.Second); err != nil {
_ = s.Close()
return nil, err
}
return s, nil
}
func (s *Session) readLoop() {
defer close(s.readerDone)
buf := make([]byte, 64*1024)
for {
n, err := s.pty.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
_, _ = s.em.Write(chunk)
s.bytesMu.Lock()
s.bytes = append(s.bytes, chunk...)
s.bytesMu.Unlock()
s.lastWriteNS.Store(time.Now().UnixNano())
}
if err != nil {
return
}
}
}
func (s *Session) bootstrapMCP(timeout time.Duration) error {
socket := filepath.Join(s.env.RuntimeDir, "patterm", fmt.Sprintf("%d.sock", s.pty.Pid()))
deadline := time.Now().Add(timeout)
var last error
for time.Now().Before(deadline) {
if _, err := os.Stat(socket); err != nil {
last = err
time.Sleep(25 * time.Millisecond)
continue
}
c, err := DialMCP(socket)
if err != nil {
last = err
time.Sleep(25 * time.Millisecond)
continue
}
_, err = c.Call("whoami", map[string]any{})
if err == nil {
s.mcp = c
return nil
}
last = err
_ = c.Close()
if strings.Contains(err.Error(), "tool host not initialized") {
time.Sleep(25 * time.Millisecond)
continue
}
time.Sleep(25 * time.Millisecond)
}
raw := strings.TrimSpace(string(s.rawBytes()))
if raw != "" {
return fmt.Errorf("mcp bootstrap timed out: %w; child output: %s", last, raw)
}
return fmt.Errorf("mcp bootstrap timed out: %w", last)
}
func (s *Session) Close() error {
s.closeOnce.Do(func() {
if s.mcp != nil {
_ = s.mcp.Close()
}
pid := s.pty.Pid()
if pid > 0 {
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
_ = syscall.Kill(pid, syscall.SIGTERM)
}
}
done := make(chan error, 1)
go func() { done <- s.pty.Wait() }()
select {
case <-done:
case <-time.After(2 * time.Second):
if pid > 0 {
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
_ = syscall.Kill(pid, syscall.SIGKILL)
}
}
select {
case <-done:
case <-time.After(500 * time.Millisecond):
}
}
select {
case <-s.readerDone:
case <-time.After(time.Second):
if err := s.pty.Close(); err != nil && !errors.Is(err, os.ErrClosed) {
s.closeErr = err
}
select {
case <-s.readerDone:
case <-time.After(500 * time.Millisecond):
}
}
_ = s.em.Close()
})
return s.closeErr
}
func (s *Session) SendChord(name string) error {
b, err := EncodeChord(name)
if err != nil {
return err
}
_, err = s.pty.Write(b)
return err
}
func (s *Session) SendText(text string) error {
_, err := s.pty.Write([]byte(text))
return err
}
func (s *Session) Screen() (string, error) { return s.em.PlainText() }
func (s *Session) Cursor() (vt.CursorState, error) { return s.em.Cursor() }
func (s *Session) WaitForStable(timeout time.Duration) error {
deadline := time.Now().Add(timeout)
tick := time.NewTicker(25 * time.Millisecond)
defer tick.Stop()
confirmed := false
for {
last := s.lastWriteNS.Load()
idle := last == 0 || time.Since(time.Unix(0, last)) >= time.Second
if idle {
if confirmed {
return nil
}
confirmed = true
} else {
confirmed = false
}
if time.Now().After(deadline) {
return fmt.Errorf("screen did not stabilize within %s", timeout)
}
<-tick.C
}
}
func (s *Session) WaitForText(text string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
screen, err := s.Screen()
if err != nil {
return err
}
if strings.Contains(screen, text) {
return nil
}
time.Sleep(25 * time.Millisecond)
}
screen, _ := s.Screen()
return fmt.Errorf("text %q not found before timeout; screen:\n%s", text, screen)
}
func (s *Session) WaitForRegex(pattern string, timeout time.Duration) error {
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
screen, err := s.Screen()
if err != nil {
return err
}
if re.MatchString(screen) {
return nil
}
time.Sleep(25 * time.Millisecond)
}
screen, _ := s.Screen()
return fmt.Errorf("regex %q not found before timeout; screen:\n%s", pattern, screen)
}
func (s *Session) MCPCall(method string, params json.RawMessage) (json.RawMessage, error) {
var v any = map[string]any{}
if len(params) > 0 {
if err := json.Unmarshal(params, &v); err != nil {
return nil, err
}
}
return s.mcp.Call(method, v)
}
func (s *Session) rawBytes() []byte {
s.bytesMu.Lock()
defer s.bytesMu.Unlock()
out := make([]byte, len(s.bytes))
copy(out, s.bytes)
return out
}