server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/url"
 10	"os/user"
 11	"runtime"
 12	"strings"
 13
 14	"github.com/charmbracelet/crush/internal/app"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17)
 18
 19// ErrServerClosed is returned when the server is closed.
 20var ErrServerClosed = http.ErrServerClosed
 21
 22// Instance represents a running [app.App] instance with its associated
 23// resources and state.
 24type Instance struct {
 25	*app.App
 26	ln   net.Listener
 27	cfg  *config.Config
 28	id   string
 29	path string
 30	env  []string
 31}
 32
 33// ParseHostURL parses a host URL into a [url.URL].
 34func ParseHostURL(host string) (*url.URL, error) {
 35	proto, addr, ok := strings.Cut(host, "://")
 36	if !ok {
 37		return nil, fmt.Errorf("invalid host format: %s", host)
 38	}
 39
 40	var basePath string
 41	if proto == "tcp" {
 42		parsed, err := url.Parse("tcp://" + addr)
 43		if err != nil {
 44			return nil, fmt.Errorf("invalid tcp address: %v", err)
 45		}
 46		addr = parsed.Host
 47		basePath = parsed.Path
 48	}
 49	return &url.URL{
 50		Scheme: proto,
 51		Host:   addr,
 52		Path:   basePath,
 53	}, nil
 54}
 55
 56// DefaultHost returns the default server host.
 57func DefaultHost() string {
 58	sock := "crush.sock"
 59	usr, err := user.Current()
 60	if err == nil && usr.Uid != "" {
 61		sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
 62	}
 63	if runtime.GOOS == "windows" {
 64		return fmt.Sprintf("npipe:////./pipe/%s", sock)
 65	}
 66	return fmt.Sprintf("unix:///tmp/%s", sock)
 67}
 68
 69// Server represents a Crush server instance bound to a specific address.
 70type Server struct {
 71	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 72	Addr    string
 73	network string
 74
 75	h   *http.Server
 76	ln  net.Listener
 77	ctx context.Context
 78
 79	// instances is a map of running applications managed by the server.
 80	instances *csync.Map[string, *Instance]
 81	cfg       *config.Config
 82	logger    *slog.Logger
 83}
 84
 85// SetLogger sets the logger for the server.
 86func (s *Server) SetLogger(logger *slog.Logger) {
 87	s.logger = logger
 88}
 89
 90// DefaultServer returns a new [Server] instance with the default address.
 91func DefaultServer(cfg *config.Config) *Server {
 92	hostURL, err := ParseHostURL(DefaultHost())
 93	if err != nil {
 94		panic("invalid default host")
 95	}
 96	return NewServer(cfg, hostURL.Scheme, hostURL.Host)
 97}
 98
 99// NewServer is a helper to create a new [Server] instance with the given
100// address. On Windows, if the address is not a "tcp" address, it will be
101// converted to a named pipe format.
102func NewServer(cfg *config.Config, network, address string) *Server {
103	s := new(Server)
104	s.Addr = address
105	s.network = network
106	s.cfg = cfg
107	s.instances = csync.NewMap[string, *Instance]()
108	s.ctx = context.Background()
109
110	var p http.Protocols
111	p.SetHTTP1(true)
112	p.SetUnencryptedHTTP2(true)
113	c := &controllerV1{Server: s}
114	mux := http.NewServeMux()
115	mux.HandleFunc("GET /v1/health", c.handleGetHealth)
116	mux.HandleFunc("GET /v1/version", c.handleGetVersion)
117	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
118	mux.HandleFunc("POST /v1/control", c.handlePostControl)
119	mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
120	mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
121	mux.HandleFunc("DELETE /v1/instances/{id}", c.handleDeleteInstances)
122	mux.HandleFunc("GET /v1/instances/{id}", c.handleGetInstance)
123	mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
124	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
125	mux.HandleFunc("GET /v1/instances/{id}/providers", c.handleGetInstanceProviders)
126	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
127	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
128	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
129	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
130	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
131	mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
132	mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
133	mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
134	mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
135	mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
136	mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
137	mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
138	mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
139	mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
140	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
141	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
142	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
143	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
144	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
145	s.h = &http.Server{
146		Protocols: &p,
147		Handler:   s.loggingHandler(mux),
148	}
149	if network == "tcp" {
150		s.h.Addr = address
151	}
152	return s
153}
154
155// Serve accepts incoming connections on the listener.
156func (s *Server) Serve(ln net.Listener) error {
157	return s.h.Serve(ln)
158}
159
160// ListenAndServe starts the server and begins accepting connections.
161func (s *Server) ListenAndServe() error {
162	if s.ln != nil {
163		return fmt.Errorf("server already started")
164	}
165	ln, err := listen(s.network, s.Addr)
166	if err != nil {
167		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
168	}
169	return s.Serve(ln)
170}
171
172func (s *Server) closeListener() {
173	if s.ln != nil {
174		s.ln.Close()
175		s.ln = nil
176	}
177}
178
179// Close force close all listeners and connections.
180func (s *Server) Close() error {
181	defer func() { s.closeListener() }()
182	return s.h.Close()
183}
184
185// Shutdown gracefully shuts down the server without interrupting active
186// connections. It stops accepting new connections and waits for existing
187// connections to finish.
188func (s *Server) Shutdown(ctx context.Context) error {
189	defer func() { s.closeListener() }()
190	return s.h.Shutdown(ctx)
191}
192
193func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
194	if s.logger != nil {
195		s.logger.With(
196			slog.String("method", r.Method),
197			slog.String("url", r.URL.String()),
198			slog.String("remote_addr", r.RemoteAddr),
199		).Debug(msg, args...)
200	}
201}
202
203func (s *Server) logError(r *http.Request, msg string, args ...any) {
204	if s.logger != nil {
205		s.logger.With(
206			slog.String("method", r.Method),
207			slog.String("url", r.URL.String()),
208			slog.String("remote_addr", r.RemoteAddr),
209		).Error(msg, args...)
210	}
211}