1package server
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/http"
9 "net/url"
10 "os/user"
11 "runtime"
12 "strings"
13
14 "github.com/charmbracelet/crush/internal/app"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17)
18
19// ErrServerClosed is returned when the server is closed.
20var ErrServerClosed = http.ErrServerClosed
21
22// InstanceState represents the state of a running [app.App] instance.
23type InstanceState uint8
24
25const (
26 // InstanceStateCreated indicates that the instance has been created but not yet started.
27 InstanceStateCreated InstanceState = iota
28 // InstanceStateStarted indicates that the instance is currently running.
29 InstanceStateStarted
30 // InstanceStateStopped indicates that the instance has been stopped.
31 InstanceStateStopped
32)
33
34// Instance represents a running [app.App] instance with its associated
35// resources and state.
36type Instance struct {
37 *app.App
38 State InstanceState
39 ln net.Listener
40 cfg *config.Config
41 id string
42 path string
43}
44
45// ID returns the unique identifier of the instance.
46func (i *Instance) ID() string {
47 return i.id
48}
49
50// Path returns the filesystem path associated with the instance.
51func (i *Instance) Path() string {
52 return i.path
53}
54
55// ParseHostURL parses a host URL into a [url.URL].
56func ParseHostURL(host string) (*url.URL, error) {
57 proto, addr, ok := strings.Cut(host, "://")
58 if !ok {
59 return nil, fmt.Errorf("invalid host format: %s", host)
60 }
61
62 var basePath string
63 if proto == "tcp" {
64 parsed, err := url.Parse("tcp://" + addr)
65 if err != nil {
66 return nil, fmt.Errorf("invalid tcp address: %v", err)
67 }
68 addr = parsed.Host
69 basePath = parsed.Path
70 }
71 return &url.URL{
72 Scheme: proto,
73 Host: addr,
74 Path: basePath,
75 }, nil
76}
77
78// DefaultHost returns the default server host.
79func DefaultHost() string {
80 sock := "crush.sock"
81 usr, err := user.Current()
82 if err == nil && usr.Uid != "" {
83 sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
84 }
85 if runtime.GOOS == "windows" {
86 return fmt.Sprintf("npipe:////./pipe/%s", sock)
87 }
88 return fmt.Sprintf("unix:///tmp/%s", sock)
89}
90
91// Server represents a Crush server instance bound to a specific address.
92type Server struct {
93 // Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
94 Addr string
95 network string
96
97 h *http.Server
98 ln net.Listener
99 ctx context.Context
100
101 // instances is a map of running applications managed by the server.
102 instances *csync.Map[string, *Instance]
103 cfg *config.Config
104 logger *slog.Logger
105}
106
107// SetLogger sets the logger for the server.
108func (s *Server) SetLogger(logger *slog.Logger) {
109 s.logger = logger
110}
111
112// DefaultServer returns a new [Server] instance with the default address.
113func DefaultServer(cfg *config.Config) *Server {
114 hostURL, err := ParseHostURL(DefaultHost())
115 if err != nil {
116 panic("invalid default host")
117 }
118 return NewServer(cfg, hostURL.Scheme, hostURL.Host)
119}
120
121// NewServer is a helper to create a new [Server] instance with the given
122// address. On Windows, if the address is not a "tcp" address, it will be
123// converted to a named pipe format.
124func NewServer(cfg *config.Config, network, address string) *Server {
125 s := new(Server)
126 s.Addr = address
127 s.network = network
128 s.cfg = cfg
129 s.instances = csync.NewMap[string, *Instance]()
130 s.ctx = context.Background()
131
132 var p http.Protocols
133 p.SetHTTP1(true)
134 p.SetUnencryptedHTTP2(true)
135 c := &controllerV1{Server: s}
136 mux := http.NewServeMux()
137 mux.HandleFunc("GET /v1/health", c.handleGetHealth)
138 mux.HandleFunc("GET /v1/version", c.handleGetVersion)
139 mux.HandleFunc("GET /v1/config", c.handleGetConfig)
140 mux.HandleFunc("POST /v1/control", c.handlePostControl)
141 mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
142 mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
143 mux.HandleFunc("DELETE /v1/instances/{id}", c.handleDeleteInstances)
144 mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
145 mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
146 mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
147 mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
148 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
149 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
150 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
151 mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
152 mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
153 mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
154 mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
155 mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
156 mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
157 mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
158 mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
159 mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
160 mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
161 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
162 mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
163 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
164 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
165 s.h = &http.Server{
166 Protocols: &p,
167 Handler: s.loggingHandler(mux),
168 }
169 if network == "tcp" {
170 s.h.Addr = address
171 }
172 return s
173}
174
175// Serve accepts incoming connections on the listener.
176func (s *Server) Serve(ln net.Listener) error {
177 return s.h.Serve(ln)
178}
179
180// ListenAndServe starts the server and begins accepting connections.
181func (s *Server) ListenAndServe() error {
182 if s.ln != nil {
183 return fmt.Errorf("server already started")
184 }
185 ln, err := listen(s.network, s.Addr)
186 if err != nil {
187 return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
188 }
189 return s.Serve(ln)
190}
191
192func (s *Server) closeListener() {
193 if s.ln != nil {
194 s.ln.Close()
195 s.ln = nil
196 }
197}
198
199// Close force close all listeners and connections.
200func (s *Server) Close() error {
201 defer func() { s.closeListener() }()
202 return s.h.Close()
203}
204
205// Shutdown gracefully shuts down the server without interrupting active
206// connections. It stops accepting new connections and waits for existing
207// connections to finish.
208func (s *Server) Shutdown(ctx context.Context) error {
209 defer func() { s.closeListener() }()
210 return s.h.Shutdown(ctx)
211}
212
213func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
214 if s.logger != nil {
215 s.logger.With(
216 slog.String("method", r.Method),
217 slog.String("url", r.URL.String()),
218 slog.String("remote_addr", r.RemoteAddr),
219 ).Debug(msg, args...)
220 }
221}
222
223func (s *Server) logError(r *http.Request, msg string, args ...any) {
224 if s.logger != nil {
225 s.logger.With(
226 slog.String("method", r.Method),
227 slog.String("url", r.URL.String()),
228 slog.String("remote_addr", r.RemoteAddr),
229 ).Error(msg, args...)
230 }
231}