Detailed changes
@@ -7,8 +7,6 @@ import (
"net"
"net/http"
"path/filepath"
- "runtime"
- "strings"
"time"
"github.com/charmbracelet/crush/internal/config"
@@ -16,39 +14,49 @@ import (
"github.com/charmbracelet/crush/internal/server"
)
+// DummyHost is used to satisfy the http.Client's requirement for a URL.
+const DummyHost = "api.crush.localhost"
+
// Client represents an RPC client connected to a Crush server.
type Client struct {
- h *http.Client
- id string
- path string
+ h *http.Client
+ id string
+ path string
+ proto string
+ addr string
}
// DefaultClient creates a new [Client] connected to the default server address.
func DefaultClient(path string) (*Client, error) {
- proto, addr, ok := strings.Cut(server.DefaultHost(), "://")
- if !ok {
- return nil, fmt.Errorf("failed to determine default server address for platform %s", runtime.GOOS)
+ host, err := server.ParseHostURL(server.DefaultHost())
+ if err != nil {
+ return nil, err
}
- return NewClient(path, proto, addr)
+ return NewClient(path, host.Scheme, host.Host)
}
// NewClient creates a new [Client] connected to the server at the given
// network and address.
func NewClient(path, network, address string) (*Client, error) {
- var p http.Protocols
+ c := new(Client)
+ c.path = filepath.Clean(path)
+ c.proto = network
+ c.addr = address
+ p := &http.Protocols{}
p.SetHTTP1(true)
p.SetUnencryptedHTTP2(true)
tr := http.DefaultTransport.(*http.Transport).Clone()
- tr.Protocols = &p
- tr.DialContext = dialer
- h := &http.Client{
+ tr.Protocols = p
+ tr.DialContext = c.dialer
+ if c.proto == "npipe" || c.proto == "unix" {
+ // We don't need compression for local connections.
+ tr.DisableCompression = true
+ }
+ c.h = &http.Client{
Transport: tr,
Timeout: 0, // we need this to be 0 for long-lived connections and SSE streams
}
- return &Client{
- h: h,
- path: filepath.Clean(path),
- }, nil
+ return c, nil
}
// ID returns the client's instance unique identifier.
@@ -126,17 +134,19 @@ func (c *Client) ShutdownServer() error {
return nil
}
-func dialer(ctx context.Context, network, address string) (net.Conn, error) {
- switch network {
+func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
+ d := net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+ switch c.proto {
case "npipe":
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- return dialPipeContext(ctx, address)
+ return dialPipeContext(ctx, c.addr)
+ case "unix":
+ return d.DialContext(ctx, "unix", c.addr)
default:
- d := net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }
return d.DialContext(ctx, network, address)
}
}
@@ -7,6 +7,7 @@ import (
"io"
"io/fs"
"log/slog"
+ "net/url"
"os"
"os/exec"
"path/filepath"
@@ -71,15 +72,62 @@ crush run "Explain the use of context in Go"
crush -y
`,
RunE: func(cmd *cobra.Command, args []string) error {
- if err := ensureServerRunning(cmd); err != nil {
- return err
+ hostURL, err := server.ParseHostURL(clientHost)
+ if err != nil {
+ return fmt.Errorf("invalid host URL: %v", err)
}
- c, err := setupApp(cmd)
+ switch hostURL.Scheme {
+ case "unix", "npipe":
+ _, err := os.Stat(hostURL.Host)
+ if err != nil && errors.Is(err, fs.ErrNotExist) {
+ slog.Info("Starting server...", "host", clientHost)
+ if err := startDetachedServer(cmd); err != nil {
+ return err
+ }
+ }
+
+ // Wait for the file to appear
+ for range 10 {
+ _, err = os.Stat(hostURL.Host)
+ if err == nil {
+ break
+ }
+ select {
+ case <-cmd.Context().Done():
+ return cmd.Context().Err()
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+ if err != nil {
+ return fmt.Errorf("failed to connect to crush server: %v", err)
+ }
+
+ default:
+ // TODO: implement TCP support
+ }
+
+ c, err := setupApp(cmd, hostURL)
if err != nil {
return err
}
+ tries := 5
+ for i := range tries {
+ err := c.Health()
+ if err == nil {
+ break
+ }
+ select {
+ case <-cmd.Context().Done():
+ return cmd.Context().Err()
+ case <-time.After(100 * time.Millisecond):
+ }
+ if i == tries-1 {
+ return fmt.Errorf("failed to connect to crush server after %d attempts: %v", tries, err)
+ }
+ }
+
m, err := tui.New(c)
if err != nil {
return fmt.Errorf("failed to create TUI model: %v", err)
@@ -145,7 +193,7 @@ func streamEvents(ctx context.Context, evc <-chan any, p *tea.Program) {
// setupApp handles the common setup logic for both interactive and non-interactive modes.
// It returns the app instance, config, cleanup function, and any error.
-func setupApp(cmd *cobra.Command) (*client.Client, error) {
+func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, error) {
debug, _ := cmd.Flags().GetBool("debug")
yolo, _ := cmd.Flags().GetBool("yolo")
dataDir, _ := cmd.Flags().GetString("data-dir")
@@ -156,7 +204,7 @@ func setupApp(cmd *cobra.Command) (*client.Client, error) {
return nil, err
}
- c, err := client.NewClient(cwd, "unix", clientHost)
+ c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
if err != nil {
return nil, err
}
@@ -178,17 +226,7 @@ func setupApp(cmd *cobra.Command) (*client.Client, error) {
var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)
-func ensureServerRunning(cmd *cobra.Command) error {
- stat, err := os.Stat(clientHost)
- if err == nil && stat.Mode()&os.ModeSocket == 0 {
- return fmt.Errorf("crush server socket path exists but is not a socket: %s", clientHost)
- } else if err == nil && stat.Mode()&os.ModeSocket != 0 {
- // Socket exists, assume server is running.
- return nil
- } else if err != nil && !errors.Is(err, fs.ErrNotExist) {
- return fmt.Errorf("failed to stat crush server socket: %v", err)
- }
-
+func startDetachedServer(cmd *cobra.Command) error {
// Start the server as a detached process if the socket does not exist.
exe, err := os.Executable()
if err != nil {
@@ -204,7 +242,7 @@ func ensureServerRunning(cmd *cobra.Command) error {
c := exec.CommandContext(cmd.Context(), exe, "server")
stdoutPath := filepath.Join(chDir, "stdout.log")
stderrPath := filepath.Join(chDir, "stderr.log")
- detachProcess(c, stdoutPath, stderrPath)
+ detachProcess(c)
stdout, err := os.Create(stdoutPath)
if err != nil {
@@ -228,23 +266,6 @@ func ensureServerRunning(cmd *cobra.Command) error {
return fmt.Errorf("failed to detach crush server process: %v", err)
}
- // Wait for the server to start and create the socket.
- for range 10 {
- stat, err := os.Stat(clientHost)
- if err == nil && stat.Mode()&os.ModeSocket != 0 {
- // Socket exists, server is running.
- return nil
- } else if err != nil && !errors.Is(err, fs.ErrNotExist) {
- return fmt.Errorf("failed to stat crush server socket: %v", err)
- }
- // Sleep for 100ms before checking again.
- select {
- case <-cmd.Context().Done():
- return fmt.Errorf("context cancelled while waiting for crush server to start")
- case <-time.After(100 * time.Millisecond):
- }
- }
-
return nil
}
@@ -8,7 +8,7 @@ import (
"syscall"
)
-func detachProcess(c *exec.Cmd, _, _ string) {
+func detachProcess(c *exec.Cmd) {
if c.SysProcAttr == nil {
c.SysProcAttr = &syscall.SysProcAttr{}
}
@@ -5,18 +5,14 @@ package cmd
import (
"os/exec"
+ "syscall"
+
+ "golang.org/x/sys/windows"
)
-func detachProcess(c *exec.Cmd, stdoutPath, stderrPath string) {
- argv1 := c.Args[0]
- c.Path = "cmd"
- c.Args = []string{
- "cmd",
- "/c",
- argv1,
- ">",
- stdoutPath,
- "2>",
- stderrPath,
+func detachProcess(c *exec.Cmd) {
+ if c.SysProcAttr == nil {
+ c.SysProcAttr = &syscall.SysProcAttr{}
}
+ c.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP | windows.DETACHED_PROCESS
}
@@ -5,6 +5,7 @@ import (
"log/slog"
"strings"
+ "github.com/charmbracelet/crush/internal/server"
"github.com/spf13/cobra"
)
@@ -25,8 +26,12 @@ crush run -q "Generate a README for this project"
`,
RunE: func(cmd *cobra.Command, args []string) error {
quiet, _ := cmd.Flags().GetBool("quiet")
+ hostURL, err := server.ParseHostURL(clientHost)
+ if err != nil {
+ return fmt.Errorf("invalid host URL: %v", err)
+ }
- c, err := setupApp(cmd)
+ c, err := setupApp(cmd, hostURL)
if err != nil {
return err
}
@@ -43,7 +43,12 @@ var serverCmd = &cobra.Command{
slog.SetLogLoggerLevel(slog.LevelDebug)
}
- srv := server.NewServer(cfg, "unix", serverHost)
+ hostURL, err := server.ParseHostURL(serverHost)
+ if err != nil {
+ return fmt.Errorf("invalid server host: %v", err)
+ }
+
+ srv := server.NewServer(cfg, hostURL.Scheme, hostURL.Host)
srv.SetLogger(slog.Default())
slog.Info("Starting Crush server...", "addr", srv.Addr)
@@ -0,0 +1,233 @@
+package gorust
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "time"
+)
+
+type Stage[T, R any] interface {
+ Process(ctx context.Context, input T) (R, error)
+ Name() string
+}
+
+type Pipeline[T any] struct {
+ stages []Stage[any, any]
+ mu sync.RWMutex
+ opts PipelineOptions
+}
+
+type PipelineOptions struct {
+ MaxConcurrency int
+ Timeout time.Duration
+ RetryAttempts int
+ RetryDelay time.Duration
+}
+
+type PipelineResult[T any] struct {
+ Output T
+ Error error
+ Stage string
+ Took time.Duration
+}
+
+func NewPipeline[T any](opts PipelineOptions) *Pipeline[T] {
+ if opts.MaxConcurrency <= 0 {
+ opts.MaxConcurrency = 10
+ }
+ if opts.Timeout <= 0 {
+ opts.Timeout = 30 * time.Second
+ }
+ if opts.RetryAttempts <= 0 {
+ opts.RetryAttempts = 3
+ }
+ if opts.RetryDelay <= 0 {
+ opts.RetryDelay = time.Second
+ }
+
+ return &Pipeline[T]{
+ stages: make([]Stage[any, any], 0),
+ opts: opts,
+ }
+}
+
+func (p *Pipeline[T]) AddStage(stage Stage[any, any]) *Pipeline[T] {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.stages = append(p.stages, stage)
+ return p
+}
+
+func (p *Pipeline[T]) Execute(ctx context.Context, input T) <-chan PipelineResult[any] {
+ results := make(chan PipelineResult[any], len(p.stages))
+
+ go func() {
+ defer close(results)
+
+ current := input
+ for _, stage := range p.stages {
+ select {
+ case <-ctx.Done():
+ results <- PipelineResult[any]{
+ Error: ctx.Err(),
+ Stage: stage.Name(),
+ }
+ return
+ default:
+ result := p.executeStageWithRetry(ctx, stage, current)
+ results <- result
+
+ if result.Error != nil {
+ return
+ }
+ current = result.Output
+ }
+ }
+ }()
+
+ return results
+}
+
+func (p *Pipeline[T]) executeStageWithRetry(ctx context.Context, stage Stage[any, any], input any) PipelineResult[any] {
+ var lastErr error
+ start := time.Now()
+
+ for attempt := 0; attempt < p.opts.RetryAttempts; attempt++ {
+ if attempt > 0 {
+ select {
+ case <-ctx.Done():
+ return PipelineResult[any]{
+ Error: ctx.Err(),
+ Stage: stage.Name(),
+ Took: time.Since(start),
+ }
+ case <-time.After(p.opts.RetryDelay):
+ }
+ }
+
+ stageCtx, cancel := context.WithTimeout(ctx, p.opts.Timeout)
+ output, err := stage.Process(stageCtx, input)
+ cancel()
+
+ if err == nil {
+ return PipelineResult[any]{
+ Output: output,
+ Stage: stage.Name(),
+ Took: time.Since(start),
+ }
+ }
+
+ lastErr = err
+ }
+
+ return PipelineResult[any]{
+ Error: fmt.Errorf("stage %s failed after %d attempts: %w", stage.Name(), p.opts.RetryAttempts, lastErr),
+ Stage: stage.Name(),
+ Took: time.Since(start),
+ }
+}
+
+type TransformStage[T, R any] struct {
+ name string
+ transform func(context.Context, T) (R, error)
+}
+
+func NewTransformStage[T, R any](name string, transform func(context.Context, T) (R, error)) *TransformStage[T, R] {
+ return &TransformStage[T, R]{
+ name: name,
+ transform: transform,
+ }
+}
+
+func (s *TransformStage[T, R]) Name() string {
+ return s.name
+}
+
+func (s *TransformStage[T, R]) Process(ctx context.Context, input T) (R, error) {
+ return s.transform(ctx, input)
+}
+
+type FilterStage[T any] struct {
+ name string
+ predicate func(context.Context, T) (bool, error)
+}
+
+func NewFilterStage[T any](name string, predicate func(context.Context, T) (bool, error)) *FilterStage[T] {
+ return &FilterStage[T]{
+ name: name,
+ predicate: predicate,
+ }
+}
+
+func (s *FilterStage[T]) Name() string {
+ return s.name
+}
+
+func (s *FilterStage[T]) Process(ctx context.Context, input T) (T, error) {
+ keep, err := s.predicate(ctx, input)
+ if err != nil {
+ var zero T
+ return zero, err
+ }
+
+ if !keep {
+ var zero T
+ return zero, fmt.Errorf("item filtered out")
+ }
+
+ return input, nil
+}
+
+type BatchProcessor[T, R any] struct {
+ name string
+ batchSize int
+ processor func(context.Context, []T) ([]R, error)
+ buffer []T
+ mu sync.Mutex
+}
+
+func NewBatchProcessor[T, R any](name string, batchSize int, processor func(context.Context, []T) ([]R, error)) *BatchProcessor[T, R] {
+ return &BatchProcessor[T, R]{
+ name: name,
+ batchSize: batchSize,
+ processor: processor,
+ buffer: make([]T, 0, batchSize),
+ }
+}
+
+func (b *BatchProcessor[T, R]) Name() string {
+ return b.name
+}
+
+func (b *BatchProcessor[T, R]) Process(ctx context.Context, input T) ([]R, error) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ b.buffer = append(b.buffer, input)
+
+ if len(b.buffer) >= b.batchSize {
+ batch := make([]T, len(b.buffer))
+ copy(batch, b.buffer)
+ b.buffer = b.buffer[:0]
+
+ return b.processor(ctx, batch)
+ }
+
+ return nil, nil
+}
+
+func (b *BatchProcessor[T, R]) Flush(ctx context.Context) ([]R, error) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ if len(b.buffer) == 0 {
+ return nil, nil
+ }
+
+ batch := make([]T, len(b.buffer))
+ copy(batch, b.buffer)
+ b.buffer = b.buffer[:0]
+
+ return b.processor(ctx, batch)
+}
@@ -12,11 +12,12 @@ import (
func listen(network, address string) (net.Listener, error) {
switch network {
case "npipe":
- return winio.ListenPipe(address, &winio.PipeConfig{
+ cfg := &winio.PipeConfig{
MessageMode: true,
InputBufferSize: 65536,
OutputBufferSize: 65536,
- })
+ }
+ return winio.ListenPipe(address, cfg)
default:
return net.Listen(network, address)
}
@@ -6,6 +6,7 @@ import (
"log/slog"
"net"
"net/http"
+ "net/url"
"os/user"
"runtime"
"strings"
@@ -51,6 +52,29 @@ func (i *Instance) Path() string {
return i.path
}
+// ParseHostURL parses a host URL into a [url.URL].
+func ParseHostURL(host string) (*url.URL, error) {
+ proto, addr, ok := strings.Cut(host, "://")
+ if !ok {
+ return nil, fmt.Errorf("invalid host format: %s", host)
+ }
+
+ var basePath string
+ if proto == "tcp" {
+ parsed, err := url.Parse("tcp://" + addr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid tcp address: %v", err)
+ }
+ addr = parsed.Host
+ basePath = parsed.Path
+ }
+ return &url.URL{
+ Scheme: proto,
+ Host: addr,
+ Path: basePath,
+ }, nil
+}
+
// DefaultHost returns the default server host.
func DefaultHost() string {
sock := "crush.sock"
@@ -86,11 +110,11 @@ func (s *Server) SetLogger(logger *slog.Logger) {
// DefaultServer returns a new [Server] instance with the default address.
func DefaultServer(cfg *config.Config) *Server {
- proto, addr, ok := strings.Cut(DefaultHost(), "://")
- if !ok {
+ hostURL, err := ParseHostURL(DefaultHost())
+ if err != nil {
panic("invalid default host")
}
- return NewServer(cfg, proto, addr)
+ return NewServer(cfg, hostURL.Scheme, hostURL.Host)
}
// NewServer is a helper to create a new [Server] instance with the given