1package server
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/http"
9 "os"
10 "os/user"
11 "path/filepath"
12 "runtime"
13 "strings"
14
15 "github.com/charmbracelet/crush/internal/app"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18)
19
20// ErrServerClosed is returned when the server is closed.
21var ErrServerClosed = fmt.Errorf("server closed")
22
23// InstanceState represents the state of a running [app.App] instance.
24type InstanceState uint8
25
26const (
27 // InstanceStateCreated indicates that the instance has been created but not yet started.
28 InstanceStateCreated InstanceState = iota
29 // InstanceStateStarted indicates that the instance is currently running.
30 InstanceStateStarted
31 // InstanceStateStopped indicates that the instance has been stopped.
32 InstanceStateStopped
33)
34
35// Instance represents a running [app.App] instance with its associated
36// resources and state.
37type Instance struct {
38 *app.App
39 State InstanceState
40 ln net.Listener
41 cfg *config.Config
42 id string
43 path string
44}
45
46// ID returns the unique identifier of the instance.
47func (i *Instance) ID() string {
48 return i.id
49}
50
51// Path returns the filesystem path associated with the instance.
52func (i *Instance) Path() string {
53 return i.path
54}
55
56// DefaultAddr returns the default address path for the Crush server based on
57// the operating system.
58func DefaultAddr() string {
59 sockPath := "crush.sock"
60 user, err := user.Current()
61 if err == nil && user.Uid != "" {
62 sockPath = fmt.Sprintf("crush-%s.sock", user.Uid)
63 }
64 if runtime.GOOS == "windows" {
65 return fmt.Sprintf(`\\.\pipe\%s`, sockPath)
66 }
67 return filepath.Join(os.TempDir(), sockPath)
68}
69
70// Server represents a Crush server instance bound to a specific address.
71type Server struct {
72 // Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
73 Addr string
74
75 h *http.Server
76 ln net.Listener
77 ctx context.Context
78
79 // instances is a map of running applications managed by the server.
80 instances *csync.Map[string, *Instance]
81 cfg *config.Config
82 logger *slog.Logger
83}
84
85// SetLogger sets the logger for the server.
86func (s *Server) SetLogger(logger *slog.Logger) {
87 s.logger = logger
88}
89
90// DefaultServer returns a new [Server] instance with the default address.
91func DefaultServer(cfg *config.Config) *Server {
92 return NewServer(cfg, "unix", DefaultAddr())
93}
94
95// NewServer is a helper to create a new [Server] instance with the given
96// address. On Windows, if the address is not a "tcp" address, it will be
97// converted to a named pipe format.
98func NewServer(cfg *config.Config, network, address string) *Server {
99 if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
100 !strings.HasPrefix(address, `\\.\pipe\`) {
101 // On Windows, convert to named pipe format if not TCP
102 // (e.g., "mypipe" -> "\\.\pipe\mypipe")
103 address = fmt.Sprintf(`\\.\pipe\%s`, address)
104 }
105
106 s := new(Server)
107 s.Addr = address
108 s.cfg = cfg
109 s.instances = csync.NewMap[string, *Instance]()
110 s.ctx = context.Background()
111
112 var p http.Protocols
113 p.SetHTTP1(true)
114 p.SetUnencryptedHTTP2(true)
115 c := &controllerV1{Server: s}
116 mux := http.NewServeMux()
117 mux.HandleFunc("GET /v1/config", c.handleGetConfig)
118 mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
119 mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
120 mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances)
121 mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
122 mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
123 mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
124 mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
125 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
126 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
127 mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
128 mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
129 mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
130 mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
131 mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
132 mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
133 mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
134 mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
135 mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
136 mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
137 mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
138 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
139 mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
140 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
141 mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
142 s.h = &http.Server{
143 Protocols: &p,
144 Handler: s.loggingHandler(mux),
145 }
146 return s
147}
148
149// Serve accepts incoming connections on the listener.
150func (s *Server) Serve(ln net.Listener) error {
151 return s.h.Serve(ln)
152}
153
154// ListenAndServe starts the server and begins accepting connections.
155func (s *Server) ListenAndServe() error {
156 if s.ln != nil {
157 return fmt.Errorf("server already started")
158 }
159 ln, err := listen("unix", s.Addr)
160 if err != nil {
161 return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
162 }
163 return s.Serve(ln)
164}
165
166func (s *Server) closeListener() {
167 if s.ln != nil {
168 s.ln.Close()
169 s.ln = nil
170 }
171}
172
173// Close force close all listeners and connections.
174func (s *Server) Close() error {
175 defer func() { s.closeListener() }()
176 return s.h.Close()
177}
178
179// Shutdown gracefully shuts down the server without interrupting active
180// connections. It stops accepting new connections and waits for existing
181// connections to finish.
182func (s *Server) Shutdown(ctx context.Context) error {
183 defer func() { s.closeListener() }()
184 return s.h.Shutdown(ctx)
185}
186
187func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
188 if s.logger != nil {
189 s.logger.With(
190 slog.String("method", r.Method),
191 slog.String("url", r.URL.String()),
192 slog.String("remote_addr", r.RemoteAddr),
193 ).Debug(msg, args...)
194 }
195}
196
197func (s *Server) logError(r *http.Request, msg string, args ...any) {
198 if s.logger != nil {
199 s.logger.With(
200 slog.String("method", r.Method),
201 slog.String("url", r.URL.String()),
202 slog.String("remote_addr", r.RemoteAddr),
203 ).Error(msg, args...)
204 }
205}