server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/url"
 10	"os/user"
 11	"runtime"
 12	"strings"
 13
 14	"github.com/charmbracelet/crush/internal/app"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17)
 18
 19// ErrServerClosed is returned when the server is closed.
 20var ErrServerClosed = http.ErrServerClosed
 21
 22// InstanceState represents the state of a running [app.App] instance.
 23type InstanceState uint8
 24
 25const (
 26	// InstanceStateCreated indicates that the instance has been created but not yet started.
 27	InstanceStateCreated InstanceState = iota
 28	// InstanceStateStarted indicates that the instance is currently running.
 29	InstanceStateStarted
 30	// InstanceStateStopped indicates that the instance has been stopped.
 31	InstanceStateStopped
 32)
 33
 34// Instance represents a running [app.App] instance with its associated
 35// resources and state.
 36type Instance struct {
 37	*app.App
 38	State InstanceState
 39	ln    net.Listener
 40	cfg   *config.Config
 41	id    string
 42	path  string
 43}
 44
 45// ID returns the unique identifier of the instance.
 46func (i *Instance) ID() string {
 47	return i.id
 48}
 49
 50// Path returns the filesystem path associated with the instance.
 51func (i *Instance) Path() string {
 52	return i.path
 53}
 54
 55// ParseHostURL parses a host URL into a [url.URL].
 56func ParseHostURL(host string) (*url.URL, error) {
 57	proto, addr, ok := strings.Cut(host, "://")
 58	if !ok {
 59		return nil, fmt.Errorf("invalid host format: %s", host)
 60	}
 61
 62	var basePath string
 63	if proto == "tcp" {
 64		parsed, err := url.Parse("tcp://" + addr)
 65		if err != nil {
 66			return nil, fmt.Errorf("invalid tcp address: %v", err)
 67		}
 68		addr = parsed.Host
 69		basePath = parsed.Path
 70	}
 71	return &url.URL{
 72		Scheme: proto,
 73		Host:   addr,
 74		Path:   basePath,
 75	}, nil
 76}
 77
 78// DefaultHost returns the default server host.
 79func DefaultHost() string {
 80	sock := "crush.sock"
 81	usr, err := user.Current()
 82	if err == nil && usr.Uid != "" {
 83		sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
 84	}
 85	if runtime.GOOS == "windows" {
 86		return fmt.Sprintf("npipe:////./pipe/%s", sock)
 87	}
 88	return fmt.Sprintf("unix:///tmp/%s", sock)
 89}
 90
 91// Server represents a Crush server instance bound to a specific address.
 92type Server struct {
 93	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 94	Addr    string
 95	network string
 96
 97	h   *http.Server
 98	ln  net.Listener
 99	ctx context.Context
100
101	// instances is a map of running applications managed by the server.
102	instances *csync.Map[string, *Instance]
103	cfg       *config.Config
104	logger    *slog.Logger
105}
106
107// SetLogger sets the logger for the server.
108func (s *Server) SetLogger(logger *slog.Logger) {
109	s.logger = logger
110}
111
112// DefaultServer returns a new [Server] instance with the default address.
113func DefaultServer(cfg *config.Config) *Server {
114	hostURL, err := ParseHostURL(DefaultHost())
115	if err != nil {
116		panic("invalid default host")
117	}
118	return NewServer(cfg, hostURL.Scheme, hostURL.Host)
119}
120
121// NewServer is a helper to create a new [Server] instance with the given
122// address. On Windows, if the address is not a "tcp" address, it will be
123// converted to a named pipe format.
124func NewServer(cfg *config.Config, network, address string) *Server {
125	s := new(Server)
126	s.Addr = address
127	s.network = network
128	s.cfg = cfg
129	s.instances = csync.NewMap[string, *Instance]()
130	s.ctx = context.Background()
131
132	var p http.Protocols
133	p.SetHTTP1(true)
134	p.SetUnencryptedHTTP2(true)
135	c := &controllerV1{Server: s}
136	mux := http.NewServeMux()
137	mux.HandleFunc("GET /v1/health", c.handleGetHealth)
138	mux.HandleFunc("GET /v1/version", c.handleGetVersion)
139	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
140	mux.HandleFunc("POST /v1/control", c.handlePostControl)
141	mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
142	mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
143	mux.HandleFunc("DELETE /v1/instances/{id}", c.handleDeleteInstances)
144	mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
145	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
146	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
147	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
148	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
149	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
150	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
151	mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
152	mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
153	mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
154	mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
155	mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
156	mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
157	mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
158	mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
159	mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
160	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
161	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
162	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
163	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
164	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
165	s.h = &http.Server{
166		Protocols: &p,
167		Handler:   s.loggingHandler(mux),
168	}
169	if network == "tcp" {
170		s.h.Addr = address
171	}
172	return s
173}
174
175// Serve accepts incoming connections on the listener.
176func (s *Server) Serve(ln net.Listener) error {
177	return s.h.Serve(ln)
178}
179
180// ListenAndServe starts the server and begins accepting connections.
181func (s *Server) ListenAndServe() error {
182	if s.ln != nil {
183		return fmt.Errorf("server already started")
184	}
185	ln, err := listen(s.network, s.Addr)
186	if err != nil {
187		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
188	}
189	return s.Serve(ln)
190}
191
192func (s *Server) closeListener() {
193	if s.ln != nil {
194		s.ln.Close()
195		s.ln = nil
196	}
197}
198
199// Close force close all listeners and connections.
200func (s *Server) Close() error {
201	defer func() { s.closeListener() }()
202	return s.h.Close()
203}
204
205// Shutdown gracefully shuts down the server without interrupting active
206// connections. It stops accepting new connections and waits for existing
207// connections to finish.
208func (s *Server) Shutdown(ctx context.Context) error {
209	defer func() { s.closeListener() }()
210	return s.h.Shutdown(ctx)
211}
212
213func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
214	if s.logger != nil {
215		s.logger.With(
216			slog.String("method", r.Method),
217			slog.String("url", r.URL.String()),
218			slog.String("remote_addr", r.RemoteAddr),
219		).Debug(msg, args...)
220	}
221}
222
223func (s *Server) logError(r *http.Request, msg string, args ...any) {
224	if s.logger != nil {
225		s.logger.With(
226			slog.String("method", r.Method),
227			slog.String("url", r.URL.String()),
228			slog.String("remote_addr", r.RemoteAddr),
229		).Error(msg, args...)
230	}
231}