265 lines
5.7 KiB
Go
265 lines
5.7 KiB
Go
package harness
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"testing"
|
|
"time"
|
|
|
|
pkgpty "github.com/hjbdev/patterm/internal/pty"
|
|
"github.com/hjbdev/patterm/internal/vt"
|
|
)
|
|
|
|
type Session struct {
|
|
pty *pkgpty.PTY
|
|
em *vt.GhosttyEmulator
|
|
mcp *MCPClient
|
|
env *testEnv
|
|
|
|
bytesMu sync.Mutex
|
|
bytes []byte
|
|
|
|
lastWriteNS atomic.Int64
|
|
readerDone chan struct{}
|
|
closeOnce sync.Once
|
|
closeErr error
|
|
}
|
|
|
|
func New(t testing.TB, opts Options) *Session {
|
|
t.Helper()
|
|
s, err := NewCLI(opts)
|
|
if err != nil {
|
|
t.Fatalf("harness New: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = s.Close() })
|
|
return s
|
|
}
|
|
|
|
func NewCLI(opts Options) (*Session, error) {
|
|
if opts.Scenario == nil {
|
|
return nil, fmt.Errorf("harness: Scenario required")
|
|
}
|
|
env, childEnv, err := prepareEnv(opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
em, err := vt.NewGhosttyEmulator(env.Cols, env.Rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p, err := pkgpty.Start([]string{env.PattermBin, "--project", env.ProjectDir}, childEnv, env.Cols, env.Rows)
|
|
if err != nil {
|
|
_ = em.Close()
|
|
return nil, err
|
|
}
|
|
em.OnWritePTY(func(b []byte) { _, _ = p.Write(b) })
|
|
s := &Session{pty: p, em: em, env: env, readerDone: make(chan struct{})}
|
|
go s.readLoop()
|
|
if err := s.bootstrapMCP(2 * time.Second); err != nil {
|
|
_ = s.Close()
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Session) readLoop() {
|
|
defer close(s.readerDone)
|
|
buf := make([]byte, 64*1024)
|
|
for {
|
|
n, err := s.pty.Read(buf)
|
|
if n > 0 {
|
|
chunk := make([]byte, n)
|
|
copy(chunk, buf[:n])
|
|
_, _ = s.em.Write(chunk)
|
|
s.bytesMu.Lock()
|
|
s.bytes = append(s.bytes, chunk...)
|
|
s.bytesMu.Unlock()
|
|
s.lastWriteNS.Store(time.Now().UnixNano())
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) bootstrapMCP(timeout time.Duration) error {
|
|
socket := filepath.Join(s.env.RuntimeDir, "patterm", fmt.Sprintf("%d.sock", s.pty.Pid()))
|
|
deadline := time.Now().Add(timeout)
|
|
var last error
|
|
for time.Now().Before(deadline) {
|
|
if _, err := os.Stat(socket); err != nil {
|
|
last = err
|
|
time.Sleep(25 * time.Millisecond)
|
|
continue
|
|
}
|
|
c, err := DialMCP(socket)
|
|
if err != nil {
|
|
last = err
|
|
time.Sleep(25 * time.Millisecond)
|
|
continue
|
|
}
|
|
_, err = c.Call("whoami", map[string]any{})
|
|
if err == nil {
|
|
s.mcp = c
|
|
return nil
|
|
}
|
|
last = err
|
|
_ = c.Close()
|
|
if strings.Contains(err.Error(), "tool host not initialized") {
|
|
time.Sleep(25 * time.Millisecond)
|
|
continue
|
|
}
|
|
time.Sleep(25 * time.Millisecond)
|
|
}
|
|
raw := strings.TrimSpace(string(s.rawBytes()))
|
|
if raw != "" {
|
|
return fmt.Errorf("mcp bootstrap timed out: %w; child output: %s", last, raw)
|
|
}
|
|
return fmt.Errorf("mcp bootstrap timed out: %w", last)
|
|
}
|
|
|
|
func (s *Session) Close() error {
|
|
s.closeOnce.Do(func() {
|
|
if s.mcp != nil {
|
|
_ = s.mcp.Close()
|
|
}
|
|
pid := s.pty.Pid()
|
|
if pid > 0 {
|
|
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
|
|
_ = syscall.Kill(pid, syscall.SIGTERM)
|
|
}
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() { done <- s.pty.Wait() }()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
if pid > 0 {
|
|
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
|
_ = syscall.Kill(pid, syscall.SIGKILL)
|
|
}
|
|
}
|
|
select {
|
|
case <-done:
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
}
|
|
select {
|
|
case <-s.readerDone:
|
|
case <-time.After(time.Second):
|
|
if err := s.pty.Close(); err != nil && !errors.Is(err, os.ErrClosed) {
|
|
s.closeErr = err
|
|
}
|
|
select {
|
|
case <-s.readerDone:
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
}
|
|
_ = s.em.Close()
|
|
})
|
|
return s.closeErr
|
|
}
|
|
|
|
func (s *Session) SendChord(name string) error {
|
|
b, err := EncodeChord(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.pty.Write(b)
|
|
return err
|
|
}
|
|
|
|
func (s *Session) SendText(text string) error {
|
|
_, err := s.pty.Write([]byte(text))
|
|
return err
|
|
}
|
|
|
|
func (s *Session) Screen() (string, error) { return s.em.PlainText() }
|
|
|
|
func (s *Session) Cursor() (vt.CursorState, error) { return s.em.Cursor() }
|
|
|
|
func (s *Session) WaitForStable(timeout time.Duration) error {
|
|
deadline := time.Now().Add(timeout)
|
|
tick := time.NewTicker(25 * time.Millisecond)
|
|
defer tick.Stop()
|
|
confirmed := false
|
|
for {
|
|
last := s.lastWriteNS.Load()
|
|
idle := last == 0 || time.Since(time.Unix(0, last)) >= time.Second
|
|
if idle {
|
|
if confirmed {
|
|
return nil
|
|
}
|
|
confirmed = true
|
|
} else {
|
|
confirmed = false
|
|
}
|
|
if time.Now().After(deadline) {
|
|
return fmt.Errorf("screen did not stabilize within %s", timeout)
|
|
}
|
|
<-tick.C
|
|
}
|
|
}
|
|
|
|
func (s *Session) WaitForText(text string, timeout time.Duration) error {
|
|
deadline := time.Now().Add(timeout)
|
|
for time.Now().Before(deadline) {
|
|
screen, err := s.Screen()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if strings.Contains(screen, text) {
|
|
return nil
|
|
}
|
|
time.Sleep(25 * time.Millisecond)
|
|
}
|
|
screen, _ := s.Screen()
|
|
return fmt.Errorf("text %q not found before timeout; screen:\n%s", text, screen)
|
|
}
|
|
|
|
func (s *Session) WaitForRegex(pattern string, timeout time.Duration) error {
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
deadline := time.Now().Add(timeout)
|
|
for time.Now().Before(deadline) {
|
|
screen, err := s.Screen()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if re.MatchString(screen) {
|
|
return nil
|
|
}
|
|
time.Sleep(25 * time.Millisecond)
|
|
}
|
|
screen, _ := s.Screen()
|
|
return fmt.Errorf("regex %q not found before timeout; screen:\n%s", pattern, screen)
|
|
}
|
|
|
|
func (s *Session) MCPCall(method string, params json.RawMessage) (json.RawMessage, error) {
|
|
var v any = map[string]any{}
|
|
if len(params) > 0 {
|
|
if err := json.Unmarshal(params, &v); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return s.mcp.Call(method, v)
|
|
}
|
|
|
|
func (s *Session) rawBytes() []byte {
|
|
s.bytesMu.Lock()
|
|
defer s.bytesMu.Unlock()
|
|
out := make([]byte, len(s.bytes))
|
|
copy(out, s.bytes)
|
|
return out
|
|
}
|