diff --git a/go.mod b/go.mod index 31320cd8d28607ca8af704ae803b721389709939..3d5d476f0dc422152f4c8ea995a49e167e401f45 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,6 @@ require ( github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec github.com/google/uuid v1.6.0 - github.com/hashicorp/net-rpc-msgpackrpc/v2 v2.0.1 github.com/invopop/jsonschema v0.13.0 github.com/joho/godotenv v1.5.1 github.com/mark3labs/mcp-go v0.40.0 @@ -48,12 +47,6 @@ require ( mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5 ) -require ( - github.com/hashicorp/errwrap v1.0.0 // indirect - github.com/hashicorp/go-msgpack/v2 v2.1.3 // indirect - github.com/hashicorp/go-multierror v1.1.1 // indirect -) - require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect diff --git a/go.sum b/go.sum index 4ce162c241099724f3534b21d741f0744807ad24..3b12028e691864df7b4b0bf91abb779f55db6a33 100644 --- a/go.sum +++ b/go.sum @@ -162,14 +162,6 @@ github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-msgpack/v2 v2.1.3 h1:cB1w4Zrk0O3jQBTcFMKqYQWRFfsSQ/TYKNyUUVyCP2c= -github.com/hashicorp/go-msgpack/v2 v2.1.3/go.mod h1:SjlwKKFnwBXvxD/I1bEcfJIBbEJ+MCUn39TxymNR5ZU= -github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= -github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/hashicorp/net-rpc-msgpackrpc/v2 v2.0.1 h1:Y1sd8ZCCUUlUetCk+3MCpOwdWd+WicHdk2zk2yUM0qw= -github.com/hashicorp/net-rpc-msgpackrpc/v2 v2.0.1/go.mod h1:wASEfI5dofjm9S9Jp3JM4pfoBZy8Z07JUE2wHNi0zuc= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/internal/app/app.go b/internal/app/app.go index 2b3d81fb58acdeb2570a765c0a25ec53b65121da..249b4a392a496c127c8adf03436f83b1372de1d5 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -94,6 +94,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { return app, nil } +// Events returns the application's event channel. +func (app *App) Events() <-chan tea.Msg { + return app.events +} + // Config returns the application configuration. func (app *App) Config() *config.Config { return app.config diff --git a/internal/app/lsp_events.go b/internal/app/lsp_events.go index 08e54582b95d8db725bffc7ff8bd43d4a37528b1..8877a02a1a623af9339e660d5710881beefc75cf 100644 --- a/internal/app/lsp_events.go +++ b/internal/app/lsp_events.go @@ -18,13 +18,22 @@ const ( LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed" ) +func (e LSPEventType) MarshalText() ([]byte, error) { + return []byte(e), nil +} + +func (e *LSPEventType) UnmarshalText(data []byte) error { + *e = LSPEventType(data) + return nil +} + // LSPEvent represents an event in the LSP system type LSPEvent struct { - Type LSPEventType - Name string - State lsp.ServerState - Error error - DiagnosticCount int + Type LSPEventType `json:"type"` + Name string `json:"name"` + State lsp.ServerState `json:"state"` + Error error `json:"error,omitempty"` + DiagnosticCount int `json:"diagnostic_count,omitempty"` } // LSPClientInfo holds information about an LSP client's state diff --git a/internal/client/client.go b/internal/client/client.go index a7ed20a973a6204a056d611edcc21be3a8a10353..c60e6672947a13b788bc83351a5651f2b27d712c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -1,17 +1,19 @@ package client import ( - "net/rpc" + "context" + "encoding/json" + "net" + "net/http" + "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/server" - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" ) // Client represents an RPC client connected to a Crush server. type Client struct { - rpc *rpc.Client + h *http.Client } // DefaultClient creates a new [Client] connected to the default server address. @@ -22,20 +24,34 @@ func DefaultClient() (*Client, error) { // NewClient creates a new [Client] connected to the server at the given // network and address. func NewClient(network, address string) (*Client, error) { - rpc, err := msgpackrpc.Dial(network, address) - if err != nil { - return nil, err + var p http.Protocols + p.SetHTTP1(true) + p.SetUnencryptedHTTP2(true) + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Protocols = &p + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + return d.DialContext(ctx, network, address) + } + h := &http.Client{ + Transport: tr, } - return &Client{rpc: rpc}, nil + return &Client{h: h}, nil } // GetConfig retrieves the server's configuration via RPC. func (c *Client) GetConfig() (*config.Config, error) { var cfg config.Config - var args proto.Args - err := c.rpc.Call("ServerProto.GetConfig", &args, &cfg) + rsp, err := c.h.Get("http://localhost/v1/config") if err != nil { return nil, err } + defer rsp.Body.Close() + if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil { + return nil, err + } return &cfg, nil } diff --git a/internal/history/file.go b/internal/history/file.go index 7317f012fd83b31990bbc701261eed91794a52a5..f7c5a04b715785ac992dd0283d230daf3e5114cc 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -16,13 +16,13 @@ const ( ) type File struct { - ID string - SessionID string - Path string - Content string - Version int64 - CreatedAt int64 - UpdatedAt int64 + ID string `json:"id"` + SessionID string `json:"session_id"` + Path string `json:"path"` + Content string `json:"content"` + Version int64 `json:"version"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } type Service interface { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ec48fc2956ac5ed3baa031ba2ed4b2f905b65ae0..e9419d99f63c5ba7b22096ea2a5e4992dd3ad5d1 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -39,15 +39,24 @@ const ( AgentEventTypeSummarize AgentEventType = "summarize" ) +func (t AgentEventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +func (t *AgentEventType) UnmarshalText(text []byte) error { + *t = AgentEventType(text) + return nil +} + type AgentEvent struct { - Type AgentEventType - Message message.Message - Error error + Type AgentEventType `json:"type"` + Message message.Message `json:"message"` + Error error `json:"error,omitempty"` // When summarizing - SessionID string - Progress string - Done bool + SessionID string `json:"session_id,omitempty"` + Progress string `json:"progress,omitempty"` + Done bool `json:"done,omitempty"` } type Service interface { diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 90011c43a0fce476c119c7a981ea6760c294b806..095ea0b9a8de9e3a2e7e0565e9c59a1cf6623774 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -34,6 +34,26 @@ const ( MCPStateError ) +func (s MCPState) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s *MCPState) UnmarshalText(data []byte) error { + switch string(data) { + case "disabled": + *s = MCPStateDisabled + case "starting": + *s = MCPStateStarting + case "connected": + *s = MCPStateConnected + case "error": + *s = MCPStateError + default: + return fmt.Errorf("unknown mcp state: %s", data) + } + return nil +} + func (s MCPState) String() string { switch s { case MCPStateDisabled: @@ -56,13 +76,22 @@ const ( MCPEventStateChanged MCPEventType = "state_changed" ) +func (t MCPEventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +func (t *MCPEventType) UnmarshalText(data []byte) error { + *t = MCPEventType(data) + return nil +} + // MCPEvent represents an event in the MCP system type MCPEvent struct { - Type MCPEventType - Name string - State MCPState - Error error - ToolCount int + Type MCPEventType `json:"type"` + Name string `json:"name"` + State MCPState `json:"state"` + Error error `json:"error,omitempty"` + ToolCount int `json:"tool_count,omitempty"` } // MCPClientInfo holds information about an MCP client's state diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 70146d3ad181459db3d2193383373159f72b2022..db2042efd5dfc84738afeb6895a3e916ea95a1de 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -157,6 +157,37 @@ const ( StateDisabled ) +func (s ServerState) MarshalText() ([]byte, error) { + switch s { + case StateStarting: + return []byte("starting"), nil + case StateReady: + return []byte("ready"), nil + case StateError: + return []byte("error"), nil + case StateDisabled: + return []byte("disabled"), nil + default: + return nil, fmt.Errorf("unknown server state: %d", s) + } +} + +func (s *ServerState) UnmarshalText(data []byte) error { + switch strings.ToLower(string(data)) { + case "starting": + *s = StateStarting + case "ready": + *s = StateReady + case "error": + *s = StateError + case "disabled": + *s = StateDisabled + default: + return fmt.Errorf("unknown server state: %s", data) + } + return nil +} + // GetServerState returns the current state of the LSP server func (c *Client) GetServerState() ServerState { if val := c.serverState.Load(); val != nil { diff --git a/internal/message/content.go b/internal/message/content.go index b3f212187c86fb57667d95943fd15b8c6e3cccdb..b3021c48cc60247e016181e7d79134cdeca3c856 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -17,6 +17,15 @@ const ( Tool MessageRole = "tool" ) +func (r MessageRole) MarshalText() ([]byte, error) { + return []byte(r), nil +} + +func (r *MessageRole) UnmarshalText(data []byte) error { + *r = MessageRole(data) + return nil +} + type FinishReason string const ( @@ -31,6 +40,15 @@ const ( FinishReasonUnknown FinishReason = "unknown" ) +func (fr FinishReason) MarshalText() ([]byte, error) { + return []byte(fr), nil +} + +func (fr *FinishReason) UnmarshalText(data []byte) error { + *fr = FinishReason(data) + return nil +} + type ContentPart interface { isPart() } @@ -114,14 +132,14 @@ type Finish struct { func (Finish) isPart() {} type Message struct { - ID string - Role MessageRole - SessionID string - Parts []ContentPart - Model string - Provider string - CreatedAt int64 - UpdatedAt int64 + ID string `json:"id"` + Role MessageRole `json:"role"` + SessionID string `json:"session_id"` + Parts []ContentPart `json:"parts"` + Model string `json:"model"` + Provider string `json:"provider"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } func (m *Message) Content() TextContent { diff --git a/internal/message/message.go b/internal/message/message.go index 7cd823bc3129df5f807ec478d9d6c02364c6cfec..106aa8846cee88e9bb17804de72d3d7c6743e873 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -13,10 +13,10 @@ import ( ) type CreateMessageParams struct { - Role MessageRole - Parts []ContentPart - Model string - Provider string + Role MessageRole `json:"role"` + Parts []ContentPart `json:"parts"` + Model string `json:"model"` + Provider string `json:"provider,omitempty"` } type Service interface { diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 9a9139334f20db9727911e221bd95ab8f68394b0..6dd6c29bd378fa81caa4b4abc849618ef4674036 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -1,4 +1,14 @@ package proto -// Args represents generic arguments that apply to all RPC routines. -type Args struct{} +// Instance represents a running app.App instance with its associated resources +// and state. +type Instance struct { + ID string `json:"id"` + Path string `json:"path"` + YOLO bool `json:"yolo"` +} + +// Error represents an error response. +type Error struct { + Message string `json:"message"` +} diff --git a/internal/pubsub/events.go b/internal/pubsub/events.go index 2fb0a741353bfc5054641815da9ad3292f49e6a3..af3df38bdc8cde1f7255c26ef887934412ba537b 100644 --- a/internal/pubsub/events.go +++ b/internal/pubsub/events.go @@ -18,11 +18,20 @@ type ( // Event represents an event in the lifecycle of a resource Event[T any] struct { - Type EventType - Payload T + Type EventType `json:"type"` + Payload T `json:"payload"` } Publisher[T any] interface { Publish(EventType, T) } ) + +func (t EventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +func (t *EventType) UnmarshalText(data []byte) error { + *t = EventType(data) + return nil +} diff --git a/internal/server/codec.go b/internal/server/codec.go deleted file mode 100644 index 9311a5767848ba25aa0e09d9b4de2768b05a1024..0000000000000000000000000000000000000000 --- a/internal/server/codec.go +++ /dev/null @@ -1,42 +0,0 @@ -package server - -import ( - "log/slog" - "net/rpc" - - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" -) - -// ServerCodec is a wrapper around msgpackrpc.ServerCodec that adds logging -// functionality. -type ServerCodec struct { - *msgpackrpc.MsgpackCodec - logger *slog.Logger -} - -var _ rpc.ServerCodec = (*ServerCodec)(nil) - -// ReadRequestHeader reads the request header and logs it. -func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error { - err := c.MsgpackCodec.ReadRequestHeader(r) - if c.logger != nil { - c.logger.Debug("rpc request", - slog.String("service_method", r.ServiceMethod), - slog.Int("seq", int(r.Seq)), - ) - } - return err -} - -// WriteResponse writes the response and logs it. -func (c *ServerCodec) WriteResponse(r *rpc.Response, body any) error { - err := c.MsgpackCodec.WriteResponse(r, body) - if c.logger != nil { - c.logger.Debug("rpc response", - slog.String("service_method", r.ServiceMethod), - slog.String("error", r.Error), - slog.Int("seq", int(r.Seq)), - ) - } - return err -} diff --git a/internal/server/logging.go b/internal/server/logging.go new file mode 100644 index 0000000000000000000000000000000000000000..736e3d57cfb6697a07cc61a03c4157a42140df54 --- /dev/null +++ b/internal/server/logging.go @@ -0,0 +1,51 @@ +package server + +import ( + "log/slog" + "net/http" + "time" +) + +func (s *Server) loggingHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if s.logger == nil { + next.ServeHTTP(w, r) + return + } + + start := time.Now() + lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK} + s.logger.Debug("HTTP request", + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.String("remote_addr", r.RemoteAddr), + slog.String("user_agent", r.UserAgent()), + ) + + next.ServeHTTP(lrw, r) + duration := time.Since(start) + + s.logger.Debug("HTTP response", + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", lrw.statusCode), + slog.Duration("duration", duration), + slog.String("remote_addr", r.RemoteAddr), + slog.String("user_agent", r.UserAgent()), + ) + }) +} + +type loggingResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func (lrw *loggingResponseWriter) WriteHeader(code int) { + lrw.statusCode = code + lrw.ResponseWriter.WriteHeader(code) +} + +func (lrw *loggingResponseWriter) Unwrap() http.ResponseWriter { + return lrw.ResponseWriter +} diff --git a/internal/server/proto.go b/internal/server/proto.go index 4d95366d3af5c59722bd53bddea3583e331a305c..b732f866dca6fe694257df060a349d9d1f5cf1a2 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -1,17 +1,187 @@ package server import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os" + "path/filepath" + + "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/proto" ) -// ServerProto defines the RPC methods exposed by the Crush server. -type ServerProto struct { - *Server +func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { + jsonEncode(w, s.cfg) +} + +func (s *Server) handleGetInstances(w http.ResponseWriter, r *http.Request) { + instances := []proto.Instance{} + for _, ins := range s.instances.Seq2() { + instances = append(instances, proto.Instance{ + ID: ins.ID(), + Path: ins.Path(), + YOLO: ins.cfg.Permissions != nil && ins.cfg.Permissions.SkipRequests, + }) + } + jsonEncode(w, instances) +} + +func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) { + flusher := http.NewResponseController(w) + id := r.PathValue("id") + ins, ok := s.instances.Get(id) + if !ok { + s.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + for { + select { + case <-r.Context().Done(): + return + case ev := <-ins.App.Events(): + data, err := json.Marshal(ev) + if err != nil { + s.logError(r, "failed to marshal event", "error", err) + continue + } + + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + } +} + +func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) { + var ids []string + id := r.URL.Query().Get("id") + if id != "" { + ids = append(ids, id) + } + + // Get IDs from body + var args []proto.Instance + if err := json.NewDecoder(r.Body).Decode(&args); err != nil { + s.logError(r, "failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + ids = append(ids, func() []string { + out := make([]string, len(args)) + for i, arg := range args { + out[i] = arg.ID + } + return out + }()...) + + for _, id := range ids { + s.instances.Del(id) + } +} + +func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) { + var args proto.Instance + if err := json.NewDecoder(r.Body).Decode(&args); err != nil { + s.logError(r, "failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + ctx := r.Context() + hasher := sha256.New() + hasher.Write([]byte(filepath.Clean(args.Path))) + id := hex.EncodeToString(hasher.Sum(nil)) + if existing, ok := s.instances.Get(id); ok { + jsonEncode(w, proto.Instance{ + ID: existing.ID(), + Path: existing.Path(), + YOLO: existing.cfg.Permissions != nil && existing.cfg.Permissions.SkipRequests, + }) + return + } + + cfg, err := config.Init(args.Path, s.cfg.Options.DataDirectory, s.cfg.Options.Debug) + if err != nil { + s.logError(r, "failed to initialize config", "error", err) + jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err)) + return + } + + if cfg.Permissions == nil { + cfg.Permissions = &config.Permissions{} + } + cfg.Permissions.SkipRequests = args.YOLO + + if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil { + s.logError(r, "failed to create data directory", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to create data directory") + return + } + + // Connect to DB; this will also run migrations. + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + s.logError(r, "failed to connect to database", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to connect to database") + return + } + + appInstance, err := app.New(ctx, conn, cfg) + if err != nil { + slog.Error("failed to create app instance", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to create app instance") + return + } + + ins := &Instance{ + App: appInstance, + State: InstanceStateCreated, + id: id, + path: args.Path, + cfg: cfg, + } + + s.instances.Set(id, ins) + jsonEncode(w, proto.Instance{ + ID: id, + Path: args.Path, + YOLO: cfg.Permissions.SkipRequests, + }) } -// GetConfig is an RPC routine that returns the server's configuration. -func (s *ServerProto) GetConfig(args *proto.Args, reply *config.Config) error { - *reply = *s.cfg +func createDotCrushDir(dir string) error { + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create data directory: %q %w", dir, err) + } + + gitIgnorePath := filepath.Join(dir, ".gitignore") + if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) { + if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil { + return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err) + } + } + return nil } + +func jsonEncode(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(v) +} + +func jsonError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(proto.Error{Message: message}) +} diff --git a/internal/server/server.go b/internal/server/server.go index f4f75eda1a9dffd99a366aeaca02e01ae80df3a7..7a7aeb468abe2878aad64de54423c3bc88265223 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -5,20 +5,16 @@ import ( "fmt" "log/slog" "net" - "net/rpc" + "net/http" "os" "os/user" "path/filepath" "runtime" "strings" - "sync/atomic" - "time" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" - - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" ) // ErrServerClosed is returned when the server is closed. @@ -41,6 +37,8 @@ const ( type Instance struct { *app.App State InstanceState + ln net.Listener + cfg *config.Config id string path string } @@ -58,15 +56,15 @@ func (i *Instance) Path() string { // DefaultAddr returns the default address path for the Crush server based on // the operating system. func DefaultAddr() string { - sock := "crush.sock" + sockPath := "crush.sock" user, err := user.Current() if err == nil && user.Uid != "" { - sock = fmt.Sprintf("crush-%s.sock", user.Uid) + sockPath = fmt.Sprintf("crush-%s.sock", user.Uid) } if runtime.GOOS == "windows" { - return fmt.Sprintf(`\\.\pipe\%s`, sock) + return fmt.Sprintf(`\\.\pipe\%s`, sockPath) } - return filepath.Join(os.TempDir(), sock) + return filepath.Join(os.TempDir(), sockPath) } // Server represents a Crush server instance bound to a specific address. @@ -74,14 +72,13 @@ type Server struct { // Addr can be a TCP address, a Unix socket path, or a Windows named pipe. Addr string + h *http.Server + ln net.Listener + // instances is a map of running applications managed by the server. instances *csync.Map[string, *Instance] - // listeners is the network listener for the server. - listeners *csync.Map[*net.Listener, struct{}] cfg *config.Config logger *slog.Logger - - shutdown atomic.Bool } // SetLogger sets the logger for the server. @@ -109,44 +106,33 @@ func NewServer(cfg *config.Config, network, address string) *Server { s.Addr = address s.cfg = cfg s.instances = csync.NewMap[string, *Instance]() - rpc.Register(&ServerProto{s}) + + var p http.Protocols + p.SetHTTP1(true) + p.SetUnencryptedHTTP2(true) + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/config", s.handleGetConfig) + mux.HandleFunc("GET /v1/instances", s.handleGetInstances) + mux.HandleFunc("POST /v1/instances", s.handlePostInstances) + mux.HandleFunc("DELETE /v1/instances", s.handleDeleteInstances) + mux.HandleFunc("GET /v1/instances/{id}/events", s.handleGetInstanceEvents) + s.h = &http.Server{ + Protocols: &p, + Handler: s.loggingHandler(mux), + } return s } // Serve accepts incoming connections on the listener. func (s *Server) Serve(ln net.Listener) error { - if s.listeners == nil { - s.listeners = csync.NewMap[*net.Listener, struct{}]() - } - s.listeners.Set(&ln, struct{}{}) - - var tempDelay time.Duration // how long to sleep on accept failure - for { - conn, err := ln.Accept() - if err != nil { - if s.shuttingDown() { - return ErrServerClosed - } - if ne, ok := err.(net.Error); ok && ne.Temporary() { - if tempDelay == 0 { - tempDelay = 5 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max - } - time.Sleep(tempDelay) - continue - } - return fmt.Errorf("failed to accept connection: %w", err) - } - go s.handleConn(conn) - } + return s.h.Serve(ln) } // ListenAndServe starts the server and begins accepting connections. func (s *Server) ListenAndServe() error { + if s.ln != nil { + return fmt.Errorf("server already started") + } ln, err := listen("unix", s.Addr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", s.Addr, err) @@ -154,63 +140,33 @@ func (s *Server) ListenAndServe() error { return s.Serve(ln) } +func (s *Server) closeListener() { + if s.ln != nil { + s.ln.Close() + s.ln = nil + } +} + // Close force close all listeners and connections. func (s *Server) Close() error { - s.shutdown.Store(true) - var firstErr error - for k := range s.listeners.Seq2() { - if err := (*k).Close(); err != nil && firstErr == nil { - firstErr = err - } - s.listeners.Del(k) - } - return firstErr + defer func() { s.closeListener() }() + return s.h.Close() } // Shutdown gracefully shuts down the server without interrupting active // connections. It stops accepting new connections and waits for existing // connections to finish. func (s *Server) Shutdown(ctx context.Context) error { - // TODO: implement graceful shutdown - return s.Close() -} - -func (s *Server) handleConn(conn net.Conn) { - s.info("accepted connection", "remote_addr", conn.LocalAddr()) - codec := &ServerCodec{ - MsgpackCodec: msgpackrpc.NewCodec(true, true, conn), - logger: s.logger.With( - slog.String("remote_addr", conn.RemoteAddr().String()), - slog.String("local_addr", conn.LocalAddr().String()), - ), - } - rpc.ServeCodec(codec) -} - -func (s *Server) shuttingDown() bool { - return s.shutdown.Load() -} - -func (s *Server) info(msg string, args ...any) { - if s.logger != nil { - s.logger.Info(msg, args...) - } -} - -func (s *Server) debug(msg string, args ...any) { - if s.logger != nil { - s.logger.Debug(msg, args...) - } -} - -func (s *Server) error(msg string, args ...any) { - if s.logger != nil { - s.logger.Error(msg, args...) - } + defer func() { s.closeListener() }() + return s.h.Shutdown(ctx) } -func (s *Server) warn(msg string, args ...any) { +func (s *Server) logError(r *http.Request, msg string, args ...any) { if s.logger != nil { - s.logger.Warn(msg, args...) + s.logger.With( + slog.String("method", r.Method), + slog.String("url", r.URL.String()), + slog.String("remote_addr", r.RemoteAddr), + ).Error(msg, args...) } } diff --git a/internal/session/session.go b/internal/session/session.go index d988dac3414fa7dd00d13b375e1309f8d6c515dd..7b57a37c3bed304ef11211f511e6d993bc497ef4 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -10,16 +10,16 @@ import ( ) type Session struct { - ID string - ParentSessionID string - Title string - MessageCount int64 - PromptTokens int64 - CompletionTokens int64 - SummaryMessageID string - Cost float64 - CreatedAt int64 - UpdatedAt int64 + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + SummaryMessageID string `json:"summary_message_id"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } type Service interface {