@@ -14,11 +14,9 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
)
@@ -404,8 +402,8 @@ func (c *Client) InitiateAgentProcessing(ctx context.Context, id string) error {
return nil
}
-// ListMessages retrieves all messages for a session.
-func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) {
+// ListMessages retrieves all messages for a session as proto types.
+func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages", id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get messages: %w", err)
@@ -414,15 +412,15 @@ func (c *Client) ListMessages(ctx context.Context, id string, sessionID string)
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get messages: status code %d", rsp.StatusCode)
}
- var protoMsgs []proto.Message
- if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) {
+ var msgs []proto.Message
+ if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("failed to decode messages: %w", err)
}
- return protoToMessages(protoMsgs), nil
+ return msgs, nil
}
-// GetSession retrieves a specific session.
-func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*session.Session, error) {
+// GetSession retrieves a specific session as a proto type.
+func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*proto.Session, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
@@ -431,15 +429,15 @@ func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get session: status code %d", rsp.StatusCode)
}
- var sess session.Session
+ var sess proto.Session
if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil {
return nil, fmt.Errorf("failed to decode session: %w", err)
}
return &sess, nil
}
-// ListSessionHistoryFiles retrieves history files for a session.
-func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]history.File, error) {
+// ListSessionHistoryFiles retrieves history files for a session as proto types.
+func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]proto.File, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/history", id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get session history files: %w", err)
@@ -448,16 +446,16 @@ func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, session
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get session history files: status code %d", rsp.StatusCode)
}
- var files []history.File
+ var files []proto.File
if err := json.NewDecoder(rsp.Body).Decode(&files); err != nil {
return nil, fmt.Errorf("failed to decode session history files: %w", err)
}
return files, nil
}
-// CreateSession creates a new session in a workspace.
-func (c *Client) CreateSession(ctx context.Context, id string, title string) (*session.Session, error) {
- rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}})
+// CreateSession creates a new session in a workspace as a proto type.
+func (c *Client) CreateSession(ctx context.Context, id string, title string) (*proto.Session, error) {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(proto.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
@@ -465,15 +463,15 @@ func (c *Client) CreateSession(ctx context.Context, id string, title string) (*s
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to create session: status code %d", rsp.StatusCode)
}
- var sess session.Session
+ var sess proto.Session
if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil {
return nil, fmt.Errorf("failed to decode session: %w", err)
}
return &sess, nil
}
-// ListSessions lists all sessions in a workspace.
-func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session, error) {
+// ListSessions lists all sessions in a workspace as proto types.
+func (c *Client) ListSessions(ctx context.Context, id string) ([]proto.Session, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get sessions: %w", err)
@@ -482,7 +480,7 @@ func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get sessions: status code %d", rsp.StatusCode)
}
- var sessions []session.Session
+ var sessions []proto.Session
if err := json.NewDecoder(rsp.Body).Decode(&sessions); err != nil {
return nil, fmt.Errorf("failed to decode sessions: %w", err)
}
@@ -556,8 +554,8 @@ func jsonBody(v any) *bytes.Buffer {
return b
}
-// SaveSession updates a session in a workspace.
-func (c *Client) SaveSession(ctx context.Context, id string, sess session.Session) (*session.Session, error) {
+// SaveSession updates a session in a workspace, returning a proto type.
+func (c *Client) SaveSession(ctx context.Context, id string, sess proto.Session) (*proto.Session, error) {
rsp, err := c.put(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sess.ID), nil, jsonBody(sess), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return nil, fmt.Errorf("failed to save session: %w", err)
@@ -566,7 +564,7 @@ func (c *Client) SaveSession(ctx context.Context, id string, sess session.Sessio
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to save session: status code %d", rsp.StatusCode)
}
- var saved session.Session
+ var saved proto.Session
if err := json.NewDecoder(rsp.Body).Decode(&saved); err != nil {
return nil, fmt.Errorf("failed to decode session: %w", err)
}
@@ -586,8 +584,8 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string)
return nil
}
-// ListUserMessages retrieves user-role messages for a session.
-func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) {
+// ListUserMessages retrieves user-role messages for a session as proto types.
+func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get user messages: %w", err)
@@ -596,15 +594,15 @@ func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID stri
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get user messages: status code %d", rsp.StatusCode)
}
- var protoMsgs []proto.Message
- if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) {
+ var msgs []proto.Message
+ if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("failed to decode user messages: %w", err)
}
- return protoToMessages(protoMsgs), nil
+ return msgs, nil
}
-// ListAllUserMessages retrieves all user-role messages across sessions.
-func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message.Message, error) {
+// ListAllUserMessages retrieves all user-role messages across sessions as proto types.
+func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]proto.Message, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/messages/user", id), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get all user messages: %w", err)
@@ -613,11 +611,11 @@ func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message.
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get all user messages: status code %d", rsp.StatusCode)
}
- var protoMsgs []proto.Message
- if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) {
+ var msgs []proto.Message
+ if err := json.NewDecoder(rsp.Body).Decode(&msgs); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("failed to decode all user messages: %w", err)
}
- return protoToMessages(protoMsgs), nil
+ return msgs, nil
}
// CancelAgentSession cancels an ongoing agent operation for a session.
@@ -749,64 +747,3 @@ func (c *Client) LSPStopAll(ctx context.Context, id string) error {
}
return nil
}
-
-func protoToMessages(msgs []proto.Message) []message.Message {
- out := make([]message.Message, len(msgs))
- for i, m := range msgs {
- out[i] = protoToMessage(m)
- }
- return out
-}
-
-func protoToMessage(m proto.Message) message.Message {
- msg := message.Message{
- ID: m.ID,
- SessionID: m.SessionID,
- Role: message.MessageRole(m.Role),
- Model: m.Model,
- Provider: m.Provider,
- CreatedAt: m.CreatedAt,
- UpdatedAt: m.UpdatedAt,
- }
-
- for _, p := range m.Parts {
- switch v := p.(type) {
- case proto.TextContent:
- msg.Parts = append(msg.Parts, message.TextContent{Text: v.Text})
- case proto.ReasoningContent:
- msg.Parts = append(msg.Parts, message.ReasoningContent{
- Thinking: v.Thinking,
- Signature: v.Signature,
- StartedAt: v.StartedAt,
- FinishedAt: v.FinishedAt,
- })
- case proto.ToolCall:
- msg.Parts = append(msg.Parts, message.ToolCall{
- ID: v.ID,
- Name: v.Name,
- Input: v.Input,
- Finished: v.Finished,
- })
- case proto.ToolResult:
- msg.Parts = append(msg.Parts, message.ToolResult{
- ToolCallID: v.ToolCallID,
- Name: v.Name,
- Content: v.Content,
- IsError: v.IsError,
- })
- case proto.Finish:
- msg.Parts = append(msg.Parts, message.Finish{
- Reason: message.FinishReason(v.Reason),
- Time: v.Time,
- Message: v.Message,
- Details: v.Details,
- })
- case proto.ImageURLContent:
- msg.Parts = append(msg.Parts, message.ImageURLContent{URL: v.URL, Detail: v.Detail})
- case proto.BinaryContent:
- msg.Parts = append(msg.Parts, message.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data})
- }
- }
-
- return msg
-}
@@ -448,7 +448,7 @@ func validateModelMatches(matches []modelMatch, modelID, label string) (modelMat
// If continueSessionID is set it fetches that session; if useLast is set it
// returns the most recently updated top-level session; otherwise it creates a
// new one.
-func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*session.Session, error) {
+func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
switch {
case continueSessionID != "":
sess, err := c.GetSession(ctx, wsID, continueSessionID)
@@ -480,7 +480,7 @@ func resolveSession(ctx context.Context, c *client.Client, wsID, continueSession
// resolveSessionByID resolves a session ID that may be a full UUID or a hash
// prefix returned by crush session list.
-func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*session.Session, error) {
+func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
if sess, err := c.GetSession(ctx, wsID, id); err == nil {
return sess, nil
}
@@ -490,7 +490,7 @@ func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string)
return nil, err
}
- var matches []session.Session
+ var matches []proto.Session
for _, s := range sessions {
hash := session.HashID(s.ID)
if hash == id || strings.HasPrefix(hash, id) {
@@ -83,7 +83,7 @@ func (w *ClientWorkspace) CreateSession(ctx context.Context, title string) (sess
if err != nil {
return session.Session{}, err
}
- return *sess, nil
+ return protoToSession(*sess), nil
}
func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) {
@@ -91,19 +91,27 @@ func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (ses
if err != nil {
return session.Session{}, err
}
- return *sess, nil
+ return protoToSession(*sess), nil
}
func (w *ClientWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) {
- return w.client.ListSessions(ctx, w.workspaceID())
+ protoSessions, err := w.client.ListSessions(ctx, w.workspaceID())
+ if err != nil {
+ return nil, err
+ }
+ sessions := make([]session.Session, len(protoSessions))
+ for i, s := range protoSessions {
+ sessions[i] = protoToSession(s)
+ }
+ return sessions, nil
}
func (w *ClientWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) {
- saved, err := w.client.SaveSession(ctx, w.workspaceID(), sess)
+ saved, err := w.client.SaveSession(ctx, w.workspaceID(), sessionToProto(sess))
if err != nil {
return session.Session{}, err
}
- return *saved, nil
+ return protoToSession(*saved), nil
}
func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) error {
@@ -125,15 +133,27 @@ func (w *ClientWorkspace) ParseAgentToolSessionID(sessionID string) (string, str
// -- Messages --
func (w *ClientWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
- return w.client.ListMessages(ctx, w.workspaceID(), sessionID)
+ msgs, err := w.client.ListMessages(ctx, w.workspaceID(), sessionID)
+ if err != nil {
+ return nil, err
+ }
+ return protoToMessages(msgs), nil
}
func (w *ClientWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
- return w.client.ListUserMessages(ctx, w.workspaceID(), sessionID)
+ msgs, err := w.client.ListUserMessages(ctx, w.workspaceID(), sessionID)
+ if err != nil {
+ return nil, err
+ }
+ return protoToMessages(msgs), nil
}
func (w *ClientWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) {
- return w.client.ListAllUserMessages(ctx, w.workspaceID())
+ msgs, err := w.client.ListAllUserMessages(ctx, w.workspaceID())
+ if err != nil {
+ return nil, err
+ }
+ return protoToMessages(msgs), nil
}
// -- Agent --
@@ -304,7 +324,11 @@ func (w *ClientWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionI
// -- History --
func (w *ClientWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) {
- return w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID)
+ files, err := w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID)
+ if err != nil {
+ return nil, err
+ }
+ return protoToFiles(files), nil
}
// -- LSP --
@@ -688,3 +712,34 @@ func protoToMessage(m proto.Message) message.Message {
return msg
}
+
+func protoToMessages(msgs []proto.Message) []message.Message {
+ out := make([]message.Message, len(msgs))
+ for i, m := range msgs {
+ out[i] = protoToMessage(m)
+ }
+ return out
+}
+
+func protoToFiles(files []proto.File) []history.File {
+ out := make([]history.File, len(files))
+ for i, f := range files {
+ out[i] = protoToFile(f)
+ }
+ return out
+}
+
+func sessionToProto(s session.Session) proto.Session {
+ return proto.Session{
+ ID: s.ID,
+ ParentSessionID: s.ParentSessionID,
+ Title: s.Title,
+ SummaryMessageID: s.SummaryMessageID,
+ MessageCount: s.MessageCount,
+ PromptTokens: s.PromptTokens,
+ CompletionTokens: s.CompletionTokens,
+ Cost: s.Cost,
+ CreatedAt: s.CreatedAt,
+ UpdatedAt: s.UpdatedAt,
+ }
+}