server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/url"
 10	"os"
 11	"os/user"
 12	"path/filepath"
 13	"runtime"
 14	"strings"
 15
 16	"github.com/charmbracelet/crush/internal/backend"
 17	"github.com/charmbracelet/crush/internal/config"
 18	_ "github.com/charmbracelet/crush/internal/swagger"
 19	httpswagger "github.com/swaggo/http-swagger/v2"
 20)
 21
 22// maxUnixSocketPathLen is the maximum length of a Unix domain socket
 23// path. The macOS sun_path field is 104 bytes; Linux allows 108. We
 24// use 104 so the resulting path is portable across both platforms.
 25const maxUnixSocketPathLen = 104
 26
 27// socketDir returns the directory used for the Crush Unix socket.
 28// It prefers $XDG_RUNTIME_DIR when set (systemd's per-user runtime
 29// directory on Linux), and otherwise falls back to [os.TempDir],
 30// which resolves to the per-user private $TMPDIR on macOS and to
 31// /tmp on Linux.
 32func socketDir() string {
 33	if dir := os.Getenv("XDG_RUNTIME_DIR"); dir != "" {
 34		return dir
 35	}
 36	return os.TempDir()
 37}
 38
 39// ErrServerClosed is returned when the server is closed.
 40var ErrServerClosed = http.ErrServerClosed
 41
 42// ParseHostURL parses a host URL into a [url.URL].
 43func ParseHostURL(host string) (*url.URL, error) {
 44	proto, addr, ok := strings.Cut(host, "://")
 45	if !ok {
 46		return nil, fmt.Errorf("invalid host format: %s", host)
 47	}
 48
 49	var basePath string
 50	if proto == "tcp" {
 51		parsed, err := url.Parse("tcp://" + addr)
 52		if err != nil {
 53			return nil, fmt.Errorf("invalid tcp address: %v", err)
 54		}
 55		addr = parsed.Host
 56		basePath = parsed.Path
 57	}
 58	return &url.URL{
 59		Scheme: proto,
 60		Host:   addr,
 61		Path:   basePath,
 62	}, nil
 63}
 64
 65// DefaultHost returns the default server host.
 66//
 67// On Windows the address is a named pipe under \\.\pipe\. On Unix
 68// platforms the socket lives in the per-user runtime directory
 69// returned by [socketDir] and is named crush-<uid>.sock, falling
 70// back to crush.sock when the current uid cannot be determined. If
 71// the composed path would exceed [maxUnixSocketPathLen] bytes (the
 72// macOS sun_path limit), we fall back to /tmp/crush-<uid>.sock so
 73// the socket remains bindable.
 74func DefaultHost() string {
 75	sock := "crush.sock"
 76	usr, err := user.Current()
 77	if err == nil && usr.Uid != "" {
 78		sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
 79	}
 80	if runtime.GOOS == "windows" {
 81		return fmt.Sprintf("npipe:////./pipe/%s", sock)
 82	}
 83	path := filepath.Join(socketDir(), sock)
 84	if len(path) > maxUnixSocketPathLen {
 85		path = filepath.Join("/tmp", sock)
 86	}
 87	return "unix://" + path
 88}
 89
 90// Server represents a Crush server bound to a specific address.
 91type Server struct {
 92	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 93	Addr    string
 94	network string
 95
 96	h  *http.Server
 97	ln net.Listener
 98
 99	backend *backend.Backend
100	logger  *slog.Logger
101}
102
103// SetLogger sets the logger for the server.
104func (s *Server) SetLogger(logger *slog.Logger) {
105	s.logger = logger
106}
107
108// DefaultServer returns a new [Server] with the default address.
109func DefaultServer(cfg *config.ConfigStore) *Server {
110	hostURL, err := ParseHostURL(DefaultHost())
111	if err != nil {
112		panic("invalid default host")
113	}
114	return NewServer(cfg, hostURL.Scheme, hostURL.Host)
115}
116
117// NewServer creates a new [Server] with the given network and address.
118func NewServer(cfg *config.ConfigStore, network, address string) *Server {
119	s := new(Server)
120	s.Addr = address
121	s.network = network
122
123	// The backend is created with a shutdown callback that triggers
124	// a graceful server shutdown (e.g. when the last workspace is
125	// removed).
126	s.backend = backend.New(context.Background(), cfg, func() {
127		go func() {
128			slog.Info("Shutting down server...")
129			if err := s.Shutdown(context.Background()); err != nil {
130				slog.Error("Failed to shutdown server", "error", err)
131			}
132		}()
133	})
134	s.installHandler()
135	if network == "tcp" {
136		s.h.Addr = address
137	}
138	return s
139}
140
141// installHandler builds the protocol/router around s.backend and
142// assigns the resulting http.Server to s.h. Extracted from
143// [NewServer] so test harnesses can wire a Server around a
144// pre-constructed backend.
145func (s *Server) installHandler() {
146	var p http.Protocols
147	p.SetHTTP1(true)
148	p.SetUnencryptedHTTP2(true)
149	c := &controllerV1{backend: s.backend, server: s}
150	mux := http.NewServeMux()
151	mux.HandleFunc("GET /v1/health", c.handleGetHealth)
152	mux.HandleFunc("GET /v1/version", c.handleGetVersion)
153	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
154	mux.HandleFunc("POST /v1/control", c.handlePostControl)
155	mux.HandleFunc("GET /v1/workspaces", c.handleGetWorkspaces)
156	mux.HandleFunc("POST /v1/workspaces", c.handlePostWorkspaces)
157	mux.HandleFunc("DELETE /v1/workspaces/{id}", c.handleDeleteWorkspaces)
158	mux.HandleFunc("POST /v1/workspaces/{id}/current-session", c.handlePostWorkspaceCurrentSession)
159	mux.HandleFunc("GET /v1/workspaces/{id}", c.handleGetWorkspace)
160	mux.HandleFunc("GET /v1/workspaces/{id}/config", c.handleGetWorkspaceConfig)
161	mux.HandleFunc("GET /v1/workspaces/{id}/events", c.handleGetWorkspaceEvents)
162	mux.HandleFunc("GET /v1/workspaces/{id}/providers", c.handleGetWorkspaceProviders)
163	mux.HandleFunc("GET /v1/workspaces/{id}/sessions", c.handleGetWorkspaceSessions)
164	mux.HandleFunc("POST /v1/workspaces/{id}/sessions", c.handlePostWorkspaceSessions)
165	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession)
166	mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession)
167	mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession)
168	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory)
169	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages)
170	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages)
171	mux.HandleFunc("GET /v1/workspaces/{id}/messages/user", c.handleGetWorkspaceAllUserMessages)
172	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/filetracker/files", c.handleGetWorkspaceSessionFileTrackerFiles)
173	mux.HandleFunc("POST /v1/workspaces/{id}/filetracker/read", c.handlePostWorkspaceFileTrackerRead)
174	mux.HandleFunc("GET /v1/workspaces/{id}/filetracker/lastread", c.handleGetWorkspaceFileTrackerLastRead)
175	mux.HandleFunc("GET /v1/workspaces/{id}/lsps", c.handleGetWorkspaceLSPs)
176	mux.HandleFunc("GET /v1/workspaces/{id}/lsps/{lsp}/diagnostics", c.handleGetWorkspaceLSPDiagnostics)
177	mux.HandleFunc("POST /v1/workspaces/{id}/lsps/start", c.handlePostWorkspaceLSPStart)
178	mux.HandleFunc("POST /v1/workspaces/{id}/lsps/stop", c.handlePostWorkspaceLSPStopAll)
179	mux.HandleFunc("GET /v1/workspaces/{id}/permissions/skip", c.handleGetWorkspacePermissionsSkip)
180	mux.HandleFunc("POST /v1/workspaces/{id}/permissions/skip", c.handlePostWorkspacePermissionsSkip)
181	mux.HandleFunc("POST /v1/workspaces/{id}/permissions/grant", c.handlePostWorkspacePermissionsGrant)
182	mux.HandleFunc("GET /v1/workspaces/{id}/agent", c.handleGetWorkspaceAgent)
183	mux.HandleFunc("POST /v1/workspaces/{id}/agent", c.handlePostWorkspaceAgent)
184	mux.HandleFunc("POST /v1/workspaces/{id}/agent/init", c.handlePostWorkspaceAgentInit)
185	mux.HandleFunc("POST /v1/workspaces/{id}/agent/update", c.handlePostWorkspaceAgentUpdate)
186	mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}", c.handleGetWorkspaceAgentSession)
187	mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel", c.handlePostWorkspaceAgentSessionCancel)
188	mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetWorkspaceAgentSessionPromptQueued)
189	mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/list", c.handleGetWorkspaceAgentSessionPromptList)
190	mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostWorkspaceAgentSessionPromptClear)
191	mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/summarize", c.handlePostWorkspaceAgentSessionSummarize)
192	mux.HandleFunc("GET /v1/workspaces/{id}/agent/default-small-model", c.handleGetWorkspaceAgentDefaultSmallModel)
193	mux.HandleFunc("POST /v1/workspaces/{id}/config/set", c.handlePostWorkspaceConfigSet)
194	mux.HandleFunc("POST /v1/workspaces/{id}/config/remove", c.handlePostWorkspaceConfigRemove)
195	mux.HandleFunc("POST /v1/workspaces/{id}/config/model", c.handlePostWorkspaceConfigModel)
196	mux.HandleFunc("POST /v1/workspaces/{id}/config/compact", c.handlePostWorkspaceConfigCompact)
197	mux.HandleFunc("POST /v1/workspaces/{id}/config/provider-key", c.handlePostWorkspaceConfigProviderKey)
198	mux.HandleFunc("POST /v1/workspaces/{id}/config/import-copilot", c.handlePostWorkspaceConfigImportCopilot)
199	mux.HandleFunc("POST /v1/workspaces/{id}/config/refresh-oauth", c.handlePostWorkspaceConfigRefreshOAuth)
200	mux.HandleFunc("GET /v1/workspaces/{id}/project/needs-init", c.handleGetWorkspaceProjectNeedsInit)
201	mux.HandleFunc("POST /v1/workspaces/{id}/project/init", c.handlePostWorkspaceProjectInit)
202	mux.HandleFunc("GET /v1/workspaces/{id}/project/init-prompt", c.handleGetWorkspaceProjectInitPrompt)
203	mux.HandleFunc("GET /v1/workspaces/{id}/skills", c.handleGetWorkspaceSkills)
204	mux.HandleFunc("POST /v1/workspaces/{id}/skills/read", c.handlePostWorkspaceSkillRead)
205	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-tools", c.handlePostWorkspaceMCPRefreshTools)
206	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/read-resource", c.handlePostWorkspaceMCPReadResource)
207	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/get-prompt", c.handlePostWorkspaceMCPGetPrompt)
208	mux.HandleFunc("GET /v1/workspaces/{id}/mcp/states", c.handleGetWorkspaceMCPStates)
209	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-prompts", c.handlePostWorkspaceMCPRefreshPrompts)
210	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-resources", c.handlePostWorkspaceMCPRefreshResources)
211	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/docker/enable", c.handlePostWorkspaceMCPEnableDocker)
212	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/docker/disable", c.handlePostWorkspaceMCPDisableDocker)
213	mux.Handle("/v1/docs/", httpswagger.WrapHandler)
214	s.h = &http.Server{
215		Protocols: &p,
216		Handler:   s.recoverHandler(s.loggingHandler(mux)),
217	}
218}
219
220// Handler returns the server's HTTP handler. Exposed so test harnesses
221// can wrap it in an httptest.Server without going through the
222// production listener setup.
223func (s *Server) Handler() http.Handler {
224	return s.h.Handler
225}
226
227// Serve accepts incoming connections on the listener.
228func (s *Server) Serve(ln net.Listener) error {
229	return s.h.Serve(ln)
230}
231
232// ListenAndServe starts the server and begins accepting connections.
233func (s *Server) ListenAndServe() error {
234	if s.ln != nil {
235		return fmt.Errorf("server already started")
236	}
237	ln, removedStale, err := listen(s.network, s.Addr)
238	if err != nil {
239		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
240	}
241	if removedStale && s.logger != nil {
242		s.logger.Warn("Removed stale socket before binding", "address", s.Addr)
243	}
244	return s.Serve(ln)
245}
246
247func (s *Server) closeListener() {
248	if s.ln != nil {
249		s.ln.Close()
250		s.ln = nil
251	}
252}
253
254// Close force closes all listeners and connections.
255func (s *Server) Close() error {
256	defer func() { s.closeListener() }()
257	return s.h.Close()
258}
259
260// Shutdown gracefully shuts down the server without interrupting active
261// connections.
262func (s *Server) Shutdown(ctx context.Context) error {
263	defer func() { s.closeListener() }()
264	return s.h.Shutdown(ctx)
265}
266
267func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
268	if s.logger != nil {
269		s.logger.With(
270			slog.String("method", r.Method),
271			slog.String("url", r.URL.String()),
272			slog.String("remote_addr", r.RemoteAddr),
273		).Debug(msg, args...)
274	}
275}
276
277func (s *Server) logError(r *http.Request, msg string, args ...any) {
278	if s.logger != nil {
279		s.logger.With(
280			slog.String("method", r.Method),
281			slog.String("url", r.URL.String()),
282			slog.String("remote_addr", r.RemoteAddr),
283		).Error(msg, args...)
284	}
285}