From b754264264b1172934518631ba995f9ebe4ddcd8 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 24 Sep 2025 17:23:34 -0400 Subject: [PATCH] fix: use npipe on Windows for default server address --- internal/client/client.go | 31 ++++++++++++++++++++-------- internal/client/dial_other.go | 14 +++++++++++++ internal/client/dial_windows.go | 15 ++++++++++++++ internal/cmd/root.go | 2 +- internal/cmd/server.go | 2 +- internal/server/net_windows.go | 13 ++++++++---- internal/server/server.go | 36 ++++++++++++++++----------------- 7 files changed, 81 insertions(+), 32 deletions(-) create mode 100644 internal/client/dial_other.go create mode 100644 internal/client/dial_windows.go diff --git a/internal/client/client.go b/internal/client/client.go index 686b94ee237d3b19da352b329b1e11a340871389..5e5887892004fb1c26957f60a4628a5568cb2525 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -7,6 +7,8 @@ import ( "net" "net/http" "path/filepath" + "runtime" + "strings" "time" "github.com/charmbracelet/crush/internal/config" @@ -23,7 +25,11 @@ type Client struct { // DefaultClient creates a new [Client] connected to the default server address. func DefaultClient(path string) (*Client, error) { - return NewClient(path, "unix", server.DefaultAddr()) + proto, addr, ok := strings.Cut(server.DefaultHost(), "://") + if !ok { + return nil, fmt.Errorf("failed to determine default server address for platform %s", runtime.GOOS) + } + return NewClient(path, proto, addr) } // NewClient creates a new [Client] connected to the server at the given @@ -34,13 +40,7 @@ func NewClient(path, network, address string) (*Client, error) { p.SetUnencryptedHTTP2(true) tr := http.DefaultTransport.(*http.Transport).Clone() tr.Protocols = &p - tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - d := net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } - return d.DialContext(ctx, network, address) - } + tr.DialContext = dialer h := &http.Client{ Transport: tr, Timeout: 0, // we need this to be 0 for long-lived connections and SSE streams @@ -125,3 +125,18 @@ func (c *Client) ShutdownServer() error { } return nil } + +func dialer(ctx context.Context, network, address string) (net.Conn, error) { + switch network { + case "npipe": + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + return dialPipeContext(ctx, address) + default: + d := net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + return d.DialContext(ctx, network, address) + } +} diff --git a/internal/client/dial_other.go b/internal/client/dial_other.go new file mode 100644 index 0000000000000000000000000000000000000000..f2ba8569ba3326f2df82dc34bbf842eac30918d9 --- /dev/null +++ b/internal/client/dial_other.go @@ -0,0 +1,14 @@ +//go:build !windows +// +build !windows + +package client + +import ( + "context" + "net" + "syscall" +) + +func dialPipeContext(context.Context, string) (net.Conn, error) { + return nil, syscall.EAFNOSUPPORT +} diff --git a/internal/client/dial_windows.go b/internal/client/dial_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..750ce98b152c7583ce8e7889ef35c12098f6da8f --- /dev/null +++ b/internal/client/dial_windows.go @@ -0,0 +1,15 @@ +//go:build windows +// +build windows + +package client + +import ( + "context" + "net" + + "github.com/Microsoft/go-winio" +) + +func dialPipeContext(ctx context.Context, address string) (net.Conn, error) { + return winio.DialPipeContext(ctx, address) +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index af25869888ba51e9f7c02cd254ac849971671863..42253c3ffcc6911033b473a66cc7e4d29beb2883 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -36,7 +36,7 @@ func init() { rootCmd.Flags().BoolP("help", "h", false, "Help") rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)") - rootCmd.Flags().StringVar(&clientHost, "host", server.DefaultAddr(), "Connect to a specific crush server host (for advanced users)") + rootCmd.Flags().StringVar(&clientHost, "host", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)") rootCmd.AddCommand(runCmd) rootCmd.AddCommand(updateProvidersCmd) diff --git a/internal/cmd/server.go b/internal/cmd/server.go index 499889de7a8554296cebfc81f3bb0cd7a6aba738..db188b3e1240cf3171bdf4bbe26dde43635a0661 100644 --- a/internal/cmd/server.go +++ b/internal/cmd/server.go @@ -87,6 +87,6 @@ var serverCmd = &cobra.Command{ } func init() { - serverCmd.Flags().StringVar(&serverHost, "host", server.DefaultAddr(), "Server host (TCP or Unix socket)") + serverCmd.Flags().StringVar(&serverHost, "host", server.DefaultHost(), "Server host (TCP or Unix socket)") rootCmd.AddCommand(serverCmd) } diff --git a/internal/server/net_windows.go b/internal/server/net_windows.go index c0fef2de1e59b3f6894b77c2dd4a9fe5e7b27970..dbf019020ad4fe64b564e342ad1ea1b8b66a900b 100644 --- a/internal/server/net_windows.go +++ b/internal/server/net_windows.go @@ -5,14 +5,19 @@ package server import ( "net" - "strings" "github.com/Microsoft/go-winio" ) func listen(network, address string) (net.Listener, error) { - if !strings.HasPrefix(address, "tcp") { - return winio.ListenPipe(address, nil) + switch network { + case "npipe": + return winio.ListenPipe(address, &winio.PipeConfig{ + MessageMode: true, + InputBufferSize: 65536, + OutputBufferSize: 65536, + }) + default: + return net.Listen(network, address) } - return net.Listen(network, address) } diff --git a/internal/server/server.go b/internal/server/server.go index d88eb1c7f0bd7f555a8edc0ecef1a28c08ae6380..f0e74b934d206a9d5e2896c19d8ca6ffde51cd07 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -51,18 +51,17 @@ func (i *Instance) Path() string { return i.path } -// DefaultAddr returns the default address path for the Crush server based on -// the operating system. -func DefaultAddr() string { - sockPath := "crush.sock" - user, err := user.Current() - if err == nil && user.Uid != "" { - sockPath = fmt.Sprintf("crush-%s.sock", user.Uid) +// DefaultHost returns the default server host. +func DefaultHost() string { + sock := "crush.sock" + usr, err := user.Current() + if err == nil && usr.Uid != "" { + sock = fmt.Sprintf("crush-%s.sock", usr.Uid) } if runtime.GOOS == "windows" { - return fmt.Sprintf(`\\.\pipe\%s`, sockPath) + return fmt.Sprintf("npipe:////./pipe/%s", sock) } - return fmt.Sprintf("/tmp/%s", sockPath) + return fmt.Sprintf("unix:///tmp/%s", sock) } // Server represents a Crush server instance bound to a specific address. @@ -87,20 +86,17 @@ func (s *Server) SetLogger(logger *slog.Logger) { // DefaultServer returns a new [Server] instance with the default address. func DefaultServer(cfg *config.Config) *Server { - return NewServer(cfg, "unix", DefaultAddr()) + proto, addr, ok := strings.Cut(DefaultHost(), "://") + if !ok { + panic("invalid default host") + } + return NewServer(cfg, proto, addr) } // NewServer is a helper to create a new [Server] instance with the given // address. On Windows, if the address is not a "tcp" address, it will be // converted to a named pipe format. func NewServer(cfg *config.Config, network, address string) *Server { - if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") && - !strings.HasPrefix(address, `\\.\pipe\`) { - // On Windows, convert to named pipe format if not TCP - // (e.g., "mypipe" -> "\\.\pipe\mypipe") - address = fmt.Sprintf(`\\.\pipe\%s`, address) - } - s := new(Server) s.Addr = address s.cfg = cfg @@ -157,7 +153,11 @@ func (s *Server) ListenAndServe() error { if s.ln != nil { return fmt.Errorf("server already started") } - ln, err := listen("unix", s.Addr) + proto := "unix" + if runtime.GOOS == "windows" { + proto = "npipe" + } + ln, err := listen(proto, s.Addr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", s.Addr, err) }