feat(server): add rest api for client-server communication

Ayman Bagabas created

Change summary

internal/server/proto.go  | 338 ++++++++++++++++++++++++++++++++++++++--
internal/server/server.go |  26 ++
2 files changed, 337 insertions(+), 27 deletions(-)

Detailed changes

internal/server/proto.go 🔗

@@ -13,16 +13,22 @@ import (
 	"github.com/charmbracelet/crush/internal/app"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/db"
+	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/proto"
+	"github.com/charmbracelet/crush/internal/session"
 )
 
-func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
-	jsonEncode(w, s.cfg)
+type controllerV1 struct {
+	*Server
 }
 
-func (s *Server) handleGetInstances(w http.ResponseWriter, r *http.Request) {
+func (c *controllerV1) handleGetConfig(w http.ResponseWriter, r *http.Request) {
+	jsonEncode(w, c.cfg)
+}
+
+func (c *controllerV1) handleGetInstances(w http.ResponseWriter, r *http.Request) {
 	instances := []proto.Instance{}
-	for _, ins := range s.instances.Seq2() {
+	for _, ins := range c.instances.Seq2() {
 		instances = append(instances, proto.Instance{
 			ID:   ins.ID(),
 			Path: ins.Path(),
@@ -32,12 +38,297 @@ func (s *Server) handleGetInstances(w http.ResponseWriter, r *http.Request) {
 	jsonEncode(w, instances)
 }
 
-func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
+func (c *controllerV1) handleGetInstanceLSPDiagnostics(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	var lsp *lsp.Client
+	lspName := r.PathValue("lsp")
+	for name, client := range ins.LSPClients.Seq2() {
+		if name == lspName {
+			lsp = client
+			break
+		}
+	}
+
+	if lsp == nil {
+		c.logError(r, "LSP client not found", "id", id, "lsp", lspName)
+		jsonError(w, http.StatusNotFound, "LSP client not found")
+		return
+	}
+
+	diagnostics := lsp.GetDiagnostics()
+	jsonEncode(w, diagnostics)
+}
+
+func (c *controllerV1) handleGetInstanceLSPs(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	lspClients := ins.GetLSPStates()
+	jsonEncode(w, lspClients)
+}
+
+func (c *controllerV1) handleGetInstanceAgentSessionPromptQueued(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	queued := ins.App.CoderAgent.QueuedPrompts(sid)
+	jsonEncode(w, queued)
+}
+
+func (c *controllerV1) handlePostInstanceAgentSessionPromptClear(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	ins.App.CoderAgent.ClearQueue(sid)
+}
+
+func (c *controllerV1) handleGetInstanceAgentSessionSummarize(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	if err := ins.App.CoderAgent.Summarize(r.Context(), sid); err != nil {
+		c.logError(r, "failed to summarize session", "error", err, "id", id, "sid", sid)
+		jsonError(w, http.StatusInternalServerError, "failed to summarize session")
+		return
+	}
+}
+
+func (c *controllerV1) handlePostInstanceAgentSessionCancel(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	if ins.App.CoderAgent != nil {
+		ins.App.CoderAgent.Cancel(sid)
+	}
+}
+
+func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	var msg proto.AgentMessage
+	if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
+		c.logError(r, "failed to decode request", "error", err)
+		jsonError(w, http.StatusBadRequest, "failed to decode request")
+		return
+	}
+
+	if ins.App.CoderAgent == nil {
+		c.logError(r, "coder agent not initialized", "id", id)
+		jsonError(w, http.StatusBadRequest, "coder agent not initialized")
+		return
+	}
+
+	if _, err := ins.App.CoderAgent.Run(r.Context(), msg.SessionID, msg.Prompt, msg.Attachments...); err != nil {
+		c.logError(r, "failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID)
+		jsonError(w, http.StatusInternalServerError, "failed to enqueue message")
+		return
+	}
+}
+
+func (c *controllerV1) handleGetInstanceAgent(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	var agentInfo proto.AgentInfo
+	if ins.App.CoderAgent != nil {
+		agentInfo = proto.AgentInfo{
+			Model:  ins.App.CoderAgent.Model(),
+			IsBusy: ins.App.CoderAgent.IsBusy(),
+		}
+	}
+	jsonEncode(w, agentInfo)
+}
+
+func (c *controllerV1) handlePostInstanceAgentUpdate(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	if err := ins.App.UpdateAgentModel(); err != nil {
+		c.logError(r, "failed to update agent model", "error", err)
+		jsonError(w, http.StatusInternalServerError, "failed to update agent model")
+		return
+	}
+}
+
+func (c *controllerV1) handlePostInstanceAgentInit(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	if err := ins.App.InitCoderAgent(); err != nil {
+		c.logError(r, "failed to initialize coder agent", "error", err)
+		jsonError(w, http.StatusInternalServerError, "failed to initialize coder agent")
+		return
+	}
+}
+
+func (c *controllerV1) handleGetInstanceSessionHistory(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	historyItems, err := ins.App.History.ListBySession(r.Context(), sid)
+	if err != nil {
+		c.logError(r, "failed to list history", "error", err, "id", id, "sid", sid)
+		jsonError(w, http.StatusInternalServerError, "failed to list history")
+		return
+	}
+
+	jsonEncode(w, historyItems)
+}
+
+func (c *controllerV1) handleGetInstanceSessionMessages(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	messages, err := ins.App.Messages.List(r.Context(), sid)
+	if err != nil {
+		c.logError(r, "failed to list messages", "error", err, "id", id, "sid", sid)
+		jsonError(w, http.StatusInternalServerError, "failed to list messages")
+		return
+	}
+
+	jsonEncode(w, messages)
+}
+
+func (c *controllerV1) handleGetInstanceSession(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	session, err := ins.App.Sessions.Get(r.Context(), sid)
+	if err != nil {
+		c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid)
+		jsonError(w, http.StatusInternalServerError, "failed to get session")
+		return
+	}
+
+	jsonEncode(w, session)
+}
+
+func (c *controllerV1) handlePostInstanceSessions(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	var args session.Session
+	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
+		c.logError(r, "failed to decode request", "error", err)
+		jsonError(w, http.StatusBadRequest, "failed to decode request")
+		return
+	}
+
+	sess, err := ins.App.Sessions.Create(r.Context(), args.Title)
+	if err != nil {
+		c.logError(r, "failed to create session", "error", err, "id", id)
+		jsonError(w, http.StatusInternalServerError, "failed to create session")
+		return
+	}
+
+	jsonEncode(w, sess)
+}
+
+func (c *controllerV1) handleGetInstanceSessions(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sessions, err := ins.App.Sessions.List(r.Context())
+	if err != nil {
+		c.logError(r, "failed to list sessions", "error", err)
+		jsonError(w, http.StatusInternalServerError, "failed to list sessions")
+		return
+	}
+
+	jsonEncode(w, sessions)
+}
+
+func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
 	flusher := http.NewResponseController(w)
 	id := r.PathValue("id")
-	ins, ok := s.instances.Get(id)
+	ins, ok := c.instances.Get(id)
 	if !ok {
-		s.logError(r, "instance not found", "id", id)
+		c.logError(r, "instance not found", "id", id)
 		jsonError(w, http.StatusNotFound, "instance not found")
 		return
 	}
@@ -53,7 +344,7 @@ func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request)
 		case ev := <-ins.App.Events():
 			data, err := json.Marshal(ev)
 			if err != nil {
-				s.logError(r, "failed to marshal event", "error", err)
+				c.logError(r, "failed to marshal event", "error", err)
 				continue
 			}
 
@@ -63,7 +354,7 @@ func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request)
 	}
 }
 
-func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
+func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
 	var ids []string
 	id := r.URL.Query().Get("id")
 	if id != "" {
@@ -73,7 +364,7 @@ func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
 	// Get IDs from body
 	var args []proto.Instance
 	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
-		s.logError(r, "failed to decode request", "error", err)
+		c.logError(r, "failed to decode request", "error", err)
 		jsonError(w, http.StatusBadRequest, "failed to decode request")
 		return
 	}
@@ -86,14 +377,14 @@ func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
 	}()...)
 
 	for _, id := range ids {
-		s.instances.Del(id)
+		c.instances.Del(id)
 	}
 }
 
-func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) {
+func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Request) {
 	var args proto.Instance
 	if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
-		s.logError(r, "failed to decode request", "error", err)
+		c.logError(r, "failed to decode request", "error", err)
 		jsonError(w, http.StatusBadRequest, "failed to decode request")
 		return
 	}
@@ -102,18 +393,21 @@ func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) {
 	hasher := sha256.New()
 	hasher.Write([]byte(filepath.Clean(args.Path)))
 	id := hex.EncodeToString(hasher.Sum(nil))
-	if existing, ok := s.instances.Get(id); ok {
+	if existing, ok := c.instances.Get(id); ok {
 		jsonEncode(w, proto.Instance{
 			ID:   existing.ID(),
 			Path: existing.Path(),
-			YOLO: existing.cfg.Permissions != nil && existing.cfg.Permissions.SkipRequests,
+			// TODO: Investigate if this makes sense.
+			YOLO:    existing.cfg.Permissions != nil && existing.cfg.Permissions.SkipRequests,
+			Debug:   existing.cfg.Options.Debug,
+			DataDir: existing.cfg.Options.DataDirectory,
 		})
 		return
 	}
 
-	cfg, err := config.Init(args.Path, s.cfg.Options.DataDirectory, s.cfg.Options.Debug)
+	cfg, err := config.Init(args.Path, args.DataDir, args.Debug)
 	if err != nil {
-		s.logError(r, "failed to initialize config", "error", err)
+		c.logError(r, "failed to initialize config", "error", err)
 		jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err))
 		return
 	}
@@ -123,16 +417,16 @@ func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) {
 	}
 	cfg.Permissions.SkipRequests = args.YOLO
 
-	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
-		s.logError(r, "failed to create data directory", "error", err)
+	if err := createDotCrushDir(args.DataDir); err != nil {
+		c.logError(r, "failed to create data directory", "error", err)
 		jsonError(w, http.StatusInternalServerError, "failed to create data directory")
 		return
 	}
 
 	// Connect to DB; this will also run migrations.
-	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
+	conn, err := db.Connect(ctx, args.DataDir)
 	if err != nil {
-		s.logError(r, "failed to connect to database", "error", err)
+		c.logError(r, "failed to connect to database", "error", err)
 		jsonError(w, http.StatusInternalServerError, "failed to connect to database")
 		return
 	}
@@ -152,7 +446,7 @@ func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) {
 		cfg:   cfg,
 	}
 
-	s.instances.Set(id, ins)
+	c.instances.Set(id, ins)
 	jsonEncode(w, proto.Instance{
 		ID:   id,
 		Path: args.Path,

internal/server/server.go 🔗

@@ -110,12 +110,28 @@ func NewServer(cfg *config.Config, network, address string) *Server {
 	var p http.Protocols
 	p.SetHTTP1(true)
 	p.SetUnencryptedHTTP2(true)
+	c := &controllerV1{Server: s}
 	mux := http.NewServeMux()
-	mux.HandleFunc("GET /v1/config", s.handleGetConfig)
-	mux.HandleFunc("GET /v1/instances", s.handleGetInstances)
-	mux.HandleFunc("POST /v1/instances", s.handlePostInstances)
-	mux.HandleFunc("DELETE /v1/instances", s.handleDeleteInstances)
-	mux.HandleFunc("GET /v1/instances/{id}/events", s.handleGetInstanceEvents)
+	mux.HandleFunc("GET /v1/config", c.handleGetConfig)
+	mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
+	mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
+	mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances)
+	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
+	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
+	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
+	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
+	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory)
+	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
+	mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
+	mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
+	mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
+	mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
+	mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
+	mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
+	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
+	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
+	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
+	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize)
 	s.h = &http.Server{
 		Protocols: &p,
 		Handler:   s.loggingHandler(mux),