proto.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"net/http"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12
 13	"github.com/charmbracelet/crush/internal/app"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/lsp"
 17	"github.com/charmbracelet/crush/internal/proto"
 18	"github.com/charmbracelet/crush/internal/session"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/google/uuid"
 21)
 22
 23type controllerV1 struct {
 24	*Server
 25}
 26
 27func (c *controllerV1) handleGetHealth(w http.ResponseWriter, r *http.Request) {
 28	w.WriteHeader(http.StatusOK)
 29}
 30
 31func (c *controllerV1) handleGetVersion(w http.ResponseWriter, r *http.Request) {
 32	jsonEncode(w, proto.VersionInfo{
 33		Version:   version.Version,
 34		Commit:    version.Commit,
 35		GoVersion: runtime.Version(),
 36		Platform:  fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH),
 37	})
 38}
 39
 40func (c *controllerV1) handlePostControl(w http.ResponseWriter, r *http.Request) {
 41	var req proto.ServerControl
 42	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 43		c.logError(r, "failed to decode request", "error", err)
 44		jsonError(w, http.StatusBadRequest, "failed to decode request")
 45		return
 46	}
 47
 48	switch req.Command {
 49	case "shutdown":
 50		go func() {
 51			slog.Info("shutting down server...")
 52			if err := c.Shutdown(context.Background()); err != nil {
 53				c.logError(r, "failed to shutdown server", "error", err)
 54			}
 55		}()
 56	default:
 57		c.logError(r, "unknown command", "command", req.Command)
 58		jsonError(w, http.StatusBadRequest, "unknown command")
 59		return
 60	}
 61}
 62
 63func (c *controllerV1) handleGetConfig(w http.ResponseWriter, r *http.Request) {
 64	jsonEncode(w, c.cfg)
 65}
 66
 67func (c *controllerV1) handleGetInstances(w http.ResponseWriter, r *http.Request) {
 68	instances := []proto.Instance{}
 69	for _, ins := range c.instances.Seq2() {
 70		// TODO: implement pagination?
 71		instances = append(instances, proto.Instance{
 72			ID:      ins.id,
 73			Path:    ins.path,
 74			YOLO:    ins.cfg.Permissions != nil && ins.cfg.Permissions.SkipRequests,
 75			DataDir: ins.cfg.Options.DataDirectory,
 76			Debug:   ins.cfg.Options.Debug,
 77			Config:  ins.cfg,
 78		})
 79	}
 80	jsonEncode(w, instances)
 81}
 82
 83func (c *controllerV1) handleGetInstanceLSPDiagnostics(w http.ResponseWriter, r *http.Request) {
 84	id := r.PathValue("id")
 85	ins, ok := c.instances.Get(id)
 86	if !ok {
 87		c.logError(r, "instance not found", "id", id)
 88		jsonError(w, http.StatusNotFound, "instance not found")
 89		return
 90	}
 91
 92	var lsp *lsp.Client
 93	lspName := r.PathValue("lsp")
 94	for name, client := range ins.LSPClients.Seq2() {
 95		if name == lspName {
 96			lsp = client
 97			break
 98		}
 99	}
100
101	if lsp == nil {
102		c.logError(r, "LSP client not found", "id", id, "lsp", lspName)
103		jsonError(w, http.StatusNotFound, "LSP client not found")
104		return
105	}
106
107	diagnostics := lsp.GetDiagnostics()
108	jsonEncode(w, diagnostics)
109}
110
111func (c *controllerV1) handleGetInstanceLSPs(w http.ResponseWriter, r *http.Request) {
112	id := r.PathValue("id")
113	ins, ok := c.instances.Get(id)
114	if !ok {
115		c.logError(r, "instance not found", "id", id)
116		jsonError(w, http.StatusNotFound, "instance not found")
117		return
118	}
119
120	lspClients := ins.GetLSPStates()
121	jsonEncode(w, lspClients)
122}
123
124func (c *controllerV1) handleGetInstanceAgentSessionPromptQueued(w http.ResponseWriter, r *http.Request) {
125	id := r.PathValue("id")
126	ins, ok := c.instances.Get(id)
127	if !ok {
128		c.logError(r, "instance not found", "id", id)
129		jsonError(w, http.StatusNotFound, "instance not found")
130		return
131	}
132
133	sid := r.PathValue("sid")
134	queued := ins.App.CoderAgent.QueuedPrompts(sid)
135	jsonEncode(w, queued)
136}
137
138func (c *controllerV1) handlePostInstanceAgentSessionPromptClear(w http.ResponseWriter, r *http.Request) {
139	id := r.PathValue("id")
140	ins, ok := c.instances.Get(id)
141	if !ok {
142		c.logError(r, "instance not found", "id", id)
143		jsonError(w, http.StatusNotFound, "instance not found")
144		return
145	}
146
147	sid := r.PathValue("sid")
148	ins.App.CoderAgent.ClearQueue(sid)
149}
150
151func (c *controllerV1) handleGetInstanceAgentSessionSummarize(w http.ResponseWriter, r *http.Request) {
152	id := r.PathValue("id")
153	ins, ok := c.instances.Get(id)
154	if !ok {
155		c.logError(r, "instance not found", "id", id)
156		jsonError(w, http.StatusNotFound, "instance not found")
157		return
158	}
159
160	sid := r.PathValue("sid")
161	if err := ins.App.CoderAgent.Summarize(r.Context(), sid); err != nil {
162		c.logError(r, "failed to summarize session", "error", err, "id", id, "sid", sid)
163		jsonError(w, http.StatusInternalServerError, "failed to summarize session")
164		return
165	}
166}
167
168func (c *controllerV1) handlePostInstanceAgentSessionCancel(w http.ResponseWriter, r *http.Request) {
169	id := r.PathValue("id")
170	ins, ok := c.instances.Get(id)
171	if !ok {
172		c.logError(r, "instance not found", "id", id)
173		jsonError(w, http.StatusNotFound, "instance not found")
174		return
175	}
176
177	sid := r.PathValue("sid")
178	if ins.App.CoderAgent != nil {
179		ins.App.CoderAgent.Cancel(sid)
180	}
181}
182
183func (c *controllerV1) handleGetInstanceAgentSession(w http.ResponseWriter, r *http.Request) {
184	id := r.PathValue("id")
185	ins, ok := c.instances.Get(id)
186	if !ok {
187		c.logError(r, "instance not found", "id", id)
188		jsonError(w, http.StatusNotFound, "instance not found")
189		return
190	}
191
192	sid := r.PathValue("sid")
193	se, err := ins.App.Sessions.Get(r.Context(), sid)
194	if err != nil {
195		c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid)
196		jsonError(w, http.StatusInternalServerError, "failed to get session")
197		return
198	}
199
200	var isSessionBusy bool
201	if ins.App.CoderAgent != nil {
202		isSessionBusy = ins.App.CoderAgent.IsSessionBusy(sid)
203	}
204
205	jsonEncode(w, proto.AgentSession{
206		Session: se,
207		IsBusy:  isSessionBusy,
208	})
209}
210
211func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Request) {
212	id := r.PathValue("id")
213	ins, ok := c.instances.Get(id)
214	if !ok {
215		c.logError(r, "instance not found", "id", id)
216		jsonError(w, http.StatusNotFound, "instance not found")
217		return
218	}
219
220	w.Header().Set("Accept", "application/json")
221
222	var msg proto.AgentMessage
223	if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
224		c.logError(r, "failed to decode request", "error", err)
225		jsonError(w, http.StatusBadRequest, "failed to decode request")
226		return
227	}
228
229	if ins.App.CoderAgent == nil {
230		c.logError(r, "coder agent not initialized", "id", id)
231		jsonError(w, http.StatusBadRequest, "coder agent not initialized")
232		return
233	}
234
235	// NOTE: This needs to be on the server's context because the agent runs
236	// the request asynchronously.
237	// TODO: Look into this one more and make it work synchronously.
238	if _, err := ins.App.CoderAgent.Run(c.ctx, msg.SessionID, msg.Prompt, msg.Attachments...); err != nil {
239		c.logError(r, "failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID)
240		jsonError(w, http.StatusInternalServerError, "failed to enqueue message")
241		return
242	}
243}
244
245func (c *controllerV1) handleGetInstanceAgent(w http.ResponseWriter, r *http.Request) {
246	id := r.PathValue("id")
247	ins, ok := c.instances.Get(id)
248	if !ok {
249		c.logError(r, "instance not found", "id", id)
250		jsonError(w, http.StatusNotFound, "instance not found")
251		return
252	}
253
254	var agentInfo proto.AgentInfo
255	if ins.App.CoderAgent != nil {
256		agentInfo = proto.AgentInfo{
257			Model:  ins.App.CoderAgent.Model(),
258			IsBusy: ins.App.CoderAgent.IsBusy(),
259		}
260	}
261	jsonEncode(w, agentInfo)
262}
263
264func (c *controllerV1) handlePostInstanceAgentUpdate(w http.ResponseWriter, r *http.Request) {
265	id := r.PathValue("id")
266	ins, ok := c.instances.Get(id)
267	if !ok {
268		c.logError(r, "instance not found", "id", id)
269		jsonError(w, http.StatusNotFound, "instance not found")
270		return
271	}
272
273	if err := ins.App.UpdateAgentModel(); err != nil {
274		c.logError(r, "failed to update agent model", "error", err)
275		jsonError(w, http.StatusInternalServerError, "failed to update agent model")
276		return
277	}
278}
279
280func (c *controllerV1) handlePostInstanceAgentInit(w http.ResponseWriter, r *http.Request) {
281	id := r.PathValue("id")
282	ins, ok := c.instances.Get(id)
283	if !ok {
284		c.logError(r, "instance not found", "id", id)
285		jsonError(w, http.StatusNotFound, "instance not found")
286		return
287	}
288
289	if err := ins.App.InitCoderAgent(); err != nil {
290		c.logError(r, "failed to initialize coder agent", "error", err)
291		jsonError(w, http.StatusInternalServerError, "failed to initialize coder agent")
292		return
293	}
294}
295
296func (c *controllerV1) handleGetInstanceSessionHistory(w http.ResponseWriter, r *http.Request) {
297	id := r.PathValue("id")
298	ins, ok := c.instances.Get(id)
299	if !ok {
300		c.logError(r, "instance not found", "id", id)
301		jsonError(w, http.StatusNotFound, "instance not found")
302		return
303	}
304
305	sid := r.PathValue("sid")
306	historyItems, err := ins.App.History.ListBySession(r.Context(), sid)
307	if err != nil {
308		c.logError(r, "failed to list history", "error", err, "id", id, "sid", sid)
309		jsonError(w, http.StatusInternalServerError, "failed to list history")
310		return
311	}
312
313	jsonEncode(w, historyItems)
314}
315
316func (c *controllerV1) handleGetInstanceSessionMessages(w http.ResponseWriter, r *http.Request) {
317	id := r.PathValue("id")
318	ins, ok := c.instances.Get(id)
319	if !ok {
320		c.logError(r, "instance not found", "id", id)
321		jsonError(w, http.StatusNotFound, "instance not found")
322		return
323	}
324
325	sid := r.PathValue("sid")
326	messages, err := ins.App.Messages.List(r.Context(), sid)
327	if err != nil {
328		c.logError(r, "failed to list messages", "error", err, "id", id, "sid", sid)
329		jsonError(w, http.StatusInternalServerError, "failed to list messages")
330		return
331	}
332
333	jsonEncode(w, messages)
334}
335
336func (c *controllerV1) handleGetInstanceSession(w http.ResponseWriter, r *http.Request) {
337	id := r.PathValue("id")
338	ins, ok := c.instances.Get(id)
339	if !ok {
340		c.logError(r, "instance not found", "id", id)
341		jsonError(w, http.StatusNotFound, "instance not found")
342		return
343	}
344
345	sid := r.PathValue("sid")
346	session, err := ins.App.Sessions.Get(r.Context(), sid)
347	if err != nil {
348		c.logError(r, "failedto get session", "error", err, "id", id, "sid", sid)
349		jsonError(w, http.StatusInternalServerError, "failed to get session")
350		return
351	}
352
353	jsonEncode(w, session)
354}
355
356func (c *controllerV1) handlePostInstanceSessions(w http.ResponseWriter, r *http.Request) {
357	id := r.PathValue("id")
358	ins, ok := c.instances.Get(id)
359	if !ok {
360		c.logError(r, "instance not found", "id", id)
361		jsonError(w, http.StatusNotFound, "instance not found")
362		return
363	}
364
365	var args session.Session
366	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
367		c.logError(r, "failed to decode request", "error", err)
368		jsonError(w, http.StatusBadRequest, "failed to decode request")
369		return
370	}
371
372	sess, err := ins.App.Sessions.Create(r.Context(), args.Title)
373	if err != nil {
374		c.logError(r, "failed to create session", "error", err, "id", id)
375		jsonError(w, http.StatusInternalServerError, "failed to create session")
376		return
377	}
378
379	jsonEncode(w, sess)
380}
381
382func (c *controllerV1) handleGetInstanceSessions(w http.ResponseWriter, r *http.Request) {
383	id := r.PathValue("id")
384	ins, ok := c.instances.Get(id)
385	if !ok {
386		c.logError(r, "instance not found", "id", id)
387		jsonError(w, http.StatusNotFound, "instance not found")
388		return
389	}
390
391	sessions, err := ins.App.Sessions.List(r.Context())
392	if err != nil {
393		c.logError(r, "failed to list sessions", "error", err)
394		jsonError(w, http.StatusInternalServerError, "failed to list sessions")
395		return
396	}
397
398	jsonEncode(w, sessions)
399}
400
401func (c *controllerV1) handlePostInstancePermissionsGrant(w http.ResponseWriter, r *http.Request) {
402	id := r.PathValue("id")
403	ins, ok := c.instances.Get(id)
404	if !ok {
405		c.logError(r, "instance not found", "id", id)
406		jsonError(w, http.StatusNotFound, "instance not found")
407		return
408	}
409
410	var req proto.PermissionGrant
411	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
412		c.logError(r, "failed to decode request", "error", err)
413		jsonError(w, http.StatusBadRequest, "failed to decode request")
414		return
415	}
416
417	switch req.Action {
418	case proto.PermissionAllow:
419		ins.App.Permissions.Grant(req.Permission)
420	case proto.PermissionAllowForSession:
421		ins.App.Permissions.GrantPersistent(req.Permission)
422	case proto.PermissionDeny:
423		ins.App.Permissions.Deny(req.Permission)
424	default:
425		c.logError(r, "invalid permission action", "action", req.Action)
426		jsonError(w, http.StatusBadRequest, "invalid permission action")
427		return
428	}
429}
430
431func (c *controllerV1) handlePostInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) {
432	id := r.PathValue("id")
433	ins, ok := c.instances.Get(id)
434	if !ok {
435		c.logError(r, "instance not found", "id", id)
436		jsonError(w, http.StatusNotFound, "instance not found")
437		return
438	}
439
440	var req proto.PermissionSkipRequest
441	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
442		c.logError(r, "failed to decode request", "error", err)
443		jsonError(w, http.StatusBadRequest, "failed to decode request")
444		return
445	}
446
447	ins.App.Permissions.SetSkipRequests(req.Skip)
448}
449
450func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) {
451	id := r.PathValue("id")
452	ins, ok := c.instances.Get(id)
453	if !ok {
454		c.logError(r, "instance not found", "id", id)
455		jsonError(w, http.StatusNotFound, "instance not found")
456		return
457	}
458
459	skip := ins.App.Permissions.SkipRequests()
460	jsonEncode(w, proto.PermissionSkipRequest{Skip: skip})
461}
462
463func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
464	flusher := http.NewResponseController(w)
465	id := r.PathValue("id")
466	ins, ok := c.instances.Get(id)
467	if !ok {
468		c.logError(r, "instance not found", "id", id)
469		jsonError(w, http.StatusNotFound, "instance not found")
470		return
471	}
472
473	w.Header().Set("Content-Type", "text/event-stream")
474	w.Header().Set("Cache-Control", "no-cache")
475	w.Header().Set("Connection", "keep-alive")
476
477	for {
478		select {
479		case <-r.Context().Done():
480			c.logDebug(r, "stopping event stream")
481			return
482		case ev := <-ins.App.Events():
483			c.logDebug(r, "sending event", "event", fmt.Sprintf("%T %+v", ev, ev))
484			data, err := json.Marshal(ev)
485			if err != nil {
486				c.logError(r, "failed to marshal event", "error", err)
487				continue
488			}
489
490			fmt.Fprintf(w, "data: %s\n\n", data)
491			flusher.Flush()
492		}
493	}
494}
495
496func (c *controllerV1) handleGetInstanceConfig(w http.ResponseWriter, r *http.Request) {
497	id := r.PathValue("id")
498	ins, ok := c.instances.Get(id)
499	if !ok {
500		c.logError(r, "instance not found", "id", id)
501		jsonError(w, http.StatusNotFound, "instance not found")
502		return
503	}
504
505	jsonEncode(w, ins.cfg)
506}
507
508func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
509	id := r.PathValue("id")
510	ins, ok := c.instances.Get(id)
511	if ok {
512		ins.App.Shutdown()
513	}
514	c.instances.Del(id)
515}
516
517func (c *controllerV1) handleGetInstance(w http.ResponseWriter, r *http.Request) {
518	id := r.PathValue("id")
519	ins, ok := c.instances.Get(id)
520	if !ok {
521		c.logError(r, "instance not found", "id", id)
522		jsonError(w, http.StatusNotFound, "instance not found")
523		return
524	}
525
526	jsonEncode(w, proto.Instance{
527		ID:      ins.id,
528		Path:    ins.path,
529		YOLO:    ins.cfg.Permissions != nil && ins.cfg.Permissions.SkipRequests,
530		DataDir: ins.cfg.Options.DataDirectory,
531		Debug:   ins.cfg.Options.Debug,
532		Config:  ins.cfg,
533	})
534}
535
536func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Request) {
537	var args proto.Instance
538	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
539		c.logError(r, "failed to decode request", "error", err)
540		jsonError(w, http.StatusBadRequest, "failed to decode request")
541		return
542	}
543
544	if args.Path == "" {
545		c.logError(r, "path is required")
546		jsonError(w, http.StatusBadRequest, "path is required")
547		return
548	}
549
550	id := uuid.New().String()
551	cfg, err := config.Init(args.Path, args.DataDir, args.Debug)
552	if err != nil {
553		c.logError(r, "failed to initialize config", "error", err)
554		jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err))
555		return
556	}
557
558	if cfg.Permissions == nil {
559		cfg.Permissions = &config.Permissions{}
560	}
561	cfg.Permissions.SkipRequests = args.YOLO
562
563	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
564		c.logError(r, "failed to create data directory", "error", err)
565		jsonError(w, http.StatusInternalServerError, "failed to create data directory")
566		return
567	}
568
569	// Connect to DB; this will also run migrations.
570	conn, err := db.Connect(c.ctx, cfg.Options.DataDirectory)
571	if err != nil {
572		c.logError(r, "failed to connect to database", "error", err)
573		jsonError(w, http.StatusInternalServerError, "failed to connect to database")
574		return
575	}
576
577	appInstance, err := app.New(c.ctx, conn, cfg)
578	if err != nil {
579		slog.Error("failed to create app instance", "error", err)
580		jsonError(w, http.StatusInternalServerError, "failed to create app instance")
581		return
582	}
583
584	ins := &Instance{
585		App:   appInstance,
586		State: InstanceStateCreated,
587		id:    id,
588		path:  args.Path,
589		cfg:   cfg,
590	}
591
592	c.instances.Set(id, ins)
593	jsonEncode(w, proto.Instance{
594		ID:      id,
595		Path:    args.Path,
596		DataDir: cfg.Options.DataDirectory,
597		Debug:   cfg.Options.Debug,
598		YOLO:    cfg.Permissions.SkipRequests,
599		Config:  cfg,
600	})
601}
602
603func createDotCrushDir(dir string) error {
604	if err := os.MkdirAll(dir, 0o700); err != nil {
605		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
606	}
607
608	gitIgnorePath := filepath.Join(dir, ".gitignore")
609	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
610		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
611			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
612		}
613	}
614
615	return nil
616}
617
618func jsonEncode(w http.ResponseWriter, v any) {
619	w.Header().Set("Content-Type", "application/json")
620	_ = json.NewEncoder(w).Encode(v)
621}
622
623func jsonError(w http.ResponseWriter, status int, message string) {
624	w.Header().Set("Content-Type", "application/json")
625	w.WriteHeader(status)
626	_ = json.NewEncoder(w).Encode(proto.Error{Message: message})
627}