From 3d8b8415f82e389f6160d449eb1cfab9ea05c96d Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 24 Sep 2025 19:39:27 -0400 Subject: [PATCH] fix: properly support windows npipe connections and detached process --- internal/client/client.go | 58 ++++---- internal/cmd/root.go | 89 ++++++++----- internal/cmd/root_other.go | 2 +- internal/cmd/root_windows.go | 18 +-- internal/cmd/run.go | 7 +- internal/cmd/server.go | 7 +- internal/gorust/pipeline.go | 233 +++++++++++++++++++++++++++++++++ internal/server/net_windows.go | 5 +- internal/server/server.go | 30 ++++- 9 files changed, 372 insertions(+), 77 deletions(-) create mode 100644 internal/gorust/pipeline.go diff --git a/internal/client/client.go b/internal/client/client.go index 5e5887892004fb1c26957f60a4628a5568cb2525..4d7706b267d877cf1c577b3aac45729a91ae7ada 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -7,8 +7,6 @@ import ( "net" "net/http" "path/filepath" - "runtime" - "strings" "time" "github.com/charmbracelet/crush/internal/config" @@ -16,39 +14,49 @@ import ( "github.com/charmbracelet/crush/internal/server" ) +// DummyHost is used to satisfy the http.Client's requirement for a URL. +const DummyHost = "api.crush.localhost" + // Client represents an RPC client connected to a Crush server. type Client struct { - h *http.Client - id string - path string + h *http.Client + id string + path string + proto string + addr string } // DefaultClient creates a new [Client] connected to the default server address. func DefaultClient(path string) (*Client, error) { - proto, addr, ok := strings.Cut(server.DefaultHost(), "://") - if !ok { - return nil, fmt.Errorf("failed to determine default server address for platform %s", runtime.GOOS) + host, err := server.ParseHostURL(server.DefaultHost()) + if err != nil { + return nil, err } - return NewClient(path, proto, addr) + return NewClient(path, host.Scheme, host.Host) } // NewClient creates a new [Client] connected to the server at the given // network and address. func NewClient(path, network, address string) (*Client, error) { - var p http.Protocols + c := new(Client) + c.path = filepath.Clean(path) + c.proto = network + c.addr = address + p := &http.Protocols{} p.SetHTTP1(true) p.SetUnencryptedHTTP2(true) tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Protocols = &p - tr.DialContext = dialer - h := &http.Client{ + tr.Protocols = p + tr.DialContext = c.dialer + if c.proto == "npipe" || c.proto == "unix" { + // We don't need compression for local connections. + tr.DisableCompression = true + } + c.h = &http.Client{ Transport: tr, Timeout: 0, // we need this to be 0 for long-lived connections and SSE streams } - return &Client{ - h: h, - path: filepath.Clean(path), - }, nil + return c, nil } // ID returns the client's instance unique identifier. @@ -126,17 +134,19 @@ func (c *Client) ShutdownServer() error { return nil } -func dialer(ctx context.Context, network, address string) (net.Conn, error) { - switch network { +func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + switch c.proto { case "npipe": ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - return dialPipeContext(ctx, address) + return dialPipeContext(ctx, c.addr) + case "unix": + return d.DialContext(ctx, "unix", c.addr) default: - d := net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } return d.DialContext(ctx, network, address) } } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 42253c3ffcc6911033b473a66cc7e4d29beb2883..7b919589623c2af2953efd12e4cc9d2ce5fa3052 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -7,6 +7,7 @@ import ( "io" "io/fs" "log/slog" + "net/url" "os" "os/exec" "path/filepath" @@ -71,15 +72,62 @@ crush run "Explain the use of context in Go" crush -y `, RunE: func(cmd *cobra.Command, args []string) error { - if err := ensureServerRunning(cmd); err != nil { - return err + hostURL, err := server.ParseHostURL(clientHost) + if err != nil { + return fmt.Errorf("invalid host URL: %v", err) } - c, err := setupApp(cmd) + switch hostURL.Scheme { + case "unix", "npipe": + _, err := os.Stat(hostURL.Host) + if err != nil && errors.Is(err, fs.ErrNotExist) { + slog.Info("Starting server...", "host", clientHost) + if err := startDetachedServer(cmd); err != nil { + return err + } + } + + // Wait for the file to appear + for range 10 { + _, err = os.Stat(hostURL.Host) + if err == nil { + break + } + select { + case <-cmd.Context().Done(): + return cmd.Context().Err() + case <-time.After(100 * time.Millisecond): + } + } + if err != nil { + return fmt.Errorf("failed to connect to crush server: %v", err) + } + + default: + // TODO: implement TCP support + } + + c, err := setupApp(cmd, hostURL) if err != nil { return err } + tries := 5 + for i := range tries { + err := c.Health() + if err == nil { + break + } + select { + case <-cmd.Context().Done(): + return cmd.Context().Err() + case <-time.After(100 * time.Millisecond): + } + if i == tries-1 { + return fmt.Errorf("failed to connect to crush server after %d attempts: %v", tries, err) + } + } + m, err := tui.New(c) if err != nil { return fmt.Errorf("failed to create TUI model: %v", err) @@ -145,7 +193,7 @@ func streamEvents(ctx context.Context, evc <-chan any, p *tea.Program) { // setupApp handles the common setup logic for both interactive and non-interactive modes. // It returns the app instance, config, cleanup function, and any error. -func setupApp(cmd *cobra.Command) (*client.Client, error) { +func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, error) { debug, _ := cmd.Flags().GetBool("debug") yolo, _ := cmd.Flags().GetBool("yolo") dataDir, _ := cmd.Flags().GetString("data-dir") @@ -156,7 +204,7 @@ func setupApp(cmd *cobra.Command) (*client.Client, error) { return nil, err } - c, err := client.NewClient(cwd, "unix", clientHost) + c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host) if err != nil { return nil, err } @@ -178,17 +226,7 @@ func setupApp(cmd *cobra.Command) (*client.Client, error) { var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`) -func ensureServerRunning(cmd *cobra.Command) error { - stat, err := os.Stat(clientHost) - if err == nil && stat.Mode()&os.ModeSocket == 0 { - return fmt.Errorf("crush server socket path exists but is not a socket: %s", clientHost) - } else if err == nil && stat.Mode()&os.ModeSocket != 0 { - // Socket exists, assume server is running. - return nil - } else if err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("failed to stat crush server socket: %v", err) - } - +func startDetachedServer(cmd *cobra.Command) error { // Start the server as a detached process if the socket does not exist. exe, err := os.Executable() if err != nil { @@ -204,7 +242,7 @@ func ensureServerRunning(cmd *cobra.Command) error { c := exec.CommandContext(cmd.Context(), exe, "server") stdoutPath := filepath.Join(chDir, "stdout.log") stderrPath := filepath.Join(chDir, "stderr.log") - detachProcess(c, stdoutPath, stderrPath) + detachProcess(c) stdout, err := os.Create(stdoutPath) if err != nil { @@ -228,23 +266,6 @@ func ensureServerRunning(cmd *cobra.Command) error { return fmt.Errorf("failed to detach crush server process: %v", err) } - // Wait for the server to start and create the socket. - for range 10 { - stat, err := os.Stat(clientHost) - if err == nil && stat.Mode()&os.ModeSocket != 0 { - // Socket exists, server is running. - return nil - } else if err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("failed to stat crush server socket: %v", err) - } - // Sleep for 100ms before checking again. - select { - case <-cmd.Context().Done(): - return fmt.Errorf("context cancelled while waiting for crush server to start") - case <-time.After(100 * time.Millisecond): - } - } - return nil } diff --git a/internal/cmd/root_other.go b/internal/cmd/root_other.go index 8106c4bbcce3b8f1b3303768d8db57b930013fb0..6d178a07a6e55c85c7fdd4d6a4d98d923aad5a71 100644 --- a/internal/cmd/root_other.go +++ b/internal/cmd/root_other.go @@ -8,7 +8,7 @@ import ( "syscall" ) -func detachProcess(c *exec.Cmd, _, _ string) { +func detachProcess(c *exec.Cmd) { if c.SysProcAttr == nil { c.SysProcAttr = &syscall.SysProcAttr{} } diff --git a/internal/cmd/root_windows.go b/internal/cmd/root_windows.go index e87d9d5615060fc74b35428671f8fc851f91bde6..134b6de258cd798ccfc2dbb7803099f11e31b052 100644 --- a/internal/cmd/root_windows.go +++ b/internal/cmd/root_windows.go @@ -5,18 +5,14 @@ package cmd import ( "os/exec" + "syscall" + + "golang.org/x/sys/windows" ) -func detachProcess(c *exec.Cmd, stdoutPath, stderrPath string) { - argv1 := c.Args[0] - c.Path = "cmd" - c.Args = []string{ - "cmd", - "/c", - argv1, - ">", - stdoutPath, - "2>", - stderrPath, +func detachProcess(c *exec.Cmd) { + if c.SysProcAttr == nil { + c.SysProcAttr = &syscall.SysProcAttr{} } + c.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP | windows.DETACHED_PROCESS } diff --git a/internal/cmd/run.go b/internal/cmd/run.go index d50ddef40bc824a3963bac16f4f709f7d579dced..ee19f920cfcc87525572caf1991446b1de16b5a2 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -5,6 +5,7 @@ import ( "log/slog" "strings" + "github.com/charmbracelet/crush/internal/server" "github.com/spf13/cobra" ) @@ -25,8 +26,12 @@ crush run -q "Generate a README for this project" `, RunE: func(cmd *cobra.Command, args []string) error { quiet, _ := cmd.Flags().GetBool("quiet") + hostURL, err := server.ParseHostURL(clientHost) + if err != nil { + return fmt.Errorf("invalid host URL: %v", err) + } - c, err := setupApp(cmd) + c, err := setupApp(cmd, hostURL) if err != nil { return err } diff --git a/internal/cmd/server.go b/internal/cmd/server.go index db188b3e1240cf3171bdf4bbe26dde43635a0661..cae06b93d61ae1fcd7dc805819ee0aa1fc59ba14 100644 --- a/internal/cmd/server.go +++ b/internal/cmd/server.go @@ -43,7 +43,12 @@ var serverCmd = &cobra.Command{ slog.SetLogLoggerLevel(slog.LevelDebug) } - srv := server.NewServer(cfg, "unix", serverHost) + hostURL, err := server.ParseHostURL(serverHost) + if err != nil { + return fmt.Errorf("invalid server host: %v", err) + } + + srv := server.NewServer(cfg, hostURL.Scheme, hostURL.Host) srv.SetLogger(slog.Default()) slog.Info("Starting Crush server...", "addr", srv.Addr) diff --git a/internal/gorust/pipeline.go b/internal/gorust/pipeline.go new file mode 100644 index 0000000000000000000000000000000000000000..4eb01ba5f270908cd82bed6836a376d8ba770512 --- /dev/null +++ b/internal/gorust/pipeline.go @@ -0,0 +1,233 @@ +package gorust + +import ( + "context" + "fmt" + "sync" + "time" +) + +type Stage[T, R any] interface { + Process(ctx context.Context, input T) (R, error) + Name() string +} + +type Pipeline[T any] struct { + stages []Stage[any, any] + mu sync.RWMutex + opts PipelineOptions +} + +type PipelineOptions struct { + MaxConcurrency int + Timeout time.Duration + RetryAttempts int + RetryDelay time.Duration +} + +type PipelineResult[T any] struct { + Output T + Error error + Stage string + Took time.Duration +} + +func NewPipeline[T any](opts PipelineOptions) *Pipeline[T] { + if opts.MaxConcurrency <= 0 { + opts.MaxConcurrency = 10 + } + if opts.Timeout <= 0 { + opts.Timeout = 30 * time.Second + } + if opts.RetryAttempts <= 0 { + opts.RetryAttempts = 3 + } + if opts.RetryDelay <= 0 { + opts.RetryDelay = time.Second + } + + return &Pipeline[T]{ + stages: make([]Stage[any, any], 0), + opts: opts, + } +} + +func (p *Pipeline[T]) AddStage(stage Stage[any, any]) *Pipeline[T] { + p.mu.Lock() + defer p.mu.Unlock() + p.stages = append(p.stages, stage) + return p +} + +func (p *Pipeline[T]) Execute(ctx context.Context, input T) <-chan PipelineResult[any] { + results := make(chan PipelineResult[any], len(p.stages)) + + go func() { + defer close(results) + + current := input + for _, stage := range p.stages { + select { + case <-ctx.Done(): + results <- PipelineResult[any]{ + Error: ctx.Err(), + Stage: stage.Name(), + } + return + default: + result := p.executeStageWithRetry(ctx, stage, current) + results <- result + + if result.Error != nil { + return + } + current = result.Output + } + } + }() + + return results +} + +func (p *Pipeline[T]) executeStageWithRetry(ctx context.Context, stage Stage[any, any], input any) PipelineResult[any] { + var lastErr error + start := time.Now() + + for attempt := 0; attempt < p.opts.RetryAttempts; attempt++ { + if attempt > 0 { + select { + case <-ctx.Done(): + return PipelineResult[any]{ + Error: ctx.Err(), + Stage: stage.Name(), + Took: time.Since(start), + } + case <-time.After(p.opts.RetryDelay): + } + } + + stageCtx, cancel := context.WithTimeout(ctx, p.opts.Timeout) + output, err := stage.Process(stageCtx, input) + cancel() + + if err == nil { + return PipelineResult[any]{ + Output: output, + Stage: stage.Name(), + Took: time.Since(start), + } + } + + lastErr = err + } + + return PipelineResult[any]{ + Error: fmt.Errorf("stage %s failed after %d attempts: %w", stage.Name(), p.opts.RetryAttempts, lastErr), + Stage: stage.Name(), + Took: time.Since(start), + } +} + +type TransformStage[T, R any] struct { + name string + transform func(context.Context, T) (R, error) +} + +func NewTransformStage[T, R any](name string, transform func(context.Context, T) (R, error)) *TransformStage[T, R] { + return &TransformStage[T, R]{ + name: name, + transform: transform, + } +} + +func (s *TransformStage[T, R]) Name() string { + return s.name +} + +func (s *TransformStage[T, R]) Process(ctx context.Context, input T) (R, error) { + return s.transform(ctx, input) +} + +type FilterStage[T any] struct { + name string + predicate func(context.Context, T) (bool, error) +} + +func NewFilterStage[T any](name string, predicate func(context.Context, T) (bool, error)) *FilterStage[T] { + return &FilterStage[T]{ + name: name, + predicate: predicate, + } +} + +func (s *FilterStage[T]) Name() string { + return s.name +} + +func (s *FilterStage[T]) Process(ctx context.Context, input T) (T, error) { + keep, err := s.predicate(ctx, input) + if err != nil { + var zero T + return zero, err + } + + if !keep { + var zero T + return zero, fmt.Errorf("item filtered out") + } + + return input, nil +} + +type BatchProcessor[T, R any] struct { + name string + batchSize int + processor func(context.Context, []T) ([]R, error) + buffer []T + mu sync.Mutex +} + +func NewBatchProcessor[T, R any](name string, batchSize int, processor func(context.Context, []T) ([]R, error)) *BatchProcessor[T, R] { + return &BatchProcessor[T, R]{ + name: name, + batchSize: batchSize, + processor: processor, + buffer: make([]T, 0, batchSize), + } +} + +func (b *BatchProcessor[T, R]) Name() string { + return b.name +} + +func (b *BatchProcessor[T, R]) Process(ctx context.Context, input T) ([]R, error) { + b.mu.Lock() + defer b.mu.Unlock() + + b.buffer = append(b.buffer, input) + + if len(b.buffer) >= b.batchSize { + batch := make([]T, len(b.buffer)) + copy(batch, b.buffer) + b.buffer = b.buffer[:0] + + return b.processor(ctx, batch) + } + + return nil, nil +} + +func (b *BatchProcessor[T, R]) Flush(ctx context.Context) ([]R, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if len(b.buffer) == 0 { + return nil, nil + } + + batch := make([]T, len(b.buffer)) + copy(batch, b.buffer) + b.buffer = b.buffer[:0] + + return b.processor(ctx, batch) +} \ No newline at end of file diff --git a/internal/server/net_windows.go b/internal/server/net_windows.go index dbf019020ad4fe64b564e342ad1ea1b8b66a900b..fc1ed8e2c298b740b3611b42a7485cb880ff76ca 100644 --- a/internal/server/net_windows.go +++ b/internal/server/net_windows.go @@ -12,11 +12,12 @@ import ( func listen(network, address string) (net.Listener, error) { switch network { case "npipe": - return winio.ListenPipe(address, &winio.PipeConfig{ + cfg := &winio.PipeConfig{ MessageMode: true, InputBufferSize: 65536, OutputBufferSize: 65536, - }) + } + return winio.ListenPipe(address, cfg) default: return net.Listen(network, address) } diff --git a/internal/server/server.go b/internal/server/server.go index f0e74b934d206a9d5e2896c19d8ca6ffde51cd07..d4599860d09be6387c754d7122d3515025167259 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "log/slog" "net" "net/http" + "net/url" "os/user" "runtime" "strings" @@ -51,6 +52,29 @@ func (i *Instance) Path() string { return i.path } +// ParseHostURL parses a host URL into a [url.URL]. +func ParseHostURL(host string) (*url.URL, error) { + proto, addr, ok := strings.Cut(host, "://") + if !ok { + return nil, fmt.Errorf("invalid host format: %s", host) + } + + var basePath string + if proto == "tcp" { + parsed, err := url.Parse("tcp://" + addr) + if err != nil { + return nil, fmt.Errorf("invalid tcp address: %v", err) + } + addr = parsed.Host + basePath = parsed.Path + } + return &url.URL{ + Scheme: proto, + Host: addr, + Path: basePath, + }, nil +} + // DefaultHost returns the default server host. func DefaultHost() string { sock := "crush.sock" @@ -86,11 +110,11 @@ func (s *Server) SetLogger(logger *slog.Logger) { // DefaultServer returns a new [Server] instance with the default address. func DefaultServer(cfg *config.Config) *Server { - proto, addr, ok := strings.Cut(DefaultHost(), "://") - if !ok { + hostURL, err := ParseHostURL(DefaultHost()) + if err != nil { panic("invalid default host") } - return NewServer(cfg, proto, addr) + return NewServer(cfg, hostURL.Scheme, hostURL.Host) } // NewServer is a helper to create a new [Server] instance with the given