Detailed changes
@@ -7,6 +7,8 @@ import (
"net"
"net/http"
"path/filepath"
+ "runtime"
+ "strings"
"time"
"github.com/charmbracelet/crush/internal/config"
@@ -23,7 +25,11 @@ type Client struct {
// DefaultClient creates a new [Client] connected to the default server address.
func DefaultClient(path string) (*Client, error) {
- return NewClient(path, "unix", server.DefaultAddr())
+ proto, addr, ok := strings.Cut(server.DefaultHost(), "://")
+ if !ok {
+ return nil, fmt.Errorf("failed to determine default server address for platform %s", runtime.GOOS)
+ }
+ return NewClient(path, proto, addr)
}
// NewClient creates a new [Client] connected to the server at the given
@@ -34,13 +40,7 @@ func NewClient(path, network, address string) (*Client, error) {
p.SetUnencryptedHTTP2(true)
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Protocols = &p
- tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
- d := net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }
- return d.DialContext(ctx, network, address)
- }
+ tr.DialContext = dialer
h := &http.Client{
Transport: tr,
Timeout: 0, // we need this to be 0 for long-lived connections and SSE streams
@@ -125,3 +125,18 @@ func (c *Client) ShutdownServer() error {
}
return nil
}
+
+func dialer(ctx context.Context, network, address string) (net.Conn, error) {
+ switch network {
+ case "npipe":
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ return dialPipeContext(ctx, address)
+ default:
+ d := net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+ return d.DialContext(ctx, network, address)
+ }
+}
@@ -0,0 +1,14 @@
+//go:build !windows
+// +build !windows
+
+package client
+
+import (
+ "context"
+ "net"
+ "syscall"
+)
+
+func dialPipeContext(context.Context, string) (net.Conn, error) {
+ return nil, syscall.EAFNOSUPPORT
+}
@@ -0,0 +1,15 @@
+//go:build windows
+// +build windows
+
+package client
+
+import (
+ "context"
+ "net"
+
+ "github.com/Microsoft/go-winio"
+)
+
+func dialPipeContext(ctx context.Context, address string) (net.Conn, error) {
+ return winio.DialPipeContext(ctx, address)
+}
@@ -36,7 +36,7 @@ func init() {
rootCmd.Flags().BoolP("help", "h", false, "Help")
rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
- rootCmd.Flags().StringVar(&clientHost, "host", server.DefaultAddr(), "Connect to a specific crush server host (for advanced users)")
+ rootCmd.Flags().StringVar(&clientHost, "host", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
rootCmd.AddCommand(runCmd)
rootCmd.AddCommand(updateProvidersCmd)
@@ -87,6 +87,6 @@ var serverCmd = &cobra.Command{
}
func init() {
- serverCmd.Flags().StringVar(&serverHost, "host", server.DefaultAddr(), "Server host (TCP or Unix socket)")
+ serverCmd.Flags().StringVar(&serverHost, "host", server.DefaultHost(), "Server host (TCP or Unix socket)")
rootCmd.AddCommand(serverCmd)
}
@@ -5,14 +5,19 @@ package server
import (
"net"
- "strings"
"github.com/Microsoft/go-winio"
)
func listen(network, address string) (net.Listener, error) {
- if !strings.HasPrefix(address, "tcp") {
- return winio.ListenPipe(address, nil)
+ switch network {
+ case "npipe":
+ return winio.ListenPipe(address, &winio.PipeConfig{
+ MessageMode: true,
+ InputBufferSize: 65536,
+ OutputBufferSize: 65536,
+ })
+ default:
+ return net.Listen(network, address)
}
- return net.Listen(network, address)
}
@@ -51,18 +51,17 @@ func (i *Instance) Path() string {
return i.path
}
-// DefaultAddr returns the default address path for the Crush server based on
-// the operating system.
-func DefaultAddr() string {
- sockPath := "crush.sock"
- user, err := user.Current()
- if err == nil && user.Uid != "" {
- sockPath = fmt.Sprintf("crush-%s.sock", user.Uid)
+// DefaultHost returns the default server host.
+func DefaultHost() string {
+ sock := "crush.sock"
+ usr, err := user.Current()
+ if err == nil && usr.Uid != "" {
+ sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
}
if runtime.GOOS == "windows" {
- return fmt.Sprintf(`\\.\pipe\%s`, sockPath)
+ return fmt.Sprintf("npipe:////./pipe/%s", sock)
}
- return fmt.Sprintf("/tmp/%s", sockPath)
+ return fmt.Sprintf("unix:///tmp/%s", sock)
}
// Server represents a Crush server instance bound to a specific address.
@@ -87,20 +86,17 @@ func (s *Server) SetLogger(logger *slog.Logger) {
// DefaultServer returns a new [Server] instance with the default address.
func DefaultServer(cfg *config.Config) *Server {
- return NewServer(cfg, "unix", DefaultAddr())
+ proto, addr, ok := strings.Cut(DefaultHost(), "://")
+ if !ok {
+ panic("invalid default host")
+ }
+ return NewServer(cfg, proto, addr)
}
// NewServer is a helper to create a new [Server] instance with the given
// address. On Windows, if the address is not a "tcp" address, it will be
// converted to a named pipe format.
func NewServer(cfg *config.Config, network, address string) *Server {
- if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
- !strings.HasPrefix(address, `\\.\pipe\`) {
- // On Windows, convert to named pipe format if not TCP
- // (e.g., "mypipe" -> "\\.\pipe\mypipe")
- address = fmt.Sprintf(`\\.\pipe\%s`, address)
- }
-
s := new(Server)
s.Addr = address
s.cfg = cfg
@@ -157,7 +153,11 @@ func (s *Server) ListenAndServe() error {
if s.ln != nil {
return fmt.Errorf("server already started")
}
- ln, err := listen("unix", s.Addr)
+ proto := "unix"
+ if runtime.GOOS == "windows" {
+ proto = "npipe"
+ }
+ ln, err := listen(proto, s.Addr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
}