server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"os/user"
 10	"runtime"
 11	"strings"
 12
 13	"github.com/charmbracelet/crush/internal/app"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16)
 17
 18// ErrServerClosed is returned when the server is closed.
 19var ErrServerClosed = fmt.Errorf("server closed")
 20
 21// InstanceState represents the state of a running [app.App] instance.
 22type InstanceState uint8
 23
 24const (
 25	// InstanceStateCreated indicates that the instance has been created but not yet started.
 26	InstanceStateCreated InstanceState = iota
 27	// InstanceStateStarted indicates that the instance is currently running.
 28	InstanceStateStarted
 29	// InstanceStateStopped indicates that the instance has been stopped.
 30	InstanceStateStopped
 31)
 32
 33// Instance represents a running [app.App] instance with its associated
 34// resources and state.
 35type Instance struct {
 36	*app.App
 37	State InstanceState
 38	ln    net.Listener
 39	cfg   *config.Config
 40	id    string
 41	path  string
 42}
 43
 44// ID returns the unique identifier of the instance.
 45func (i *Instance) ID() string {
 46	return i.id
 47}
 48
 49// Path returns the filesystem path associated with the instance.
 50func (i *Instance) Path() string {
 51	return i.path
 52}
 53
 54// DefaultAddr returns the default address path for the Crush server based on
 55// the operating system.
 56func DefaultAddr() string {
 57	sockPath := "crush.sock"
 58	user, err := user.Current()
 59	if err == nil && user.Uid != "" {
 60		sockPath = fmt.Sprintf("crush-%s.sock", user.Uid)
 61	}
 62	if runtime.GOOS == "windows" {
 63		return fmt.Sprintf(`\\.\pipe\%s`, sockPath)
 64	}
 65	return fmt.Sprintf("/tmp/%s", sockPath)
 66}
 67
 68// Server represents a Crush server instance bound to a specific address.
 69type Server struct {
 70	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 71	Addr string
 72
 73	h   *http.Server
 74	ln  net.Listener
 75	ctx context.Context
 76
 77	// instances is a map of running applications managed by the server.
 78	instances *csync.Map[string, *Instance]
 79	cfg       *config.Config
 80	logger    *slog.Logger
 81}
 82
 83// SetLogger sets the logger for the server.
 84func (s *Server) SetLogger(logger *slog.Logger) {
 85	s.logger = logger
 86}
 87
 88// DefaultServer returns a new [Server] instance with the default address.
 89func DefaultServer(cfg *config.Config) *Server {
 90	return NewServer(cfg, "unix", DefaultAddr())
 91}
 92
 93// NewServer is a helper to create a new [Server] instance with the given
 94// address. On Windows, if the address is not a "tcp" address, it will be
 95// converted to a named pipe format.
 96func NewServer(cfg *config.Config, network, address string) *Server {
 97	if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
 98		!strings.HasPrefix(address, `\\.\pipe\`) {
 99		// On Windows, convert to named pipe format if not TCP
100		// (e.g., "mypipe" -> "\\.\pipe\mypipe")
101		address = fmt.Sprintf(`\\.\pipe\%s`, address)
102	}
103
104	s := new(Server)
105	s.Addr = address
106	s.cfg = cfg
107	s.instances = csync.NewMap[string, *Instance]()
108	s.ctx = context.Background()
109
110	var p http.Protocols
111	p.SetHTTP1(true)
112	p.SetUnencryptedHTTP2(true)
113	c := &controllerV1{Server: s}
114	mux := http.NewServeMux()
115	mux.HandleFunc("GET /v1/health", c.handleGetHealth)
116	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
117	mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
118	mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
119	mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances)
120	mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
121	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
122	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
123	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
124	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
125	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
126	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
127	mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
128	mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
129	mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
130	mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
131	mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
132	mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
133	mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
134	mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
135	mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
136	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
137	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
138	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
139	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
140	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
141	s.h = &http.Server{
142		Protocols: &p,
143		Handler:   s.loggingHandler(mux),
144	}
145	return s
146}
147
148// Serve accepts incoming connections on the listener.
149func (s *Server) Serve(ln net.Listener) error {
150	return s.h.Serve(ln)
151}
152
153// ListenAndServe starts the server and begins accepting connections.
154func (s *Server) ListenAndServe() error {
155	if s.ln != nil {
156		return fmt.Errorf("server already started")
157	}
158	ln, err := listen("unix", s.Addr)
159	if err != nil {
160		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
161	}
162	return s.Serve(ln)
163}
164
165func (s *Server) closeListener() {
166	if s.ln != nil {
167		s.ln.Close()
168		s.ln = nil
169	}
170}
171
172// Close force close all listeners and connections.
173func (s *Server) Close() error {
174	defer func() { s.closeListener() }()
175	return s.h.Close()
176}
177
178// Shutdown gracefully shuts down the server without interrupting active
179// connections. It stops accepting new connections and waits for existing
180// connections to finish.
181func (s *Server) Shutdown(ctx context.Context) error {
182	defer func() { s.closeListener() }()
183	return s.h.Shutdown(ctx)
184}
185
186func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
187	if s.logger != nil {
188		s.logger.With(
189			slog.String("method", r.Method),
190			slog.String("url", r.URL.String()),
191			slog.String("remote_addr", r.RemoteAddr),
192		).Debug(msg, args...)
193	}
194}
195
196func (s *Server) logError(r *http.Request, msg string, args ...any) {
197	if s.logger != nil {
198		s.logger.With(
199			slog.String("method", r.Method),
200			slog.String("url", r.URL.String()),
201			slog.String("remote_addr", r.RemoteAddr),
202		).Error(msg, args...)
203	}
204}