fix: use npipe on Windows for default server address

Ayman Bagabas created

Change summary

internal/client/client.go       | 31 ++++++++++++++++++++++-------
internal/client/dial_other.go   | 14 +++++++++++++
internal/client/dial_windows.go | 15 ++++++++++++++
internal/cmd/root.go            |  2 
internal/cmd/server.go          |  2 
internal/server/net_windows.go  | 13 ++++++++---
internal/server/server.go       | 36 +++++++++++++++++-----------------
7 files changed, 81 insertions(+), 32 deletions(-)

Detailed changes

internal/client/client.go 🔗

@@ -7,6 +7,8 @@ import (
 	"net"
 	"net/http"
 	"path/filepath"
+	"runtime"
+	"strings"
 	"time"
 
 	"github.com/charmbracelet/crush/internal/config"
@@ -23,7 +25,11 @@ type Client struct {
 
 // DefaultClient creates a new [Client] connected to the default server address.
 func DefaultClient(path string) (*Client, error) {
-	return NewClient(path, "unix", server.DefaultAddr())
+	proto, addr, ok := strings.Cut(server.DefaultHost(), "://")
+	if !ok {
+		return nil, fmt.Errorf("failed to determine default server address for platform %s", runtime.GOOS)
+	}
+	return NewClient(path, proto, addr)
 }
 
 // NewClient creates a new [Client] connected to the server at the given
@@ -34,13 +40,7 @@ func NewClient(path, network, address string) (*Client, error) {
 	p.SetUnencryptedHTTP2(true)
 	tr := http.DefaultTransport.(*http.Transport).Clone()
 	tr.Protocols = &p
-	tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
-		d := net.Dialer{
-			Timeout:   30 * time.Second,
-			KeepAlive: 30 * time.Second,
-		}
-		return d.DialContext(ctx, network, address)
-	}
+	tr.DialContext = dialer
 	h := &http.Client{
 		Transport: tr,
 		Timeout:   0, // we need this to be 0 for long-lived connections and SSE streams
@@ -125,3 +125,18 @@ func (c *Client) ShutdownServer() error {
 	}
 	return nil
 }
+
+func dialer(ctx context.Context, network, address string) (net.Conn, error) {
+	switch network {
+	case "npipe":
+		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+		defer cancel()
+		return dialPipeContext(ctx, address)
+	default:
+		d := net.Dialer{
+			Timeout:   30 * time.Second,
+			KeepAlive: 30 * time.Second,
+		}
+		return d.DialContext(ctx, network, address)
+	}
+}

internal/client/dial_other.go 🔗

@@ -0,0 +1,14 @@
+//go:build !windows
+// +build !windows
+
+package client
+
+import (
+	"context"
+	"net"
+	"syscall"
+)
+
+func dialPipeContext(context.Context, string) (net.Conn, error) {
+	return nil, syscall.EAFNOSUPPORT
+}

internal/client/dial_windows.go 🔗

@@ -0,0 +1,15 @@
+//go:build windows
+// +build windows
+
+package client
+
+import (
+	"context"
+	"net"
+
+	"github.com/Microsoft/go-winio"
+)
+
+func dialPipeContext(ctx context.Context, address string) (net.Conn, error) {
+	return winio.DialPipeContext(ctx, address)
+}

internal/cmd/root.go 🔗

@@ -36,7 +36,7 @@ func init() {
 	rootCmd.Flags().BoolP("help", "h", false, "Help")
 	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 
-	rootCmd.Flags().StringVar(&clientHost, "host", server.DefaultAddr(), "Connect to a specific crush server host (for advanced users)")
+	rootCmd.Flags().StringVar(&clientHost, "host", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
 
 	rootCmd.AddCommand(runCmd)
 	rootCmd.AddCommand(updateProvidersCmd)

internal/cmd/server.go 🔗

@@ -87,6 +87,6 @@ var serverCmd = &cobra.Command{
 }
 
 func init() {
-	serverCmd.Flags().StringVar(&serverHost, "host", server.DefaultAddr(), "Server host (TCP or Unix socket)")
+	serverCmd.Flags().StringVar(&serverHost, "host", server.DefaultHost(), "Server host (TCP or Unix socket)")
 	rootCmd.AddCommand(serverCmd)
 }

internal/server/net_windows.go 🔗

@@ -5,14 +5,19 @@ package server
 
 import (
 	"net"
-	"strings"
 
 	"github.com/Microsoft/go-winio"
 )
 
 func listen(network, address string) (net.Listener, error) {
-	if !strings.HasPrefix(address, "tcp") {
-		return winio.ListenPipe(address, nil)
+	switch network {
+	case "npipe":
+		return winio.ListenPipe(address, &winio.PipeConfig{
+			MessageMode:      true,
+			InputBufferSize:  65536,
+			OutputBufferSize: 65536,
+		})
+	default:
+		return net.Listen(network, address)
 	}
-	return net.Listen(network, address)
 }

internal/server/server.go 🔗

@@ -51,18 +51,17 @@ func (i *Instance) Path() string {
 	return i.path
 }
 
-// DefaultAddr returns the default address path for the Crush server based on
-// the operating system.
-func DefaultAddr() string {
-	sockPath := "crush.sock"
-	user, err := user.Current()
-	if err == nil && user.Uid != "" {
-		sockPath = fmt.Sprintf("crush-%s.sock", user.Uid)
+// DefaultHost returns the default server host.
+func DefaultHost() string {
+	sock := "crush.sock"
+	usr, err := user.Current()
+	if err == nil && usr.Uid != "" {
+		sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
 	}
 	if runtime.GOOS == "windows" {
-		return fmt.Sprintf(`\\.\pipe\%s`, sockPath)
+		return fmt.Sprintf("npipe:////./pipe/%s", sock)
 	}
-	return fmt.Sprintf("/tmp/%s", sockPath)
+	return fmt.Sprintf("unix:///tmp/%s", sock)
 }
 
 // Server represents a Crush server instance bound to a specific address.
@@ -87,20 +86,17 @@ func (s *Server) SetLogger(logger *slog.Logger) {
 
 // DefaultServer returns a new [Server] instance with the default address.
 func DefaultServer(cfg *config.Config) *Server {
-	return NewServer(cfg, "unix", DefaultAddr())
+	proto, addr, ok := strings.Cut(DefaultHost(), "://")
+	if !ok {
+		panic("invalid default host")
+	}
+	return NewServer(cfg, proto, addr)
 }
 
 // NewServer is a helper to create a new [Server] instance with the given
 // address. On Windows, if the address is not a "tcp" address, it will be
 // converted to a named pipe format.
 func NewServer(cfg *config.Config, network, address string) *Server {
-	if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
-		!strings.HasPrefix(address, `\\.\pipe\`) {
-		// On Windows, convert to named pipe format if not TCP
-		// (e.g., "mypipe" -> "\\.\pipe\mypipe")
-		address = fmt.Sprintf(`\\.\pipe\%s`, address)
-	}
-
 	s := new(Server)
 	s.Addr = address
 	s.cfg = cfg
@@ -157,7 +153,11 @@ func (s *Server) ListenAndServe() error {
 	if s.ln != nil {
 		return fmt.Errorf("server already started")
 	}
-	ln, err := listen("unix", s.Addr)
+	proto := "unix"
+	if runtime.GOOS == "windows" {
+		proto = "npipe"
+	}
+	ln, err := listen(proto, s.Addr)
 	if err != nil {
 		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
 	}