server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"os"
 10	"os/user"
 11	"path/filepath"
 12	"runtime"
 13	"strings"
 14
 15	"github.com/charmbracelet/crush/internal/app"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18)
 19
 20// ErrServerClosed is returned when the server is closed.
 21var ErrServerClosed = fmt.Errorf("server closed")
 22
 23// InstanceState represents the state of a running [app.App] instance.
 24type InstanceState uint8
 25
 26const (
 27	// InstanceStateCreated indicates that the instance has been created but not yet started.
 28	InstanceStateCreated InstanceState = iota
 29	// InstanceStateStarted indicates that the instance is currently running.
 30	InstanceStateStarted
 31	// InstanceStateStopped indicates that the instance has been stopped.
 32	InstanceStateStopped
 33)
 34
 35// Instance represents a running [app.App] instance with its associated
 36// resources and state.
 37type Instance struct {
 38	*app.App
 39	State InstanceState
 40	ln    net.Listener
 41	cfg   *config.Config
 42	id    string
 43	path  string
 44}
 45
 46// ID returns the unique identifier of the instance.
 47func (i *Instance) ID() string {
 48	return i.id
 49}
 50
 51// Path returns the filesystem path associated with the instance.
 52func (i *Instance) Path() string {
 53	return i.path
 54}
 55
 56// DefaultAddr returns the default address path for the Crush server based on
 57// the operating system.
 58func DefaultAddr() string {
 59	sockPath := "crush.sock"
 60	user, err := user.Current()
 61	if err == nil && user.Uid != "" {
 62		sockPath = fmt.Sprintf("crush-%s.sock", user.Uid)
 63	}
 64	if runtime.GOOS == "windows" {
 65		return fmt.Sprintf(`\\.\pipe\%s`, sockPath)
 66	}
 67	return filepath.Join(os.TempDir(), sockPath)
 68}
 69
 70// Server represents a Crush server instance bound to a specific address.
 71type Server struct {
 72	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 73	Addr string
 74
 75	h   *http.Server
 76	ln  net.Listener
 77	ctx context.Context
 78
 79	// instances is a map of running applications managed by the server.
 80	instances *csync.Map[string, *Instance]
 81	cfg       *config.Config
 82	logger    *slog.Logger
 83}
 84
 85// SetLogger sets the logger for the server.
 86func (s *Server) SetLogger(logger *slog.Logger) {
 87	s.logger = logger
 88}
 89
 90// DefaultServer returns a new [Server] instance with the default address.
 91func DefaultServer(cfg *config.Config) *Server {
 92	return NewServer(cfg, "unix", DefaultAddr())
 93}
 94
 95// NewServer is a helper to create a new [Server] instance with the given
 96// address. On Windows, if the address is not a "tcp" address, it will be
 97// converted to a named pipe format.
 98func NewServer(cfg *config.Config, network, address string) *Server {
 99	if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
100		!strings.HasPrefix(address, `\\.\pipe\`) {
101		// On Windows, convert to named pipe format if not TCP
102		// (e.g., "mypipe" -> "\\.\pipe\mypipe")
103		address = fmt.Sprintf(`\\.\pipe\%s`, address)
104	}
105
106	s := new(Server)
107	s.Addr = address
108	s.cfg = cfg
109	s.instances = csync.NewMap[string, *Instance]()
110	s.ctx = context.Background()
111
112	var p http.Protocols
113	p.SetHTTP1(true)
114	p.SetUnencryptedHTTP2(true)
115	c := &controllerV1{Server: s}
116	mux := http.NewServeMux()
117	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
118	mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
119	mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
120	mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances)
121	mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
122	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
123	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
124	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
125	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
126	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
127	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
128	mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
129	mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
130	mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
131	mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
132	mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
133	mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
134	mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
135	mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
136	mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
137	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
138	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
139	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
140	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
141	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
142	s.h = &http.Server{
143		Protocols: &p,
144		Handler:   s.loggingHandler(mux),
145	}
146	return s
147}
148
149// Serve accepts incoming connections on the listener.
150func (s *Server) Serve(ln net.Listener) error {
151	return s.h.Serve(ln)
152}
153
154// ListenAndServe starts the server and begins accepting connections.
155func (s *Server) ListenAndServe() error {
156	if s.ln != nil {
157		return fmt.Errorf("server already started")
158	}
159	ln, err := listen("unix", s.Addr)
160	if err != nil {
161		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
162	}
163	return s.Serve(ln)
164}
165
166func (s *Server) closeListener() {
167	if s.ln != nil {
168		s.ln.Close()
169		s.ln = nil
170	}
171}
172
173// Close force close all listeners and connections.
174func (s *Server) Close() error {
175	defer func() { s.closeListener() }()
176	return s.h.Close()
177}
178
179// Shutdown gracefully shuts down the server without interrupting active
180// connections. It stops accepting new connections and waits for existing
181// connections to finish.
182func (s *Server) Shutdown(ctx context.Context) error {
183	defer func() { s.closeListener() }()
184	return s.h.Shutdown(ctx)
185}
186
187func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
188	if s.logger != nil {
189		s.logger.With(
190			slog.String("method", r.Method),
191			slog.String("url", r.URL.String()),
192			slog.String("remote_addr", r.RemoteAddr),
193		).Debug(msg, args...)
194	}
195}
196
197func (s *Server) logError(r *http.Request, msg string, args ...any) {
198	if s.logger != nil {
199		s.logger.With(
200			slog.String("method", r.Method),
201			slog.String("url", r.URL.String()),
202			slog.String("remote_addr", r.RemoteAddr),
203		).Error(msg, args...)
204	}
205}