From 72251dd90681476e158a931069d5b078261c0323 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Tue, 17 Mar 2026 15:29:12 +0300 Subject: [PATCH] refactor: update client methods to return proto types --- internal/client/proto.go | 125 ++++++------------------- internal/cmd/run.go | 6 +- internal/workspace/client_workspace.go | 73 +++++++++++++-- 3 files changed, 98 insertions(+), 106 deletions(-) diff --git a/internal/client/proto.go b/internal/client/proto.go index 7a3b3a3c6b4b375c776ac6740f4d055858205941..13cd5fa12b9e3f29e1cceb61696d695366cbf133 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -14,11 +14,9 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" ) @@ -404,8 +402,8 @@ func (c *Client) InitiateAgentProcessing(ctx context.Context, id string) error { return nil } -// ListMessages retrieves all messages for a session. -func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) { +// ListMessages retrieves all messages for a session as proto types. +func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages", id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) @@ -414,15 +412,15 @@ func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get messages: status code %d", rsp.StatusCode) } - var protoMsgs []proto.Message - if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) { + var msgs []proto.Message + if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to decode messages: %w", err) } - return protoToMessages(protoMsgs), nil + return msgs, nil } -// GetSession retrieves a specific session. -func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*session.Session, error) { +// GetSession retrieves a specific session as a proto type. +func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*proto.Session, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get session: %w", err) @@ -431,15 +429,15 @@ func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (* if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get session: status code %d", rsp.StatusCode) } - var sess session.Session + var sess proto.Session if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil { return nil, fmt.Errorf("failed to decode session: %w", err) } return &sess, nil } -// ListSessionHistoryFiles retrieves history files for a session. -func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]history.File, error) { +// ListSessionHistoryFiles retrieves history files for a session as proto types. +func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]proto.File, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/history", id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get session history files: %w", err) @@ -448,16 +446,16 @@ func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, session if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get session history files: status code %d", rsp.StatusCode) } - var files []history.File + var files []proto.File if err := json.NewDecoder(rsp.Body).Decode(&files); err != nil { return nil, fmt.Errorf("failed to decode session history files: %w", err) } return files, nil } -// CreateSession creates a new session in a workspace. -func (c *Client) CreateSession(ctx context.Context, id string, title string) (*session.Session, error) { - rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}}) +// CreateSession creates a new session in a workspace as a proto type. +func (c *Client) CreateSession(ctx context.Context, id string, title string) (*proto.Session, error) { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(proto.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } @@ -465,15 +463,15 @@ func (c *Client) CreateSession(ctx context.Context, id string, title string) (*s if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to create session: status code %d", rsp.StatusCode) } - var sess session.Session + var sess proto.Session if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil { return nil, fmt.Errorf("failed to decode session: %w", err) } return &sess, nil } -// ListSessions lists all sessions in a workspace. -func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session, error) { +// ListSessions lists all sessions in a workspace as proto types. +func (c *Client) ListSessions(ctx context.Context, id string) ([]proto.Session, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get sessions: %w", err) @@ -482,7 +480,7 @@ func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get sessions: status code %d", rsp.StatusCode) } - var sessions []session.Session + var sessions []proto.Session if err := json.NewDecoder(rsp.Body).Decode(&sessions); err != nil { return nil, fmt.Errorf("failed to decode sessions: %w", err) } @@ -556,8 +554,8 @@ func jsonBody(v any) *bytes.Buffer { return b } -// SaveSession updates a session in a workspace. -func (c *Client) SaveSession(ctx context.Context, id string, sess session.Session) (*session.Session, error) { +// SaveSession updates a session in a workspace, returning a proto type. +func (c *Client) SaveSession(ctx context.Context, id string, sess proto.Session) (*proto.Session, error) { rsp, err := c.put(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sess.ID), nil, jsonBody(sess), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return nil, fmt.Errorf("failed to save session: %w", err) @@ -566,7 +564,7 @@ func (c *Client) SaveSession(ctx context.Context, id string, sess session.Sessio if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to save session: status code %d", rsp.StatusCode) } - var saved session.Session + var saved proto.Session if err := json.NewDecoder(rsp.Body).Decode(&saved); err != nil { return nil, fmt.Errorf("failed to decode session: %w", err) } @@ -586,8 +584,8 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string) return nil } -// ListUserMessages retrieves user-role messages for a session. -func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) { +// ListUserMessages retrieves user-role messages for a session as proto types. +func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get user messages: %w", err) @@ -596,15 +594,15 @@ func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID stri if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get user messages: status code %d", rsp.StatusCode) } - var protoMsgs []proto.Message - if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) { + var msgs []proto.Message + if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to decode user messages: %w", err) } - return protoToMessages(protoMsgs), nil + return msgs, nil } -// ListAllUserMessages retrieves all user-role messages across sessions. -func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message.Message, error) { +// ListAllUserMessages retrieves all user-role messages across sessions as proto types. +func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]proto.Message, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/messages/user", id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get all user messages: %w", err) @@ -613,11 +611,11 @@ func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message. if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get all user messages: status code %d", rsp.StatusCode) } - var protoMsgs []proto.Message - if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) { + var msgs []proto.Message + if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to decode all user messages: %w", err) } - return protoToMessages(protoMsgs), nil + return msgs, nil } // CancelAgentSession cancels an ongoing agent operation for a session. @@ -749,64 +747,3 @@ func (c *Client) LSPStopAll(ctx context.Context, id string) error { } return nil } - -func protoToMessages(msgs []proto.Message) []message.Message { - out := make([]message.Message, len(msgs)) - for i, m := range msgs { - out[i] = protoToMessage(m) - } - return out -} - -func protoToMessage(m proto.Message) message.Message { - msg := message.Message{ - ID: m.ID, - SessionID: m.SessionID, - Role: message.MessageRole(m.Role), - Model: m.Model, - Provider: m.Provider, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - } - - for _, p := range m.Parts { - switch v := p.(type) { - case proto.TextContent: - msg.Parts = append(msg.Parts, message.TextContent{Text: v.Text}) - case proto.ReasoningContent: - msg.Parts = append(msg.Parts, message.ReasoningContent{ - Thinking: v.Thinking, - Signature: v.Signature, - StartedAt: v.StartedAt, - FinishedAt: v.FinishedAt, - }) - case proto.ToolCall: - msg.Parts = append(msg.Parts, message.ToolCall{ - ID: v.ID, - Name: v.Name, - Input: v.Input, - Finished: v.Finished, - }) - case proto.ToolResult: - msg.Parts = append(msg.Parts, message.ToolResult{ - ToolCallID: v.ToolCallID, - Name: v.Name, - Content: v.Content, - IsError: v.IsError, - }) - case proto.Finish: - msg.Parts = append(msg.Parts, message.Finish{ - Reason: message.FinishReason(v.Reason), - Time: v.Time, - Message: v.Message, - Details: v.Details, - }) - case proto.ImageURLContent: - msg.Parts = append(msg.Parts, message.ImageURLContent{URL: v.URL, Detail: v.Detail}) - case proto.BinaryContent: - msg.Parts = append(msg.Parts, message.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data}) - } - } - - return msg -} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 6b2b18c7414da89d34c7b306525887cea07ac9ad..010baeb3ea5dc933477af1d2eadab101c28132fc 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -448,7 +448,7 @@ func validateModelMatches(matches []modelMatch, modelID, label string) (modelMat // If continueSessionID is set it fetches that session; if useLast is set it // returns the most recently updated top-level session; otherwise it creates a // new one. -func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*session.Session, error) { +func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) { switch { case continueSessionID != "": sess, err := c.GetSession(ctx, wsID, continueSessionID) @@ -480,7 +480,7 @@ func resolveSession(ctx context.Context, c *client.Client, wsID, continueSession // resolveSessionByID resolves a session ID that may be a full UUID or a hash // prefix returned by crush session list. -func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*session.Session, error) { +func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) { if sess, err := c.GetSession(ctx, wsID, id); err == nil { return sess, nil } @@ -490,7 +490,7 @@ func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) return nil, err } - var matches []session.Session + var matches []proto.Session for _, s := range sessions { hash := session.HashID(s.ID) if hash == id || strings.HasPrefix(hash, id) { diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index cc68d3d9bf97fe2e4a23d26d2be1489a0ddb4c99..a3090a740e948b36111c29d1c1f59918df2a0c0b 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -83,7 +83,7 @@ func (w *ClientWorkspace) CreateSession(ctx context.Context, title string) (sess if err != nil { return session.Session{}, err } - return *sess, nil + return protoToSession(*sess), nil } func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) { @@ -91,19 +91,27 @@ func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (ses if err != nil { return session.Session{}, err } - return *sess, nil + return protoToSession(*sess), nil } func (w *ClientWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) { - return w.client.ListSessions(ctx, w.workspaceID()) + protoSessions, err := w.client.ListSessions(ctx, w.workspaceID()) + if err != nil { + return nil, err + } + sessions := make([]session.Session, len(protoSessions)) + for i, s := range protoSessions { + sessions[i] = protoToSession(s) + } + return sessions, nil } func (w *ClientWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) { - saved, err := w.client.SaveSession(ctx, w.workspaceID(), sess) + saved, err := w.client.SaveSession(ctx, w.workspaceID(), sessionToProto(sess)) if err != nil { return session.Session{}, err } - return *saved, nil + return protoToSession(*saved), nil } func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) error { @@ -125,15 +133,27 @@ func (w *ClientWorkspace) ParseAgentToolSessionID(sessionID string) (string, str // -- Messages -- func (w *ClientWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { - return w.client.ListMessages(ctx, w.workspaceID(), sessionID) + msgs, err := w.client.ListMessages(ctx, w.workspaceID(), sessionID) + if err != nil { + return nil, err + } + return protoToMessages(msgs), nil } func (w *ClientWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) { - return w.client.ListUserMessages(ctx, w.workspaceID(), sessionID) + msgs, err := w.client.ListUserMessages(ctx, w.workspaceID(), sessionID) + if err != nil { + return nil, err + } + return protoToMessages(msgs), nil } func (w *ClientWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) { - return w.client.ListAllUserMessages(ctx, w.workspaceID()) + msgs, err := w.client.ListAllUserMessages(ctx, w.workspaceID()) + if err != nil { + return nil, err + } + return protoToMessages(msgs), nil } // -- Agent -- @@ -304,7 +324,11 @@ func (w *ClientWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionI // -- History -- func (w *ClientWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) { - return w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID) + files, err := w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID) + if err != nil { + return nil, err + } + return protoToFiles(files), nil } // -- LSP -- @@ -688,3 +712,34 @@ func protoToMessage(m proto.Message) message.Message { return msg } + +func protoToMessages(msgs []proto.Message) []message.Message { + out := make([]message.Message, len(msgs)) + for i, m := range msgs { + out[i] = protoToMessage(m) + } + return out +} + +func protoToFiles(files []proto.File) []history.File { + out := make([]history.File, len(files)) + for i, f := range files { + out[i] = protoToFile(f) + } + return out +} + +func sessionToProto(s session.Session) proto.Session { + return proto.Session{ + ID: s.ID, + ParentSessionID: s.ParentSessionID, + Title: s.Title, + SummaryMessageID: s.SummaryMessageID, + MessageCount: s.MessageCount, + PromptTokens: s.PromptTokens, + CompletionTokens: s.CompletionTokens, + Cost: s.Cost, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + } +}