proto.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"net/http"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12
 13	"github.com/charmbracelet/crush/internal/app"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/lsp"
 17	"github.com/charmbracelet/crush/internal/proto"
 18	"github.com/charmbracelet/crush/internal/session"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/google/uuid"
 21)
 22
 23type controllerV1 struct {
 24	*Server
 25}
 26
 27func (c *controllerV1) handleGetHealth(w http.ResponseWriter, r *http.Request) {
 28	w.WriteHeader(http.StatusOK)
 29}
 30
 31func (c *controllerV1) handleGetVersion(w http.ResponseWriter, r *http.Request) {
 32	jsonEncode(w, proto.VersionInfo{
 33		Version:   version.Version,
 34		Commit:    version.Commit,
 35		GoVersion: runtime.Version(),
 36		Platform:  fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH),
 37	})
 38}
 39
 40func (c *controllerV1) handlePostControl(w http.ResponseWriter, r *http.Request) {
 41	var req proto.ServerControl
 42	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 43		c.logError(r, "failed to decode request", "error", err)
 44		jsonError(w, http.StatusBadRequest, "failed to decode request")
 45		return
 46	}
 47
 48	switch req.Command {
 49	case "shutdown":
 50		go func() {
 51			slog.Info("shutting down server...")
 52			if err := c.Shutdown(context.Background()); err != nil {
 53				c.logError(r, "failed to shutdown server", "error", err)
 54			}
 55		}()
 56	default:
 57		c.logError(r, "unknown command", "command", req.Command)
 58		jsonError(w, http.StatusBadRequest, "unknown command")
 59		return
 60	}
 61}
 62
 63func (c *controllerV1) handleGetConfig(w http.ResponseWriter, r *http.Request) {
 64	jsonEncode(w, c.cfg)
 65}
 66
 67func (c *controllerV1) handleGetInstances(w http.ResponseWriter, r *http.Request) {
 68	instances := []proto.Instance{}
 69	for _, ins := range c.instances.Seq2() {
 70		// TODO: implement pagination?
 71		instances = append(instances, proto.Instance{
 72			ID:      ins.id,
 73			Path:    ins.path,
 74			YOLO:    ins.cfg.Permissions != nil && ins.cfg.Permissions.SkipRequests,
 75			DataDir: ins.cfg.Options.DataDirectory,
 76			Debug:   ins.cfg.Options.Debug,
 77			Config:  ins.cfg,
 78		})
 79	}
 80	jsonEncode(w, instances)
 81}
 82
 83func (c *controllerV1) handleGetInstanceLSPDiagnostics(w http.ResponseWriter, r *http.Request) {
 84	id := r.PathValue("id")
 85	ins, ok := c.instances.Get(id)
 86	if !ok {
 87		c.logError(r, "instance not found", "id", id)
 88		jsonError(w, http.StatusNotFound, "instance not found")
 89		return
 90	}
 91
 92	var lsp *lsp.Client
 93	lspName := r.PathValue("lsp")
 94	for name, client := range ins.LSPClients.Seq2() {
 95		if name == lspName {
 96			lsp = client
 97			break
 98		}
 99	}
100
101	if lsp == nil {
102		c.logError(r, "LSP client not found", "id", id, "lsp", lspName)
103		jsonError(w, http.StatusNotFound, "LSP client not found")
104		return
105	}
106
107	diagnostics := lsp.GetDiagnostics()
108	jsonEncode(w, diagnostics)
109}
110
111func (c *controllerV1) handleGetInstanceLSPs(w http.ResponseWriter, r *http.Request) {
112	id := r.PathValue("id")
113	ins, ok := c.instances.Get(id)
114	if !ok {
115		c.logError(r, "instance not found", "id", id)
116		jsonError(w, http.StatusNotFound, "instance not found")
117		return
118	}
119
120	lspClients := ins.GetLSPStates()
121	jsonEncode(w, lspClients)
122}
123
124func (c *controllerV1) handleGetInstanceAgentSessionPromptQueued(w http.ResponseWriter, r *http.Request) {
125	id := r.PathValue("id")
126	ins, ok := c.instances.Get(id)
127	if !ok {
128		c.logError(r, "instance not found", "id", id)
129		jsonError(w, http.StatusNotFound, "instance not found")
130		return
131	}
132
133	sid := r.PathValue("sid")
134	queued := ins.App.CoderAgent.QueuedPrompts(sid)
135	jsonEncode(w, queued)
136}
137
138func (c *controllerV1) handlePostInstanceAgentSessionPromptClear(w http.ResponseWriter, r *http.Request) {
139	id := r.PathValue("id")
140	ins, ok := c.instances.Get(id)
141	if !ok {
142		c.logError(r, "instance not found", "id", id)
143		jsonError(w, http.StatusNotFound, "instance not found")
144		return
145	}
146
147	sid := r.PathValue("sid")
148	ins.App.CoderAgent.ClearQueue(sid)
149}
150
151func (c *controllerV1) handleGetInstanceAgentSessionSummarize(w http.ResponseWriter, r *http.Request) {
152	id := r.PathValue("id")
153	ins, ok := c.instances.Get(id)
154	if !ok {
155		c.logError(r, "instance not found", "id", id)
156		jsonError(w, http.StatusNotFound, "instance not found")
157		return
158	}
159
160	sid := r.PathValue("sid")
161	if err := ins.App.CoderAgent.Summarize(r.Context(), sid); err != nil {
162		c.logError(r, "failed to summarize session", "error", err, "id", id, "sid", sid)
163		jsonError(w, http.StatusInternalServerError, "failed to summarize session")
164		return
165	}
166}
167
168func (c *controllerV1) handlePostInstanceAgentSessionCancel(w http.ResponseWriter, r *http.Request) {
169	id := r.PathValue("id")
170	ins, ok := c.instances.Get(id)
171	if !ok {
172		c.logError(r, "instance not found", "id", id)
173		jsonError(w, http.StatusNotFound, "instance not found")
174		return
175	}
176
177	sid := r.PathValue("sid")
178	if ins.App.CoderAgent != nil {
179		ins.App.CoderAgent.Cancel(sid)
180	}
181}
182
183func (c *controllerV1) handleGetInstanceAgentSession(w http.ResponseWriter, r *http.Request) {
184	id := r.PathValue("id")
185	ins, ok := c.instances.Get(id)
186	if !ok {
187		c.logError(r, "instance not found", "id", id)
188		jsonError(w, http.StatusNotFound, "instance not found")
189		return
190	}
191
192	sid := r.PathValue("sid")
193	se, err := ins.App.Sessions.Get(r.Context(), sid)
194	if err != nil {
195		c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid)
196		jsonError(w, http.StatusInternalServerError, "failed to get session")
197		return
198	}
199
200	var isSessionBusy bool
201	if ins.App.CoderAgent != nil {
202		isSessionBusy = ins.App.CoderAgent.IsSessionBusy(sid)
203	}
204
205	jsonEncode(w, proto.AgentSession{
206		Session: se,
207		IsBusy:  isSessionBusy,
208	})
209}
210
211func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Request) {
212	id := r.PathValue("id")
213	ins, ok := c.instances.Get(id)
214	if !ok {
215		c.logError(r, "instance not found", "id", id)
216		jsonError(w, http.StatusNotFound, "instance not found")
217		return
218	}
219
220	w.Header().Set("Accept", "application/json")
221
222	var msg proto.AgentMessage
223	if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
224		c.logError(r, "failed to decode request", "error", err)
225		jsonError(w, http.StatusBadRequest, "failed to decode request")
226		return
227	}
228
229	if ins.App.CoderAgent == nil {
230		c.logError(r, "coder agent not initialized", "id", id)
231		jsonError(w, http.StatusBadRequest, "coder agent not initialized")
232		return
233	}
234
235	// NOTE: This needs to be on the server's context because the agent runs
236	// the request asynchronously.
237	// TODO: Look into this one more and make it work synchronously.
238	if _, err := ins.App.CoderAgent.Run(c.ctx, msg.SessionID, msg.Prompt, msg.Attachments...); err != nil {
239		c.logError(r, "failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID)
240		jsonError(w, http.StatusInternalServerError, "failed to enqueue message")
241		return
242	}
243}
244
245func (c *controllerV1) handleGetInstanceAgent(w http.ResponseWriter, r *http.Request) {
246	id := r.PathValue("id")
247	ins, ok := c.instances.Get(id)
248	if !ok {
249		c.logError(r, "instance not found", "id", id)
250		jsonError(w, http.StatusNotFound, "instance not found")
251		return
252	}
253
254	var agentInfo proto.AgentInfo
255	if ins.App.CoderAgent != nil {
256		agentInfo = proto.AgentInfo{
257			Model:  ins.App.CoderAgent.Model(),
258			IsBusy: ins.App.CoderAgent.IsBusy(),
259		}
260	}
261	jsonEncode(w, agentInfo)
262}
263
264func (c *controllerV1) handlePostInstanceAgentUpdate(w http.ResponseWriter, r *http.Request) {
265	id := r.PathValue("id")
266	ins, ok := c.instances.Get(id)
267	if !ok {
268		c.logError(r, "instance not found", "id", id)
269		jsonError(w, http.StatusNotFound, "instance not found")
270		return
271	}
272
273	if err := ins.App.UpdateAgentModel(); err != nil {
274		c.logError(r, "failed to update agent model", "error", err)
275		jsonError(w, http.StatusInternalServerError, "failed to update agent model")
276		return
277	}
278}
279
280func (c *controllerV1) handlePostInstanceAgentInit(w http.ResponseWriter, r *http.Request) {
281	id := r.PathValue("id")
282	ins, ok := c.instances.Get(id)
283	if !ok {
284		c.logError(r, "instance not found", "id", id)
285		jsonError(w, http.StatusNotFound, "instance not found")
286		return
287	}
288
289	if err := ins.App.InitCoderAgent(); err != nil {
290		c.logError(r, "failed to initialize coder agent", "error", err)
291		jsonError(w, http.StatusInternalServerError, "failed to initialize coder agent")
292		return
293	}
294}
295
296func (c *controllerV1) handleGetInstanceSessionHistory(w http.ResponseWriter, r *http.Request) {
297	id := r.PathValue("id")
298	ins, ok := c.instances.Get(id)
299	if !ok {
300		c.logError(r, "instance not found", "id", id)
301		jsonError(w, http.StatusNotFound, "instance not found")
302		return
303	}
304
305	sid := r.PathValue("sid")
306	historyItems, err := ins.App.History.ListBySession(r.Context(), sid)
307	if err != nil {
308		c.logError(r, "failed to list history", "error", err, "id", id, "sid", sid)
309		jsonError(w, http.StatusInternalServerError, "failed to list history")
310		return
311	}
312
313	jsonEncode(w, historyItems)
314}
315
316func (c *controllerV1) handleGetInstanceSessionMessages(w http.ResponseWriter, r *http.Request) {
317	id := r.PathValue("id")
318	ins, ok := c.instances.Get(id)
319	if !ok {
320		c.logError(r, "instance not found", "id", id)
321		jsonError(w, http.StatusNotFound, "instance not found")
322		return
323	}
324
325	sid := r.PathValue("sid")
326	messages, err := ins.App.Messages.List(r.Context(), sid)
327	if err != nil {
328		c.logError(r, "failed to list messages", "error", err, "id", id, "sid", sid)
329		jsonError(w, http.StatusInternalServerError, "failed to list messages")
330		return
331	}
332
333	jsonEncode(w, messages)
334}
335
336func (c *controllerV1) handleGetInstanceSession(w http.ResponseWriter, r *http.Request) {
337	id := r.PathValue("id")
338	ins, ok := c.instances.Get(id)
339	if !ok {
340		c.logError(r, "instance not found", "id", id)
341		jsonError(w, http.StatusNotFound, "instance not found")
342		return
343	}
344
345	sid := r.PathValue("sid")
346	session, err := ins.App.Sessions.Get(r.Context(), sid)
347	if err != nil {
348		c.logError(r, "failedto get session", "error", err, "id", id, "sid", sid)
349		jsonError(w, http.StatusInternalServerError, "failed to get session")
350		return
351	}
352
353	jsonEncode(w, session)
354}
355
356func (c *controllerV1) handlePostInstanceSessions(w http.ResponseWriter, r *http.Request) {
357	id := r.PathValue("id")
358	ins, ok := c.instances.Get(id)
359	if !ok {
360		c.logError(r, "instance not found", "id", id)
361		jsonError(w, http.StatusNotFound, "instance not found")
362		return
363	}
364
365	var args session.Session
366	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
367		c.logError(r, "failed to decode request", "error", err)
368		jsonError(w, http.StatusBadRequest, "failed to decode request")
369		return
370	}
371
372	sess, err := ins.App.Sessions.Create(r.Context(), args.Title)
373	if err != nil {
374		c.logError(r, "failed to create session", "error", err, "id", id)
375		jsonError(w, http.StatusInternalServerError, "failed to create session")
376		return
377	}
378
379	jsonEncode(w, sess)
380}
381
382func (c *controllerV1) handleGetInstanceSessions(w http.ResponseWriter, r *http.Request) {
383	id := r.PathValue("id")
384	ins, ok := c.instances.Get(id)
385	if !ok {
386		c.logError(r, "instance not found", "id", id)
387		jsonError(w, http.StatusNotFound, "instance not found")
388		return
389	}
390
391	sessions, err := ins.App.Sessions.List(r.Context())
392	if err != nil {
393		c.logError(r, "failed to list sessions", "error", err)
394		jsonError(w, http.StatusInternalServerError, "failed to list sessions")
395		return
396	}
397
398	jsonEncode(w, sessions)
399}
400
401func (c *controllerV1) handlePostInstancePermissionsGrant(w http.ResponseWriter, r *http.Request) {
402	id := r.PathValue("id")
403	ins, ok := c.instances.Get(id)
404	if !ok {
405		c.logError(r, "instance not found", "id", id)
406		jsonError(w, http.StatusNotFound, "instance not found")
407		return
408	}
409
410	var req proto.PermissionGrant
411	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
412		c.logError(r, "failed to decode request", "error", err)
413		jsonError(w, http.StatusBadRequest, "failed to decode request")
414		return
415	}
416
417	switch req.Action {
418	case proto.PermissionAllow:
419		ins.App.Permissions.Grant(req.Permission)
420	case proto.PermissionAllowForSession:
421		ins.App.Permissions.GrantPersistent(req.Permission)
422	case proto.PermissionDeny:
423		ins.App.Permissions.Deny(req.Permission)
424	default:
425		c.logError(r, "invalid permission action", "action", req.Action)
426		jsonError(w, http.StatusBadRequest, "invalid permission action")
427		return
428	}
429}
430
431func (c *controllerV1) handlePostInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) {
432	id := r.PathValue("id")
433	ins, ok := c.instances.Get(id)
434	if !ok {
435		c.logError(r, "instance not found", "id", id)
436		jsonError(w, http.StatusNotFound, "instance not found")
437		return
438	}
439
440	var req proto.PermissionSkipRequest
441	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
442		c.logError(r, "failed to decode request", "error", err)
443		jsonError(w, http.StatusBadRequest, "failed to decode request")
444		return
445	}
446
447	ins.App.Permissions.SetSkipRequests(req.Skip)
448}
449
450func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) {
451	id := r.PathValue("id")
452	ins, ok := c.instances.Get(id)
453	if !ok {
454		c.logError(r, "instance not found", "id", id)
455		jsonError(w, http.StatusNotFound, "instance not found")
456		return
457	}
458
459	skip := ins.App.Permissions.SkipRequests()
460	jsonEncode(w, proto.PermissionSkipRequest{Skip: skip})
461}
462
463func (c *controllerV1) handleGetInstanceProviders(w http.ResponseWriter, r *http.Request) {
464	id := r.PathValue("id")
465	ins, ok := c.instances.Get(id)
466	if !ok {
467		c.logError(r, "instance not found", "id", id)
468		jsonError(w, http.StatusNotFound, "instance not found")
469		return
470	}
471
472	providers, _ := config.Providers(ins.cfg)
473	jsonEncode(w, providers)
474}
475
476func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
477	flusher := http.NewResponseController(w)
478	id := r.PathValue("id")
479	ins, ok := c.instances.Get(id)
480	if !ok {
481		c.logError(r, "instance not found", "id", id)
482		jsonError(w, http.StatusNotFound, "instance not found")
483		return
484	}
485
486	w.Header().Set("Content-Type", "text/event-stream")
487	w.Header().Set("Cache-Control", "no-cache")
488	w.Header().Set("Connection", "keep-alive")
489
490	for {
491		select {
492		case <-r.Context().Done():
493			c.logDebug(r, "stopping event stream")
494			return
495		case ev := <-ins.App.Events():
496			c.logDebug(r, "sending event", "event", fmt.Sprintf("%T %+v", ev, ev))
497			data, err := json.Marshal(ev)
498			if err != nil {
499				c.logError(r, "failed to marshal event", "error", err)
500				continue
501			}
502
503			fmt.Fprintf(w, "data: %s\n\n", data)
504			flusher.Flush()
505		}
506	}
507}
508
509func (c *controllerV1) handleGetInstanceConfig(w http.ResponseWriter, r *http.Request) {
510	id := r.PathValue("id")
511	ins, ok := c.instances.Get(id)
512	if !ok {
513		c.logError(r, "instance not found", "id", id)
514		jsonError(w, http.StatusNotFound, "instance not found")
515		return
516	}
517
518	jsonEncode(w, ins.cfg)
519}
520
521func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
522	id := r.PathValue("id")
523	ins, ok := c.instances.Get(id)
524	if ok {
525		ins.App.Shutdown()
526	}
527	c.instances.Del(id)
528}
529
530func (c *controllerV1) handleGetInstance(w http.ResponseWriter, r *http.Request) {
531	id := r.PathValue("id")
532	ins, ok := c.instances.Get(id)
533	if !ok {
534		c.logError(r, "instance not found", "id", id)
535		jsonError(w, http.StatusNotFound, "instance not found")
536		return
537	}
538
539	jsonEncode(w, proto.Instance{
540		ID:      ins.id,
541		Path:    ins.path,
542		YOLO:    ins.cfg.Permissions != nil && ins.cfg.Permissions.SkipRequests,
543		DataDir: ins.cfg.Options.DataDirectory,
544		Debug:   ins.cfg.Options.Debug,
545		Config:  ins.cfg,
546	})
547}
548
549func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Request) {
550	var args proto.Instance
551	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
552		c.logError(r, "failed to decode request", "error", err)
553		jsonError(w, http.StatusBadRequest, "failed to decode request")
554		return
555	}
556
557	if args.Path == "" {
558		c.logError(r, "path is required")
559		jsonError(w, http.StatusBadRequest, "path is required")
560		return
561	}
562
563	id := uuid.New().String()
564	cfg, err := config.Init(args.Path, args.DataDir, args.Debug, args.Env)
565	if err != nil {
566		c.logError(r, "failed to initialize config", "error", err)
567		jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err))
568		return
569	}
570
571	if cfg.Permissions == nil {
572		cfg.Permissions = &config.Permissions{}
573	}
574	cfg.Permissions.SkipRequests = args.YOLO
575
576	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
577		c.logError(r, "failed to create data directory", "error", err)
578		jsonError(w, http.StatusInternalServerError, "failed to create data directory")
579		return
580	}
581
582	// Connect to DB; this will also run migrations.
583	conn, err := db.Connect(c.ctx, cfg.Options.DataDirectory)
584	if err != nil {
585		c.logError(r, "failed to connect to database", "error", err)
586		jsonError(w, http.StatusInternalServerError, "failed to connect to database")
587		return
588	}
589
590	appInstance, err := app.New(c.ctx, conn, cfg)
591	if err != nil {
592		slog.Error("failed to create app instance", "error", err)
593		jsonError(w, http.StatusInternalServerError, "failed to create app instance")
594		return
595	}
596
597	ins := &Instance{
598		App:  appInstance,
599		id:   id,
600		path: args.Path,
601		cfg:  cfg,
602		env:  args.Env,
603	}
604
605	c.instances.Set(id, ins)
606	jsonEncode(w, proto.Instance{
607		ID:      id,
608		Path:    args.Path,
609		DataDir: cfg.Options.DataDirectory,
610		Debug:   cfg.Options.Debug,
611		YOLO:    cfg.Permissions.SkipRequests,
612		Config:  cfg,
613		Env:     args.Env,
614	})
615}
616
617func createDotCrushDir(dir string) error {
618	if err := os.MkdirAll(dir, 0o700); err != nil {
619		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
620	}
621
622	gitIgnorePath := filepath.Join(dir, ".gitignore")
623	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
624		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
625			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
626		}
627	}
628
629	return nil
630}
631
632func jsonEncode(w http.ResponseWriter, v any) {
633	w.Header().Set("Content-Type", "application/json")
634	_ = json.NewEncoder(w).Encode(v)
635}
636
637func jsonError(w http.ResponseWriter, status int, message string) {
638	w.Header().Set("Content-Type", "application/json")
639	w.WriteHeader(status)
640	_ = json.NewEncoder(w).Encode(proto.Error{Message: message})
641}