add tcp daemon listener with token auth
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -20,11 +21,15 @@ import (
|
||||
)
|
||||
|
||||
type DaemonOptions struct {
|
||||
ProjectDir string
|
||||
SocketPath string
|
||||
PidPath string
|
||||
Cols uint16
|
||||
Rows uint16
|
||||
ProjectDir string
|
||||
SocketPath string
|
||||
PidPath string
|
||||
ListenAddr string
|
||||
Token string
|
||||
TokenOut io.Writer
|
||||
ListenReady chan string
|
||||
Cols uint16
|
||||
Rows uint16
|
||||
}
|
||||
|
||||
type DaemonStatus struct {
|
||||
@@ -113,28 +118,101 @@ func RunDaemon(ctx context.Context, opts DaemonOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var tcpLn net.Listener
|
||||
tcpToken := opts.Token
|
||||
if opts.ListenAddr != "" {
|
||||
addr := normalizeListenAddr(opts.ListenAddr)
|
||||
tcpToken, err = ensureDaemonToken(tcpToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpLn, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("daemon: listen tcp %s: %w", addr, err)
|
||||
}
|
||||
defer tcpLn.Close()
|
||||
if opts.ListenReady != nil {
|
||||
select {
|
||||
case opts.ListenReady <- tcpLn.Addr().String():
|
||||
default:
|
||||
}
|
||||
}
|
||||
out := opts.TokenOut
|
||||
if out == nil {
|
||||
out = os.Stderr
|
||||
}
|
||||
fmt.Fprintf(out, "patterm daemon listening on %s\npatterm token: %s\n", tcpLn.Addr().String(), tcpToken)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = ln.Close()
|
||||
if tcpLn != nil {
|
||||
_ = tcpLn.Close()
|
||||
}
|
||||
}()
|
||||
errCh := make(chan error, 2)
|
||||
go acceptDaemonLoop(ctx, &wg, ln, "", cancel, registry, errCh)
|
||||
if tcpLn != nil {
|
||||
go acceptDaemonLoop(ctx, &wg, tcpLn, tcpToken, cancel, registry, errCh)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case err := <-errCh:
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func acceptDaemonLoop(ctx context.Context, wg *sync.WaitGroup, ln net.Listener, authToken string, stop func(), registry *ProjectRegistry, errCh chan<- error) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
|
||||
wg.Wait()
|
||||
return nil
|
||||
return
|
||||
}
|
||||
continue
|
||||
select {
|
||||
case errCh <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
handleDaemonConn(ctx, cancel, registry, protocol.NewConnTransport(conn))
|
||||
handleDaemonConn(ctx, stop, registry, protocol.NewConnTransport(conn), authToken)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeListenAddr(addr string) string {
|
||||
addr = strings.TrimSpace(addr)
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(addr); err == nil {
|
||||
return addr
|
||||
}
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
return addr
|
||||
}
|
||||
if _, err := strconv.Atoi(addr); err == nil {
|
||||
return ":" + addr
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func ensureDaemonToken(token string) (string, error) {
|
||||
if strings.TrimSpace(token) != "" {
|
||||
return strings.TrimSpace(token), nil
|
||||
}
|
||||
return LoadOrCreateClientToken()
|
||||
}
|
||||
|
||||
func prepareDaemonSocket(socketPath, pidPath string) (string, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(socketPath), 0o700); err != nil {
|
||||
return "", err
|
||||
@@ -163,7 +241,7 @@ func syscallSignal0(pid int) error {
|
||||
return syscall.Kill(pid, 0)
|
||||
}
|
||||
|
||||
func handleDaemonConn(ctx context.Context, stop func(), registry *ProjectRegistry, t protocol.Transport) {
|
||||
func handleDaemonConn(ctx context.Context, stop func(), registry *ProjectRegistry, t protocol.Transport, authToken string) {
|
||||
defer t.Close()
|
||||
f, err := t.Recv()
|
||||
if err != nil {
|
||||
@@ -178,6 +256,17 @@ func handleDaemonConn(ctx context.Context, stop func(), registry *ProjectRegistr
|
||||
stop()
|
||||
return
|
||||
case protocol.FrameAttach:
|
||||
if authToken != "" {
|
||||
attach, err := protocol.Decode[protocol.Attach](f)
|
||||
if err != nil {
|
||||
_ = sendProtocolError(t, err.Error())
|
||||
return
|
||||
}
|
||||
if attach.Token != authToken {
|
||||
_ = sendProtocolError(t, "auth denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
handleDaemonAttach(ctx, registry, t, f)
|
||||
default:
|
||||
_ = sendProtocolError(t, fmt.Sprintf("first frame must be attach, list, or stop; got %q", f.Type))
|
||||
|
||||
Reference in New Issue
Block a user