1package server
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/http"
9 "net/url"
10 "os/user"
11 "runtime"
12 "strings"
13
14 "github.com/charmbracelet/crush/internal/app"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17)
18
19// ErrServerClosed is returned when the server is closed.
20var ErrServerClosed = http.ErrServerClosed
21
22// Instance represents a running [app.App] instance with its associated
23// resources and state.
24type Instance struct {
25 *app.App
26 ln net.Listener
27 cfg *config.Config
28 id string
29 path string
30 env []string
31}
32
33// ParseHostURL parses a host URL into a [url.URL].
34func ParseHostURL(host string) (*url.URL, error) {
35 proto, addr, ok := strings.Cut(host, "://")
36 if !ok {
37 return nil, fmt.Errorf("invalid host format: %s", host)
38 }
39
40 var basePath string
41 if proto == "tcp" {
42 parsed, err := url.Parse("tcp://" + addr)
43 if err != nil {
44 return nil, fmt.Errorf("invalid tcp address: %v", err)
45 }
46 addr = parsed.Host
47 basePath = parsed.Path
48 }
49 return &url.URL{
50 Scheme: proto,
51 Host: addr,
52 Path: basePath,
53 }, nil
54}
55
56// DefaultHost returns the default server host.
57func DefaultHost() string {
58 sock := "crush.sock"
59 usr, err := user.Current()
60 if err == nil && usr.Uid != "" {
61 sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
62 }
63 if runtime.GOOS == "windows" {
64 return fmt.Sprintf("npipe:////./pipe/%s", sock)
65 }
66 return fmt.Sprintf("unix:///tmp/%s", sock)
67}
68
69// Server represents a Crush server instance bound to a specific address.
70type Server struct {
71 // Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
72 Addr string
73 network string
74
75 h *http.Server
76 ln net.Listener
77 ctx context.Context
78
79 // instances is a map of running applications managed by the server.
80 instances *csync.Map[string, *Instance]
81 cfg *config.Config
82 logger *slog.Logger
83}
84
85// SetLogger sets the logger for the server.
86func (s *Server) SetLogger(logger *slog.Logger) {
87 s.logger = logger
88}
89
90// DefaultServer returns a new [Server] instance with the default address.
91func DefaultServer(cfg *config.Config) *Server {
92 hostURL, err := ParseHostURL(DefaultHost())
93 if err != nil {
94 panic("invalid default host")
95 }
96 return NewServer(cfg, hostURL.Scheme, hostURL.Host)
97}
98
99// NewServer is a helper to create a new [Server] instance with the given
100// address. On Windows, if the address is not a "tcp" address, it will be
101// converted to a named pipe format.
102func NewServer(cfg *config.Config, network, address string) *Server {
103 s := new(Server)
104 s.Addr = address
105 s.network = network
106 s.cfg = cfg
107 s.instances = csync.NewMap[string, *Instance]()
108 s.ctx = context.Background()
109
110 var p http.Protocols
111 p.SetHTTP1(true)
112 p.SetUnencryptedHTTP2(true)
113 c := &controllerV1{Server: s}
114 mux := http.NewServeMux()
115 mux.HandleFunc("GET /v1/health", c.handleGetHealth)
116 mux.HandleFunc("GET /v1/version", c.handleGetVersion)
117 mux.HandleFunc("GET /v1/config", c.handleGetConfig)
118 mux.HandleFunc("POST /v1/control", c.handlePostControl)
119 mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
120 mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
121 mux.HandleFunc("DELETE /v1/instances/{id}", c.handleDeleteInstances)
122 mux.HandleFunc("GET /v1/instances/{id}", c.handleGetInstance)
123 mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
124 mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
125 mux.HandleFunc("GET /v1/instances/{id}/providers", c.handleGetInstanceProviders)
126 mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
127 mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
128 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
129 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
130 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
131 mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
132 mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
133 mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
134 mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
135 mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
136 mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
137 mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
138 mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
139 mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
140 mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
141 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
142 mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
143 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
144 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
145 s.h = &http.Server{
146 Protocols: &p,
147 Handler: s.loggingHandler(mux),
148 }
149 if network == "tcp" {
150 s.h.Addr = address
151 }
152 return s
153}
154
155// Serve accepts incoming connections on the listener.
156func (s *Server) Serve(ln net.Listener) error {
157 return s.h.Serve(ln)
158}
159
160// ListenAndServe starts the server and begins accepting connections.
161func (s *Server) ListenAndServe() error {
162 if s.ln != nil {
163 return fmt.Errorf("server already started")
164 }
165 ln, err := listen(s.network, s.Addr)
166 if err != nil {
167 return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
168 }
169 return s.Serve(ln)
170}
171
172func (s *Server) closeListener() {
173 if s.ln != nil {
174 s.ln.Close()
175 s.ln = nil
176 }
177}
178
179// Close force close all listeners and connections.
180func (s *Server) Close() error {
181 defer func() { s.closeListener() }()
182 return s.h.Close()
183}
184
185// Shutdown gracefully shuts down the server without interrupting active
186// connections. It stops accepting new connections and waits for existing
187// connections to finish.
188func (s *Server) Shutdown(ctx context.Context) error {
189 defer func() { s.closeListener() }()
190 return s.h.Shutdown(ctx)
191}
192
193func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
194 if s.logger != nil {
195 s.logger.With(
196 slog.String("method", r.Method),
197 slog.String("url", r.URL.String()),
198 slog.String("remote_addr", r.RemoteAddr),
199 ).Debug(msg, args...)
200 }
201}
202
203func (s *Server) logError(r *http.Request, msg string, args ...any) {
204 if s.logger != nil {
205 s.logger.With(
206 slog.String("method", r.Method),
207 slog.String("url", r.URL.String()),
208 slog.String("remote_addr", r.RemoteAddr),
209 ).Error(msg, args...)
210 }
211}