fix: properly support windows npipe connections and detached process

Ayman Bagabas created

Change summary

internal/client/client.go      |  58 +++++---
internal/cmd/root.go           |  89 ++++++++-----
internal/cmd/root_other.go     |   2 
internal/cmd/root_windows.go   |  18 +-
internal/cmd/run.go            |   7 
internal/cmd/server.go         |   7 
internal/gorust/pipeline.go    | 233 ++++++++++++++++++++++++++++++++++++
internal/server/net_windows.go |   5 
internal/server/server.go      |  30 ++++
9 files changed, 372 insertions(+), 77 deletions(-)

Detailed changes

internal/client/client.go 🔗

@@ -7,8 +7,6 @@ import (
 	"net"
 	"net/http"
 	"path/filepath"
-	"runtime"
-	"strings"
 	"time"
 
 	"github.com/charmbracelet/crush/internal/config"
@@ -16,39 +14,49 @@ import (
 	"github.com/charmbracelet/crush/internal/server"
 )
 
+// DummyHost is used to satisfy the http.Client's requirement for a URL.
+const DummyHost = "api.crush.localhost"
+
 // Client represents an RPC client connected to a Crush server.
 type Client struct {
-	h    *http.Client
-	id   string
-	path string
+	h     *http.Client
+	id    string
+	path  string
+	proto string
+	addr  string
 }
 
 // DefaultClient creates a new [Client] connected to the default server address.
 func DefaultClient(path string) (*Client, error) {
-	proto, addr, ok := strings.Cut(server.DefaultHost(), "://")
-	if !ok {
-		return nil, fmt.Errorf("failed to determine default server address for platform %s", runtime.GOOS)
+	host, err := server.ParseHostURL(server.DefaultHost())
+	if err != nil {
+		return nil, err
 	}
-	return NewClient(path, proto, addr)
+	return NewClient(path, host.Scheme, host.Host)
 }
 
 // NewClient creates a new [Client] connected to the server at the given
 // network and address.
 func NewClient(path, network, address string) (*Client, error) {
-	var p http.Protocols
+	c := new(Client)
+	c.path = filepath.Clean(path)
+	c.proto = network
+	c.addr = address
+	p := &http.Protocols{}
 	p.SetHTTP1(true)
 	p.SetUnencryptedHTTP2(true)
 	tr := http.DefaultTransport.(*http.Transport).Clone()
-	tr.Protocols = &p
-	tr.DialContext = dialer
-	h := &http.Client{
+	tr.Protocols = p
+	tr.DialContext = c.dialer
+	if c.proto == "npipe" || c.proto == "unix" {
+		// We don't need compression for local connections.
+		tr.DisableCompression = true
+	}
+	c.h = &http.Client{
 		Transport: tr,
 		Timeout:   0, // we need this to be 0 for long-lived connections and SSE streams
 	}
-	return &Client{
-		h:    h,
-		path: filepath.Clean(path),
-	}, nil
+	return c, nil
 }
 
 // ID returns the client's instance unique identifier.
@@ -126,17 +134,19 @@ func (c *Client) ShutdownServer() error {
 	return nil
 }
 
-func dialer(ctx context.Context, network, address string) (net.Conn, error) {
-	switch network {
+func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
+	d := net.Dialer{
+		Timeout:   30 * time.Second,
+		KeepAlive: 30 * time.Second,
+	}
+	switch c.proto {
 	case "npipe":
 		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
 		defer cancel()
-		return dialPipeContext(ctx, address)
+		return dialPipeContext(ctx, c.addr)
+	case "unix":
+		return d.DialContext(ctx, "unix", c.addr)
 	default:
-		d := net.Dialer{
-			Timeout:   30 * time.Second,
-			KeepAlive: 30 * time.Second,
-		}
 		return d.DialContext(ctx, network, address)
 	}
 }

internal/cmd/root.go 🔗

@@ -7,6 +7,7 @@ import (
 	"io"
 	"io/fs"
 	"log/slog"
+	"net/url"
 	"os"
 	"os/exec"
 	"path/filepath"
@@ -71,15 +72,62 @@ crush run "Explain the use of context in Go"
 crush -y
   `,
 	RunE: func(cmd *cobra.Command, args []string) error {
-		if err := ensureServerRunning(cmd); err != nil {
-			return err
+		hostURL, err := server.ParseHostURL(clientHost)
+		if err != nil {
+			return fmt.Errorf("invalid host URL: %v", err)
 		}
 
-		c, err := setupApp(cmd)
+		switch hostURL.Scheme {
+		case "unix", "npipe":
+			_, err := os.Stat(hostURL.Host)
+			if err != nil && errors.Is(err, fs.ErrNotExist) {
+				slog.Info("Starting server...", "host", clientHost)
+				if err := startDetachedServer(cmd); err != nil {
+					return err
+				}
+			}
+
+			// Wait for the file to appear
+			for range 10 {
+				_, err = os.Stat(hostURL.Host)
+				if err == nil {
+					break
+				}
+				select {
+				case <-cmd.Context().Done():
+					return cmd.Context().Err()
+				case <-time.After(100 * time.Millisecond):
+				}
+			}
+			if err != nil {
+				return fmt.Errorf("failed to connect to crush server: %v", err)
+			}
+
+		default:
+			// TODO: implement TCP support
+		}
+
+		c, err := setupApp(cmd, hostURL)
 		if err != nil {
 			return err
 		}
 
+		tries := 5
+		for i := range tries {
+			err := c.Health()
+			if err == nil {
+				break
+			}
+			select {
+			case <-cmd.Context().Done():
+				return cmd.Context().Err()
+			case <-time.After(100 * time.Millisecond):
+			}
+			if i == tries-1 {
+				return fmt.Errorf("failed to connect to crush server after %d attempts: %v", tries, err)
+			}
+		}
+
 		m, err := tui.New(c)
 		if err != nil {
 			return fmt.Errorf("failed to create TUI model: %v", err)
@@ -145,7 +193,7 @@ func streamEvents(ctx context.Context, evc <-chan any, p *tea.Program) {
 
 // setupApp handles the common setup logic for both interactive and non-interactive modes.
 // It returns the app instance, config, cleanup function, and any error.
-func setupApp(cmd *cobra.Command) (*client.Client, error) {
+func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, error) {
 	debug, _ := cmd.Flags().GetBool("debug")
 	yolo, _ := cmd.Flags().GetBool("yolo")
 	dataDir, _ := cmd.Flags().GetString("data-dir")
@@ -156,7 +204,7 @@ func setupApp(cmd *cobra.Command) (*client.Client, error) {
 		return nil, err
 	}
 
-	c, err := client.NewClient(cwd, "unix", clientHost)
+	c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
 	if err != nil {
 		return nil, err
 	}
@@ -178,17 +226,7 @@ func setupApp(cmd *cobra.Command) (*client.Client, error) {
 
 var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)
 
-func ensureServerRunning(cmd *cobra.Command) error {
-	stat, err := os.Stat(clientHost)
-	if err == nil && stat.Mode()&os.ModeSocket == 0 {
-		return fmt.Errorf("crush server socket path exists but is not a socket: %s", clientHost)
-	} else if err == nil && stat.Mode()&os.ModeSocket != 0 {
-		// Socket exists, assume server is running.
-		return nil
-	} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
-		return fmt.Errorf("failed to stat crush server socket: %v", err)
-	}
-
+func startDetachedServer(cmd *cobra.Command) error {
 	// Start the server as a detached process if the socket does not exist.
 	exe, err := os.Executable()
 	if err != nil {
@@ -204,7 +242,7 @@ func ensureServerRunning(cmd *cobra.Command) error {
 	c := exec.CommandContext(cmd.Context(), exe, "server")
 	stdoutPath := filepath.Join(chDir, "stdout.log")
 	stderrPath := filepath.Join(chDir, "stderr.log")
-	detachProcess(c, stdoutPath, stderrPath)
+	detachProcess(c)
 
 	stdout, err := os.Create(stdoutPath)
 	if err != nil {
@@ -228,23 +266,6 @@ func ensureServerRunning(cmd *cobra.Command) error {
 		return fmt.Errorf("failed to detach crush server process: %v", err)
 	}
 
-	// Wait for the server to start and create the socket.
-	for range 10 {
-		stat, err := os.Stat(clientHost)
-		if err == nil && stat.Mode()&os.ModeSocket != 0 {
-			// Socket exists, server is running.
-			return nil
-		} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
-			return fmt.Errorf("failed to stat crush server socket: %v", err)
-		}
-		// Sleep for 100ms before checking again.
-		select {
-		case <-cmd.Context().Done():
-			return fmt.Errorf("context cancelled while waiting for crush server to start")
-		case <-time.After(100 * time.Millisecond):
-		}
-	}
-
 	return nil
 }
 

internal/cmd/root_other.go 🔗

@@ -8,7 +8,7 @@ import (
 	"syscall"
 )
 
-func detachProcess(c *exec.Cmd, _, _ string) {
+func detachProcess(c *exec.Cmd) {
 	if c.SysProcAttr == nil {
 		c.SysProcAttr = &syscall.SysProcAttr{}
 	}

internal/cmd/root_windows.go 🔗

@@ -5,18 +5,14 @@ package cmd
 
 import (
 	"os/exec"
+	"syscall"
+
+	"golang.org/x/sys/windows"
 )
 
-func detachProcess(c *exec.Cmd, stdoutPath, stderrPath string) {
-	argv1 := c.Args[0]
-	c.Path = "cmd"
-	c.Args = []string{
-		"cmd",
-		"/c",
-		argv1,
-		">",
-		stdoutPath,
-		"2>",
-		stderrPath,
+func detachProcess(c *exec.Cmd) {
+	if c.SysProcAttr == nil {
+		c.SysProcAttr = &syscall.SysProcAttr{}
 	}
+	c.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP | windows.DETACHED_PROCESS
 }

internal/cmd/run.go 🔗

@@ -5,6 +5,7 @@ import (
 	"log/slog"
 	"strings"
 
+	"github.com/charmbracelet/crush/internal/server"
 	"github.com/spf13/cobra"
 )
 
@@ -25,8 +26,12 @@ crush run -q "Generate a README for this project"
   `,
 	RunE: func(cmd *cobra.Command, args []string) error {
 		quiet, _ := cmd.Flags().GetBool("quiet")
+		hostURL, err := server.ParseHostURL(clientHost)
+		if err != nil {
+			return fmt.Errorf("invalid host URL: %v", err)
+		}
 
-		c, err := setupApp(cmd)
+		c, err := setupApp(cmd, hostURL)
 		if err != nil {
 			return err
 		}

internal/cmd/server.go 🔗

@@ -43,7 +43,12 @@ var serverCmd = &cobra.Command{
 			slog.SetLogLoggerLevel(slog.LevelDebug)
 		}
 
-		srv := server.NewServer(cfg, "unix", serverHost)
+		hostURL, err := server.ParseHostURL(serverHost)
+		if err != nil {
+			return fmt.Errorf("invalid server host: %v", err)
+		}
+
+		srv := server.NewServer(cfg, hostURL.Scheme, hostURL.Host)
 		srv.SetLogger(slog.Default())
 		slog.Info("Starting Crush server...", "addr", srv.Addr)
 

internal/gorust/pipeline.go 🔗

@@ -0,0 +1,233 @@
+package gorust
+
+import (
+	"context"
+	"fmt"
+	"sync"
+	"time"
+)
+
+type Stage[T, R any] interface {
+	Process(ctx context.Context, input T) (R, error)
+	Name() string
+}
+
+type Pipeline[T any] struct {
+	stages []Stage[any, any]
+	mu     sync.RWMutex
+	opts   PipelineOptions
+}
+
+type PipelineOptions struct {
+	MaxConcurrency int
+	Timeout        time.Duration
+	RetryAttempts  int
+	RetryDelay     time.Duration
+}
+
+type PipelineResult[T any] struct {
+	Output T
+	Error  error
+	Stage  string
+	Took   time.Duration
+}
+
+func NewPipeline[T any](opts PipelineOptions) *Pipeline[T] {
+	if opts.MaxConcurrency <= 0 {
+		opts.MaxConcurrency = 10
+	}
+	if opts.Timeout <= 0 {
+		opts.Timeout = 30 * time.Second
+	}
+	if opts.RetryAttempts <= 0 {
+		opts.RetryAttempts = 3
+	}
+	if opts.RetryDelay <= 0 {
+		opts.RetryDelay = time.Second
+	}
+
+	return &Pipeline[T]{
+		stages: make([]Stage[any, any], 0),
+		opts:   opts,
+	}
+}
+
+func (p *Pipeline[T]) AddStage(stage Stage[any, any]) *Pipeline[T] {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	p.stages = append(p.stages, stage)
+	return p
+}
+
+func (p *Pipeline[T]) Execute(ctx context.Context, input T) <-chan PipelineResult[any] {
+	results := make(chan PipelineResult[any], len(p.stages))
+	
+	go func() {
+		defer close(results)
+		
+		current := input
+		for _, stage := range p.stages {
+			select {
+			case <-ctx.Done():
+				results <- PipelineResult[any]{
+					Error: ctx.Err(),
+					Stage: stage.Name(),
+				}
+				return
+			default:
+				result := p.executeStageWithRetry(ctx, stage, current)
+				results <- result
+				
+				if result.Error != nil {
+					return
+				}
+				current = result.Output
+			}
+		}
+	}()
+	
+	return results
+}
+
+func (p *Pipeline[T]) executeStageWithRetry(ctx context.Context, stage Stage[any, any], input any) PipelineResult[any] {
+	var lastErr error
+	start := time.Now()
+	
+	for attempt := 0; attempt < p.opts.RetryAttempts; attempt++ {
+		if attempt > 0 {
+			select {
+			case <-ctx.Done():
+				return PipelineResult[any]{
+					Error: ctx.Err(),
+					Stage: stage.Name(),
+					Took:  time.Since(start),
+				}
+			case <-time.After(p.opts.RetryDelay):
+			}
+		}
+		
+		stageCtx, cancel := context.WithTimeout(ctx, p.opts.Timeout)
+		output, err := stage.Process(stageCtx, input)
+		cancel()
+		
+		if err == nil {
+			return PipelineResult[any]{
+				Output: output,
+				Stage:  stage.Name(),
+				Took:   time.Since(start),
+			}
+		}
+		
+		lastErr = err
+	}
+	
+	return PipelineResult[any]{
+		Error: fmt.Errorf("stage %s failed after %d attempts: %w", stage.Name(), p.opts.RetryAttempts, lastErr),
+		Stage: stage.Name(),
+		Took:  time.Since(start),
+	}
+}
+
+type TransformStage[T, R any] struct {
+	name      string
+	transform func(context.Context, T) (R, error)
+}
+
+func NewTransformStage[T, R any](name string, transform func(context.Context, T) (R, error)) *TransformStage[T, R] {
+	return &TransformStage[T, R]{
+		name:      name,
+		transform: transform,
+	}
+}
+
+func (s *TransformStage[T, R]) Name() string {
+	return s.name
+}
+
+func (s *TransformStage[T, R]) Process(ctx context.Context, input T) (R, error) {
+	return s.transform(ctx, input)
+}
+
+type FilterStage[T any] struct {
+	name      string
+	predicate func(context.Context, T) (bool, error)
+}
+
+func NewFilterStage[T any](name string, predicate func(context.Context, T) (bool, error)) *FilterStage[T] {
+	return &FilterStage[T]{
+		name:      name,
+		predicate: predicate,
+	}
+}
+
+func (s *FilterStage[T]) Name() string {
+	return s.name
+}
+
+func (s *FilterStage[T]) Process(ctx context.Context, input T) (T, error) {
+	keep, err := s.predicate(ctx, input)
+	if err != nil {
+		var zero T
+		return zero, err
+	}
+	
+	if !keep {
+		var zero T
+		return zero, fmt.Errorf("item filtered out")
+	}
+	
+	return input, nil
+}
+
+type BatchProcessor[T, R any] struct {
+	name      string
+	batchSize int
+	processor func(context.Context, []T) ([]R, error)
+	buffer    []T
+	mu        sync.Mutex
+}
+
+func NewBatchProcessor[T, R any](name string, batchSize int, processor func(context.Context, []T) ([]R, error)) *BatchProcessor[T, R] {
+	return &BatchProcessor[T, R]{
+		name:      name,
+		batchSize: batchSize,
+		processor: processor,
+		buffer:    make([]T, 0, batchSize),
+	}
+}
+
+func (b *BatchProcessor[T, R]) Name() string {
+	return b.name
+}
+
+func (b *BatchProcessor[T, R]) Process(ctx context.Context, input T) ([]R, error) {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+	
+	b.buffer = append(b.buffer, input)
+	
+	if len(b.buffer) >= b.batchSize {
+		batch := make([]T, len(b.buffer))
+		copy(batch, b.buffer)
+		b.buffer = b.buffer[:0]
+		
+		return b.processor(ctx, batch)
+	}
+	
+	return nil, nil
+}
+
+func (b *BatchProcessor[T, R]) Flush(ctx context.Context) ([]R, error) {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+	
+	if len(b.buffer) == 0 {
+		return nil, nil
+	}
+	
+	batch := make([]T, len(b.buffer))
+	copy(batch, b.buffer)
+	b.buffer = b.buffer[:0]
+	
+	return b.processor(ctx, batch)
+}

internal/server/net_windows.go 🔗

@@ -12,11 +12,12 @@ import (
 func listen(network, address string) (net.Listener, error) {
 	switch network {
 	case "npipe":
-		return winio.ListenPipe(address, &winio.PipeConfig{
+		cfg := &winio.PipeConfig{
 			MessageMode:      true,
 			InputBufferSize:  65536,
 			OutputBufferSize: 65536,
-		})
+		}
+		return winio.ListenPipe(address, cfg)
 	default:
 		return net.Listen(network, address)
 	}

internal/server/server.go 🔗

@@ -6,6 +6,7 @@ import (
 	"log/slog"
 	"net"
 	"net/http"
+	"net/url"
 	"os/user"
 	"runtime"
 	"strings"
@@ -51,6 +52,29 @@ func (i *Instance) Path() string {
 	return i.path
 }
 
+// ParseHostURL parses a host URL into a [url.URL].
+func ParseHostURL(host string) (*url.URL, error) {
+	proto, addr, ok := strings.Cut(host, "://")
+	if !ok {
+		return nil, fmt.Errorf("invalid host format: %s", host)
+	}
+
+	var basePath string
+	if proto == "tcp" {
+		parsed, err := url.Parse("tcp://" + addr)
+		if err != nil {
+			return nil, fmt.Errorf("invalid tcp address: %v", err)
+		}
+		addr = parsed.Host
+		basePath = parsed.Path
+	}
+	return &url.URL{
+		Scheme: proto,
+		Host:   addr,
+		Path:   basePath,
+	}, nil
+}
+
 // DefaultHost returns the default server host.
 func DefaultHost() string {
 	sock := "crush.sock"
@@ -86,11 +110,11 @@ func (s *Server) SetLogger(logger *slog.Logger) {
 
 // DefaultServer returns a new [Server] instance with the default address.
 func DefaultServer(cfg *config.Config) *Server {
-	proto, addr, ok := strings.Cut(DefaultHost(), "://")
-	if !ok {
+	hostURL, err := ParseHostURL(DefaultHost())
+	if err != nil {
 		panic("invalid default host")
 	}
-	return NewServer(cfg, proto, addr)
+	return NewServer(cfg, hostURL.Scheme, hostURL.Host)
 }
 
 // NewServer is a helper to create a new [Server] instance with the given