refactor: update client methods to return proto types

Ayman Bagabas created

Change summary

internal/client/proto.go               | 125 ++++++---------------------
internal/cmd/run.go                    |   6 
internal/workspace/client_workspace.go |  73 ++++++++++++++--
3 files changed, 98 insertions(+), 106 deletions(-)

Detailed changes

internal/client/proto.go 🔗

@@ -14,11 +14,9 @@ import (
 	"time"
 
 	"github.com/charmbracelet/crush/internal/config"
-	"github.com/charmbracelet/crush/internal/history"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/proto"
 	"github.com/charmbracelet/crush/internal/pubsub"
-	"github.com/charmbracelet/crush/internal/session"
 	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 )
 
@@ -404,8 +402,8 @@ func (c *Client) InitiateAgentProcessing(ctx context.Context, id string) error {
 	return nil
 }
 
-// ListMessages retrieves all messages for a session.
-func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) {
+// ListMessages retrieves all messages for a session as proto types.
+func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages", id, sessionID), nil, nil)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get messages: %w", err)
@@ -414,15 +412,15 @@ func (c *Client) ListMessages(ctx context.Context, id string, sessionID string)
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to get messages: status code %d", rsp.StatusCode)
 	}
-	var protoMsgs []proto.Message
-	if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) {
+	var msgs []proto.Message
+	if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) {
 		return nil, fmt.Errorf("failed to decode messages: %w", err)
 	}
-	return protoToMessages(protoMsgs), nil
+	return msgs, nil
 }
 
-// GetSession retrieves a specific session.
-func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*session.Session, error) {
+// GetSession retrieves a specific session as a proto type.
+func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*proto.Session, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sessionID), nil, nil)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get session: %w", err)
@@ -431,15 +429,15 @@ func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to get session: status code %d", rsp.StatusCode)
 	}
-	var sess session.Session
+	var sess proto.Session
 	if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil {
 		return nil, fmt.Errorf("failed to decode session: %w", err)
 	}
 	return &sess, nil
 }
 
-// ListSessionHistoryFiles retrieves history files for a session.
-func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]history.File, error) {
+// ListSessionHistoryFiles retrieves history files for a session as proto types.
+func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]proto.File, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/history", id, sessionID), nil, nil)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get session history files: %w", err)
@@ -448,16 +446,16 @@ func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, session
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to get session history files: status code %d", rsp.StatusCode)
 	}
-	var files []history.File
+	var files []proto.File
 	if err := json.NewDecoder(rsp.Body).Decode(&files); err != nil {
 		return nil, fmt.Errorf("failed to decode session history files: %w", err)
 	}
 	return files, nil
 }
 
-// CreateSession creates a new session in a workspace.
-func (c *Client) CreateSession(ctx context.Context, id string, title string) (*session.Session, error) {
-	rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}})
+// CreateSession creates a new session in a workspace as a proto type.
+func (c *Client) CreateSession(ctx context.Context, id string, title string) (*proto.Session, error) {
+	rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(proto.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}})
 	if err != nil {
 		return nil, fmt.Errorf("failed to create session: %w", err)
 	}
@@ -465,15 +463,15 @@ func (c *Client) CreateSession(ctx context.Context, id string, title string) (*s
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to create session: status code %d", rsp.StatusCode)
 	}
-	var sess session.Session
+	var sess proto.Session
 	if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil {
 		return nil, fmt.Errorf("failed to decode session: %w", err)
 	}
 	return &sess, nil
 }
 
-// ListSessions lists all sessions in a workspace.
-func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session, error) {
+// ListSessions lists all sessions in a workspace as proto types.
+func (c *Client) ListSessions(ctx context.Context, id string) ([]proto.Session, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, nil)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get sessions: %w", err)
@@ -482,7 +480,7 @@ func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to get sessions: status code %d", rsp.StatusCode)
 	}
-	var sessions []session.Session
+	var sessions []proto.Session
 	if err := json.NewDecoder(rsp.Body).Decode(&sessions); err != nil {
 		return nil, fmt.Errorf("failed to decode sessions: %w", err)
 	}
@@ -556,8 +554,8 @@ func jsonBody(v any) *bytes.Buffer {
 	return b
 }
 
-// SaveSession updates a session in a workspace.
-func (c *Client) SaveSession(ctx context.Context, id string, sess session.Session) (*session.Session, error) {
+// SaveSession updates a session in a workspace, returning a proto type.
+func (c *Client) SaveSession(ctx context.Context, id string, sess proto.Session) (*proto.Session, error) {
 	rsp, err := c.put(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sess.ID), nil, jsonBody(sess), http.Header{"Content-Type": []string{"application/json"}})
 	if err != nil {
 		return nil, fmt.Errorf("failed to save session: %w", err)
@@ -566,7 +564,7 @@ func (c *Client) SaveSession(ctx context.Context, id string, sess session.Sessio
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to save session: status code %d", rsp.StatusCode)
 	}
-	var saved session.Session
+	var saved proto.Session
 	if err := json.NewDecoder(rsp.Body).Decode(&saved); err != nil {
 		return nil, fmt.Errorf("failed to decode session: %w", err)
 	}
@@ -586,8 +584,8 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string)
 	return nil
 }
 
-// ListUserMessages retrieves user-role messages for a session.
-func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) {
+// ListUserMessages retrieves user-role messages for a session as proto types.
+func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get user messages: %w", err)
@@ -596,15 +594,15 @@ func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID stri
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to get user messages: status code %d", rsp.StatusCode)
 	}
-	var protoMsgs []proto.Message
-	if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) {
+	var msgs []proto.Message
+	if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) {
 		return nil, fmt.Errorf("failed to decode user messages: %w", err)
 	}
-	return protoToMessages(protoMsgs), nil
+	return msgs, nil
 }
 
-// ListAllUserMessages retrieves all user-role messages across sessions.
-func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message.Message, error) {
+// ListAllUserMessages retrieves all user-role messages across sessions as proto types.
+func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]proto.Message, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/messages/user", id), nil, nil)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get all user messages: %w", err)
@@ -613,11 +611,11 @@ func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message.
 	if rsp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("failed to get all user messages: status code %d", rsp.StatusCode)
 	}
-	var protoMsgs []proto.Message
-	if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) {
+	var msgs []proto.Message
+	if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) {
 		return nil, fmt.Errorf("failed to decode all user messages: %w", err)
 	}
-	return protoToMessages(protoMsgs), nil
+	return msgs, nil
 }
 
 // CancelAgentSession cancels an ongoing agent operation for a session.
@@ -749,64 +747,3 @@ func (c *Client) LSPStopAll(ctx context.Context, id string) error {
 	}
 	return nil
 }
-
-func protoToMessages(msgs []proto.Message) []message.Message {
-	out := make([]message.Message, len(msgs))
-	for i, m := range msgs {
-		out[i] = protoToMessage(m)
-	}
-	return out
-}
-
-func protoToMessage(m proto.Message) message.Message {
-	msg := message.Message{
-		ID:        m.ID,
-		SessionID: m.SessionID,
-		Role:      message.MessageRole(m.Role),
-		Model:     m.Model,
-		Provider:  m.Provider,
-		CreatedAt: m.CreatedAt,
-		UpdatedAt: m.UpdatedAt,
-	}
-
-	for _, p := range m.Parts {
-		switch v := p.(type) {
-		case proto.TextContent:
-			msg.Parts = append(msg.Parts, message.TextContent{Text: v.Text})
-		case proto.ReasoningContent:
-			msg.Parts = append(msg.Parts, message.ReasoningContent{
-				Thinking:   v.Thinking,
-				Signature:  v.Signature,
-				StartedAt:  v.StartedAt,
-				FinishedAt: v.FinishedAt,
-			})
-		case proto.ToolCall:
-			msg.Parts = append(msg.Parts, message.ToolCall{
-				ID:       v.ID,
-				Name:     v.Name,
-				Input:    v.Input,
-				Finished: v.Finished,
-			})
-		case proto.ToolResult:
-			msg.Parts = append(msg.Parts, message.ToolResult{
-				ToolCallID: v.ToolCallID,
-				Name:       v.Name,
-				Content:    v.Content,
-				IsError:    v.IsError,
-			})
-		case proto.Finish:
-			msg.Parts = append(msg.Parts, message.Finish{
-				Reason:  message.FinishReason(v.Reason),
-				Time:    v.Time,
-				Message: v.Message,
-				Details: v.Details,
-			})
-		case proto.ImageURLContent:
-			msg.Parts = append(msg.Parts, message.ImageURLContent{URL: v.URL, Detail: v.Detail})
-		case proto.BinaryContent:
-			msg.Parts = append(msg.Parts, message.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data})
-		}
-	}
-
-	return msg
-}

internal/cmd/run.go 🔗

@@ -448,7 +448,7 @@ func validateModelMatches(matches []modelMatch, modelID, label string) (modelMat
 // If continueSessionID is set it fetches that session; if useLast is set it
 // returns the most recently updated top-level session; otherwise it creates a
 // new one.
-func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*session.Session, error) {
+func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
 	switch {
 	case continueSessionID != "":
 		sess, err := c.GetSession(ctx, wsID, continueSessionID)
@@ -480,7 +480,7 @@ func resolveSession(ctx context.Context, c *client.Client, wsID, continueSession
 
 // resolveSessionByID resolves a session ID that may be a full UUID or a hash
 // prefix returned by crush session list.
-func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*session.Session, error) {
+func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
 	if sess, err := c.GetSession(ctx, wsID, id); err == nil {
 		return sess, nil
 	}
@@ -490,7 +490,7 @@ func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string)
 		return nil, err
 	}
 
-	var matches []session.Session
+	var matches []proto.Session
 	for _, s := range sessions {
 		hash := session.HashID(s.ID)
 		if hash == id || strings.HasPrefix(hash, id) {

internal/workspace/client_workspace.go 🔗

@@ -83,7 +83,7 @@ func (w *ClientWorkspace) CreateSession(ctx context.Context, title string) (sess
 	if err != nil {
 		return session.Session{}, err
 	}
-	return *sess, nil
+	return protoToSession(*sess), nil
 }
 
 func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) {
@@ -91,19 +91,27 @@ func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (ses
 	if err != nil {
 		return session.Session{}, err
 	}
-	return *sess, nil
+	return protoToSession(*sess), nil
 }
 
 func (w *ClientWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) {
-	return w.client.ListSessions(ctx, w.workspaceID())
+	protoSessions, err := w.client.ListSessions(ctx, w.workspaceID())
+	if err != nil {
+		return nil, err
+	}
+	sessions := make([]session.Session, len(protoSessions))
+	for i, s := range protoSessions {
+		sessions[i] = protoToSession(s)
+	}
+	return sessions, nil
 }
 
 func (w *ClientWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) {
-	saved, err := w.client.SaveSession(ctx, w.workspaceID(), sess)
+	saved, err := w.client.SaveSession(ctx, w.workspaceID(), sessionToProto(sess))
 	if err != nil {
 		return session.Session{}, err
 	}
-	return *saved, nil
+	return protoToSession(*saved), nil
 }
 
 func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) error {
@@ -125,15 +133,27 @@ func (w *ClientWorkspace) ParseAgentToolSessionID(sessionID string) (string, str
 // -- Messages --
 
 func (w *ClientWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
-	return w.client.ListMessages(ctx, w.workspaceID(), sessionID)
+	msgs, err := w.client.ListMessages(ctx, w.workspaceID(), sessionID)
+	if err != nil {
+		return nil, err
+	}
+	return protoToMessages(msgs), nil
 }
 
 func (w *ClientWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
-	return w.client.ListUserMessages(ctx, w.workspaceID(), sessionID)
+	msgs, err := w.client.ListUserMessages(ctx, w.workspaceID(), sessionID)
+	if err != nil {
+		return nil, err
+	}
+	return protoToMessages(msgs), nil
 }
 
 func (w *ClientWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) {
-	return w.client.ListAllUserMessages(ctx, w.workspaceID())
+	msgs, err := w.client.ListAllUserMessages(ctx, w.workspaceID())
+	if err != nil {
+		return nil, err
+	}
+	return protoToMessages(msgs), nil
 }
 
 // -- Agent --
@@ -304,7 +324,11 @@ func (w *ClientWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionI
 // -- History --
 
 func (w *ClientWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) {
-	return w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID)
+	files, err := w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID)
+	if err != nil {
+		return nil, err
+	}
+	return protoToFiles(files), nil
 }
 
 // -- LSP --
@@ -688,3 +712,34 @@ func protoToMessage(m proto.Message) message.Message {
 
 	return msg
 }
+
+func protoToMessages(msgs []proto.Message) []message.Message {
+	out := make([]message.Message, len(msgs))
+	for i, m := range msgs {
+		out[i] = protoToMessage(m)
+	}
+	return out
+}
+
+func protoToFiles(files []proto.File) []history.File {
+	out := make([]history.File, len(files))
+	for i, f := range files {
+		out[i] = protoToFile(f)
+	}
+	return out
+}
+
+func sessionToProto(s session.Session) proto.Session {
+	return proto.Session{
+		ID:               s.ID,
+		ParentSessionID:  s.ParentSessionID,
+		Title:            s.Title,
+		SummaryMessageID: s.SummaryMessageID,
+		MessageCount:     s.MessageCount,
+		PromptTokens:     s.PromptTokens,
+		CompletionTokens: s.CompletionTokens,
+		Cost:             s.Cost,
+		CreatedAt:        s.CreatedAt,
+		UpdatedAt:        s.UpdatedAt,
+	}
+}