server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/url"
 10	"os/user"
 11	"runtime"
 12	"strings"
 13
 14	"github.com/charmbracelet/crush/internal/backend"
 15	"github.com/charmbracelet/crush/internal/config"
 16)
 17
 18// ErrServerClosed is returned when the server is closed.
 19var ErrServerClosed = http.ErrServerClosed
 20
 21// ParseHostURL parses a host URL into a [url.URL].
 22func ParseHostURL(host string) (*url.URL, error) {
 23	proto, addr, ok := strings.Cut(host, "://")
 24	if !ok {
 25		return nil, fmt.Errorf("invalid host format: %s", host)
 26	}
 27
 28	var basePath string
 29	if proto == "tcp" {
 30		parsed, err := url.Parse("tcp://" + addr)
 31		if err != nil {
 32			return nil, fmt.Errorf("invalid tcp address: %v", err)
 33		}
 34		addr = parsed.Host
 35		basePath = parsed.Path
 36	}
 37	return &url.URL{
 38		Scheme: proto,
 39		Host:   addr,
 40		Path:   basePath,
 41	}, nil
 42}
 43
 44// DefaultHost returns the default server host.
 45func DefaultHost() string {
 46	sock := "crush.sock"
 47	usr, err := user.Current()
 48	if err == nil && usr.Uid != "" {
 49		sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
 50	}
 51	if runtime.GOOS == "windows" {
 52		return fmt.Sprintf("npipe:////./pipe/%s", sock)
 53	}
 54	return fmt.Sprintf("unix:///tmp/%s", sock)
 55}
 56
 57// Server represents a Crush server bound to a specific address.
 58type Server struct {
 59	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 60	Addr    string
 61	network string
 62
 63	h  *http.Server
 64	ln net.Listener
 65
 66	backend *backend.Backend
 67	logger  *slog.Logger
 68}
 69
 70// SetLogger sets the logger for the server.
 71func (s *Server) SetLogger(logger *slog.Logger) {
 72	s.logger = logger
 73}
 74
 75// DefaultServer returns a new [Server] with the default address.
 76func DefaultServer(cfg *config.ConfigStore) *Server {
 77	hostURL, err := ParseHostURL(DefaultHost())
 78	if err != nil {
 79		panic("invalid default host")
 80	}
 81	return NewServer(cfg, hostURL.Scheme, hostURL.Host)
 82}
 83
 84// NewServer creates a new [Server] with the given network and address.
 85func NewServer(cfg *config.ConfigStore, network, address string) *Server {
 86	s := new(Server)
 87	s.Addr = address
 88	s.network = network
 89
 90	// The backend is created with a shutdown callback that triggers
 91	// a graceful server shutdown (e.g. when the last workspace is
 92	// removed).
 93	s.backend = backend.New(context.Background(), cfg, func() {
 94		go func() {
 95			slog.Info("Shutting down server...")
 96			if err := s.Shutdown(context.Background()); err != nil {
 97				slog.Error("Failed to shutdown server", "error", err)
 98			}
 99		}()
100	})
101
102	var p http.Protocols
103	p.SetHTTP1(true)
104	p.SetUnencryptedHTTP2(true)
105	c := &controllerV1{backend: s.backend, server: s}
106	mux := http.NewServeMux()
107	mux.HandleFunc("GET /v1/health", c.handleGetHealth)
108	mux.HandleFunc("GET /v1/version", c.handleGetVersion)
109	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
110	mux.HandleFunc("POST /v1/control", c.handlePostControl)
111	mux.HandleFunc("GET /v1/workspaces", c.handleGetWorkspaces)
112	mux.HandleFunc("POST /v1/workspaces", c.handlePostWorkspaces)
113	mux.HandleFunc("DELETE /v1/workspaces/{id}", c.handleDeleteWorkspaces)
114	mux.HandleFunc("GET /v1/workspaces/{id}", c.handleGetWorkspace)
115	mux.HandleFunc("GET /v1/workspaces/{id}/config", c.handleGetWorkspaceConfig)
116	mux.HandleFunc("GET /v1/workspaces/{id}/events", c.handleGetWorkspaceEvents)
117	mux.HandleFunc("GET /v1/workspaces/{id}/providers", c.handleGetWorkspaceProviders)
118	mux.HandleFunc("GET /v1/workspaces/{id}/sessions", c.handleGetWorkspaceSessions)
119	mux.HandleFunc("POST /v1/workspaces/{id}/sessions", c.handlePostWorkspaceSessions)
120	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession)
121	mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession)
122	mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession)
123	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory)
124	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages)
125	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages)
126	mux.HandleFunc("GET /v1/workspaces/{id}/messages/user", c.handleGetWorkspaceAllUserMessages)
127	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/filetracker/files", c.handleGetWorkspaceSessionFileTrackerFiles)
128	mux.HandleFunc("POST /v1/workspaces/{id}/filetracker/read", c.handlePostWorkspaceFileTrackerRead)
129	mux.HandleFunc("GET /v1/workspaces/{id}/filetracker/lastread", c.handleGetWorkspaceFileTrackerLastRead)
130	mux.HandleFunc("GET /v1/workspaces/{id}/lsps", c.handleGetWorkspaceLSPs)
131	mux.HandleFunc("GET /v1/workspaces/{id}/lsps/{lsp}/diagnostics", c.handleGetWorkspaceLSPDiagnostics)
132	mux.HandleFunc("POST /v1/workspaces/{id}/lsps/start", c.handlePostWorkspaceLSPStart)
133	mux.HandleFunc("POST /v1/workspaces/{id}/lsps/stop", c.handlePostWorkspaceLSPStopAll)
134	mux.HandleFunc("GET /v1/workspaces/{id}/permissions/skip", c.handleGetWorkspacePermissionsSkip)
135	mux.HandleFunc("POST /v1/workspaces/{id}/permissions/skip", c.handlePostWorkspacePermissionsSkip)
136	mux.HandleFunc("POST /v1/workspaces/{id}/permissions/grant", c.handlePostWorkspacePermissionsGrant)
137	mux.HandleFunc("GET /v1/workspaces/{id}/agent", c.handleGetWorkspaceAgent)
138	mux.HandleFunc("POST /v1/workspaces/{id}/agent", c.handlePostWorkspaceAgent)
139	mux.HandleFunc("POST /v1/workspaces/{id}/agent/init", c.handlePostWorkspaceAgentInit)
140	mux.HandleFunc("POST /v1/workspaces/{id}/agent/update", c.handlePostWorkspaceAgentUpdate)
141	mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}", c.handleGetWorkspaceAgentSession)
142	mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel", c.handlePostWorkspaceAgentSessionCancel)
143	mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetWorkspaceAgentSessionPromptQueued)
144	mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/list", c.handleGetWorkspaceAgentSessionPromptList)
145	mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostWorkspaceAgentSessionPromptClear)
146	mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/summarize", c.handlePostWorkspaceAgentSessionSummarize)
147	mux.HandleFunc("GET /v1/workspaces/{id}/agent/default-small-model", c.handleGetWorkspaceAgentDefaultSmallModel)
148	mux.HandleFunc("POST /v1/workspaces/{id}/config/set", c.handlePostWorkspaceConfigSet)
149	mux.HandleFunc("POST /v1/workspaces/{id}/config/remove", c.handlePostWorkspaceConfigRemove)
150	mux.HandleFunc("POST /v1/workspaces/{id}/config/model", c.handlePostWorkspaceConfigModel)
151	mux.HandleFunc("POST /v1/workspaces/{id}/config/compact", c.handlePostWorkspaceConfigCompact)
152	mux.HandleFunc("POST /v1/workspaces/{id}/config/provider-key", c.handlePostWorkspaceConfigProviderKey)
153	mux.HandleFunc("POST /v1/workspaces/{id}/config/import-copilot", c.handlePostWorkspaceConfigImportCopilot)
154	mux.HandleFunc("POST /v1/workspaces/{id}/config/refresh-oauth", c.handlePostWorkspaceConfigRefreshOAuth)
155	mux.HandleFunc("GET /v1/workspaces/{id}/project/needs-init", c.handleGetWorkspaceProjectNeedsInit)
156	mux.HandleFunc("POST /v1/workspaces/{id}/project/init", c.handlePostWorkspaceProjectInit)
157	mux.HandleFunc("GET /v1/workspaces/{id}/project/init-prompt", c.handleGetWorkspaceProjectInitPrompt)
158	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-tools", c.handlePostWorkspaceMCPRefreshTools)
159	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/read-resource", c.handlePostWorkspaceMCPReadResource)
160	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/get-prompt", c.handlePostWorkspaceMCPGetPrompt)
161	mux.HandleFunc("GET /v1/workspaces/{id}/mcp/states", c.handleGetWorkspaceMCPStates)
162	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-prompts", c.handlePostWorkspaceMCPRefreshPrompts)
163	mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-resources", c.handlePostWorkspaceMCPRefreshResources)
164	s.h = &http.Server{
165		Protocols: &p,
166		Handler:   s.loggingHandler(mux),
167	}
168	if network == "tcp" {
169		s.h.Addr = address
170	}
171	return s
172}
173
174// Serve accepts incoming connections on the listener.
175func (s *Server) Serve(ln net.Listener) error {
176	return s.h.Serve(ln)
177}
178
179// ListenAndServe starts the server and begins accepting connections.
180func (s *Server) ListenAndServe() error {
181	if s.ln != nil {
182		return fmt.Errorf("server already started")
183	}
184	ln, err := listen(s.network, s.Addr)
185	if err != nil {
186		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
187	}
188	return s.Serve(ln)
189}
190
191func (s *Server) closeListener() {
192	if s.ln != nil {
193		s.ln.Close()
194		s.ln = nil
195	}
196}
197
198// Close force closes all listeners and connections.
199func (s *Server) Close() error {
200	defer func() { s.closeListener() }()
201	return s.h.Close()
202}
203
204// Shutdown gracefully shuts down the server without interrupting active
205// connections.
206func (s *Server) Shutdown(ctx context.Context) error {
207	defer func() { s.closeListener() }()
208	return s.h.Shutdown(ctx)
209}
210
211func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
212	if s.logger != nil {
213		s.logger.With(
214			slog.String("method", r.Method),
215			slog.String("url", r.URL.String()),
216			slog.String("remote_addr", r.RemoteAddr),
217		).Debug(msg, args...)
218	}
219}
220
221func (s *Server) logError(r *http.Request, msg string, args ...any) {
222	if s.logger != nil {
223		s.logger.With(
224			slog.String("method", r.Method),
225			slog.String("url", r.URL.String()),
226			slog.String("remote_addr", r.RemoteAddr),
227		).Error(msg, args...)
228	}
229}