From 797788a08658635315f0c7c9442257f9c4841d85 Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Tue, 22 Jul 2025 10:13:43 -0400 Subject: [PATCH] Add verbs to list and export sessions --- internal/cmd/sessions.go | 444 +++++++++++++++++++++++++++++++++++ internal/db/db.go | 22 +- internal/db/files.sql.go | 2 +- internal/db/messages.sql.go | 2 +- internal/db/models.go | 2 +- internal/db/querier.go | 5 +- internal/db/sessions.sql.go | 83 ++++++- internal/db/sql/sessions.sql | 11 + internal/session/session.go | 29 +++ 9 files changed, 594 insertions(+), 6 deletions(-) create mode 100644 internal/cmd/sessions.go diff --git a/internal/cmd/sessions.go b/internal/cmd/sessions.go new file mode 100644 index 0000000000000000000000000000000000000000..2937e70c5f5db9fae6cfc1d77574a81925049b1b --- /dev/null +++ b/internal/cmd/sessions.go @@ -0,0 +1,444 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/session" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +// SessionWithChildren represents a session with its nested children +type SessionWithChildren struct { + session.Session + Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"` +} + +var sessionsCmd = &cobra.Command{ + Use: "sessions", + Short: "Manage sessions", + Long: `List and export sessions and their nested subsessions`, +} + +var listCmd = &cobra.Command{ + Use: "list", + Short: "List sessions", + Long: `List all sessions in a hierarchical format`, + RunE: func(cmd *cobra.Command, args []string) error { + format, _ := cmd.Flags().GetString("format") + return runSessionsList(cmd.Context(), format) + }, +} + +var exportCmd = &cobra.Command{ + Use: "export", + Short: "Export sessions", + Long: `Export all sessions and their nested subsessions to different formats`, + RunE: func(cmd *cobra.Command, args []string) error { + format, _ := cmd.Flags().GetString("format") + return runSessionsExport(cmd.Context(), format) + }, +} + +var exportConversationCmd = &cobra.Command{ + Use: "export-conversation ", + Short: "Export a single conversation", + Long: `Export a single session with all its messages as markdown for sharing`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + sessionID := args[0] + format, _ := cmd.Flags().GetString("format") + return runExportConversation(cmd.Context(), sessionID, format) + }, +} + +func init() { + rootCmd.AddCommand(sessionsCmd) + sessionsCmd.AddCommand(listCmd) + sessionsCmd.AddCommand(exportCmd) + sessionsCmd.AddCommand(exportConversationCmd) + + listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)") + exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)") + exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)") +} + +func runSessionsList(ctx context.Context, format string) error { + sessionService, err := createSessionService(ctx) + if err != nil { + return err + } + + sessions, err := buildSessionTree(ctx, sessionService) + if err != nil { + return err + } + + return formatOutput(sessions, format, false) +} + +func runSessionsExport(ctx context.Context, format string) error { + sessionService, err := createSessionService(ctx) + if err != nil { + return err + } + + sessions, err := buildSessionTree(ctx, sessionService) + if err != nil { + return err + } + + return formatOutput(sessions, format, true) +} + +func runExportConversation(ctx context.Context, sessionID, format string) error { + sessionService, messageService, err := createServices(ctx) + if err != nil { + return err + } + + // Get the session + sess, err := sessionService.Get(ctx, sessionID) + if err != nil { + return fmt.Errorf("failed to get session %s: %w", sessionID, err) + } + + // Get all messages for the session + messages, err := messageService.List(ctx, sessionID) + if err != nil { + return fmt.Errorf("failed to get messages for session %s: %w", sessionID, err) + } + + return formatConversation(sess, messages, format) +} + +func createSessionService(ctx context.Context) (session.Service, error) { + cwd, err := getCwd() + if err != nil { + return nil, err + } + + cfg, err := config.Init(cwd, false) + if err != nil { + return nil, err + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return nil, err + } + + queries := db.New(conn) + return session.NewService(queries), nil +} + +func createServices(ctx context.Context) (session.Service, message.Service, error) { + cwd, err := getCwd() + if err != nil { + return nil, nil, err + } + + cfg, err := config.Init(cwd, false) + if err != nil { + return nil, nil, err + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return nil, nil, err + } + + queries := db.New(conn) + sessionService := session.NewService(queries) + messageService := message.NewService(queries) + return sessionService, messageService, nil +} + +func getCwd() (string, error) { + // This could be enhanced to use the same logic as root.go + cwd, err := getCwdFromFlags() + if err != nil { + return "", err + } + return cwd, nil +} + +func getCwdFromFlags() (string, error) { + return os.Getwd() +} + +func buildSessionTree(ctx context.Context, sessionService session.Service) ([]SessionWithChildren, error) { + // Get all top-level sessions (no parent) + topLevelSessions, err := sessionService.List(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list sessions: %w", err) + } + + var result []SessionWithChildren + for _, sess := range topLevelSessions { + sessionWithChildren, err := buildSessionWithChildren(ctx, sessionService, sess) + if err != nil { + return nil, err + } + result = append(result, sessionWithChildren) + } + + return result, nil +} + +func buildSessionWithChildren(ctx context.Context, sessionService session.Service, sess session.Session) (SessionWithChildren, error) { + children, err := sessionService.ListChildren(ctx, sess.ID) + if err != nil { + return SessionWithChildren{}, fmt.Errorf("failed to list children for session %s: %w", sess.ID, err) + } + + var childrenWithChildren []SessionWithChildren + for _, child := range children { + childWithChildren, err := buildSessionWithChildren(ctx, sessionService, child) + if err != nil { + return SessionWithChildren{}, err + } + childrenWithChildren = append(childrenWithChildren, childWithChildren) + } + + return SessionWithChildren{ + Session: sess, + Children: childrenWithChildren, + }, nil +} + +func formatOutput(sessions []SessionWithChildren, format string, includeMetadata bool) error { + switch strings.ToLower(format) { + case "json": + return formatJSON(sessions) + case "yaml": + return formatYAML(sessions) + case "markdown", "md": + return formatMarkdown(sessions, includeMetadata) + case "text": + return formatText(sessions) + default: + return fmt.Errorf("unsupported format: %s", format) + } +} + +func formatJSON(sessions []SessionWithChildren) error { + data, err := json.MarshalIndent(sessions, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + fmt.Println(string(data)) + return nil +} + +func formatYAML(sessions []SessionWithChildren) error { + data, err := yaml.Marshal(sessions) + if err != nil { + return fmt.Errorf("failed to marshal YAML: %w", err) + } + fmt.Println(string(data)) + return nil +} + +func formatMarkdown(sessions []SessionWithChildren, includeMetadata bool) error { + fmt.Println("# Sessions") + fmt.Println() + + if len(sessions) == 0 { + fmt.Println("No sessions found.") + return nil + } + + for _, sess := range sessions { + printSessionMarkdown(sess, 0, includeMetadata) + } + + return nil +} + +func formatText(sessions []SessionWithChildren) error { + if len(sessions) == 0 { + fmt.Println("No sessions found.") + return nil + } + + for _, sess := range sessions { + printSessionText(sess, 0) + } + + return nil +} + +func printSessionMarkdown(sess SessionWithChildren, level int, includeMetadata bool) { + indent := strings.Repeat("#", level+2) + fmt.Printf("%s %s\n", indent, sess.Title) + fmt.Println() + + if includeMetadata { + fmt.Printf("- **ID**: %s\n", sess.ID) + if sess.ParentSessionID != "" { + fmt.Printf("- **Parent**: %s\n", sess.ParentSessionID) + } + fmt.Printf("- **Messages**: %d\n", sess.MessageCount) + fmt.Printf("- **Tokens**: %d prompt, %d completion\n", sess.PromptTokens, sess.CompletionTokens) + fmt.Printf("- **Cost**: $%.4f\n", sess.Cost) + fmt.Printf("- **Created**: %s\n", formatTimestamp(sess.CreatedAt)) + fmt.Printf("- **Updated**: %s\n", formatTimestamp(sess.UpdatedAt)) + fmt.Println() + } + + for _, child := range sess.Children { + printSessionMarkdown(child, level+1, includeMetadata) + } +} + +func printSessionText(sess SessionWithChildren, level int) { + indent := strings.Repeat(" ", level) + fmt.Printf("%s• %s (ID: %s, Messages: %d, Cost: $%.4f)\n", + indent, sess.Title, sess.ID, sess.MessageCount, sess.Cost) + + for _, child := range sess.Children { + printSessionText(child, level+1) + } +} + +func formatTimestamp(timestamp int64) string { + // Assuming timestamp is Unix seconds + return time.Unix(timestamp, 0).Format("2006-01-02 15:04:05") +} + +func formatConversation(sess session.Session, messages []message.Message, format string) error { + switch strings.ToLower(format) { + case "markdown", "md": + return formatConversationMarkdown(sess, messages) + case "json": + return formatConversationJSON(sess, messages) + case "yaml": + return formatConversationYAML(sess, messages) + default: + return fmt.Errorf("unsupported format: %s", format) + } +} + +func formatConversationMarkdown(sess session.Session, messages []message.Message) error { + fmt.Printf("# %s\n\n", sess.Title) + + // Session metadata + fmt.Printf("**Session ID:** %s \n", sess.ID) + fmt.Printf("**Created:** %s \n", formatTimestamp(sess.CreatedAt)) + fmt.Printf("**Messages:** %d \n", sess.MessageCount) + fmt.Printf("**Tokens:** %d prompt, %d completion \n", sess.PromptTokens, sess.CompletionTokens) + if sess.Cost > 0 { + fmt.Printf("**Cost:** $%.4f \n", sess.Cost) + } + fmt.Println() + fmt.Println("---") + fmt.Println() + + for i, msg := range messages { + formatMessageMarkdown(msg, i+1) + } + + return nil +} + +func formatMessageMarkdown(msg message.Message, index int) { + // Role header + switch msg.Role { + case message.User: + fmt.Printf("## šŸ‘¤ User\n\n") + case message.Assistant: + fmt.Printf("## šŸ¤– Assistant") + if msg.Model != "" { + fmt.Printf(" (%s)", msg.Model) + } + fmt.Printf("\n\n") + case message.System: + fmt.Printf("## āš™ļø System\n\n") + case message.Tool: + fmt.Printf("## šŸ”§ Tool\n\n") + } + + // Process each part + for _, part := range msg.Parts { + switch p := part.(type) { + case message.TextContent: + fmt.Printf("%s\n\n", p.Text) + case message.ReasoningContent: + if p.Thinking != "" { + fmt.Printf("### 🧠 Reasoning\n\n") + fmt.Printf("```\n%s\n```\n\n", p.Thinking) + } + case message.ToolCall: + fmt.Printf("### šŸ”§ Tool Call: %s\n\n", p.Name) + fmt.Printf("**ID:** %s \n", p.ID) + if p.Input != "" { + fmt.Printf("**Input:**\n```json\n%s\n```\n\n", p.Input) + } + case message.ToolResult: + fmt.Printf("### šŸ“ Tool Result: %s\n\n", p.Name) + if p.IsError { + fmt.Printf("**āŒ Error:**\n```\n%s\n```\n\n", p.Content) + } else { + fmt.Printf("**āœ… Result:**\n```\n%s\n```\n\n", p.Content) + } + case message.ImageURLContent: + fmt.Printf("![Image](%s)\n\n", p.URL) + case message.BinaryContent: + fmt.Printf("**File:** %s (%s)\n\n", p.Path, p.MIMEType) + case message.Finish: + if p.Reason != message.FinishReasonEndTurn { + fmt.Printf("**Finish Reason:** %s\n", p.Reason) + if p.Message != "" { + fmt.Printf("**Message:** %s\n", p.Message) + } + fmt.Println() + } + } + } + + fmt.Println("---") + fmt.Println() +} + +func formatConversationJSON(sess session.Session, messages []message.Message) error { + data := struct { + Session session.Session `json:"session"` + Messages []message.Message `json:"messages"` + }{ + Session: sess, + Messages: messages, + } + + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + fmt.Println(string(jsonData)) + return nil +} + +func formatConversationYAML(sess session.Session, messages []message.Message) error { + data := struct { + Session session.Session `yaml:"session"` + Messages []message.Message `yaml:"messages"` + }{ + Session: sess, + Messages: messages, + } + + yamlData, err := yaml.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal YAML: %w", err) + } + fmt.Println(string(yamlData)) + return nil +} diff --git a/internal/db/db.go b/internal/db/db.go index 62ebe0134c683f2a3f69d26ea3f826c9bbf02d14..3c83e53a8bae375f59ef32775e2511864eec3d7e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.22.0 package db @@ -60,6 +60,12 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil { return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err) } + if q.listAllSessionsStmt, err = db.PrepareContext(ctx, listAllSessions); err != nil { + return nil, fmt.Errorf("error preparing query ListAllSessions: %w", err) + } + if q.listChildSessionsStmt, err = db.PrepareContext(ctx, listChildSessions); err != nil { + return nil, fmt.Errorf("error preparing query ListChildSessions: %w", err) + } if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil { return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err) } @@ -149,6 +155,16 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr) } } + if q.listAllSessionsStmt != nil { + if cerr := q.listAllSessionsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listAllSessionsStmt: %w", cerr) + } + } + if q.listChildSessionsStmt != nil { + if cerr := q.listChildSessionsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listChildSessionsStmt: %w", cerr) + } + } if q.listFilesByPathStmt != nil { if cerr := q.listFilesByPathStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr) @@ -240,6 +256,8 @@ type Queries struct { getFileByPathAndSessionStmt *sql.Stmt getMessageStmt *sql.Stmt getSessionByIDStmt *sql.Stmt + listAllSessionsStmt *sql.Stmt + listChildSessionsStmt *sql.Stmt listFilesByPathStmt *sql.Stmt listFilesBySessionStmt *sql.Stmt listLatestSessionFilesStmt *sql.Stmt @@ -266,6 +284,8 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, getMessageStmt: q.getMessageStmt, getSessionByIDStmt: q.getSessionByIDStmt, + listAllSessionsStmt: q.listAllSessionsStmt, + listChildSessionsStmt: q.listChildSessionsStmt, listFilesByPathStmt: q.listFilesByPathStmt, listFilesBySessionStmt: q.listFilesBySessionStmt, listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go index a52516d20edb189e476ad41bbc7486b2ea8cc18b..cebc8c4a638ed64a3210c2fd7bc53d471d3a1181 100644 --- a/internal/db/files.sql.go +++ b/internal/db/files.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.22.0 // source: files.sql package db diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index 81f322921db87dde7ade48ce64322aa01004d255..31f546d6d6b28fd9a0b58582d10b85a3edaf10bb 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.22.0 // source: messages.sql package db diff --git a/internal/db/models.go b/internal/db/models.go index ec3e6e10ad990d0f1a3d03a7533c8b1aed184447..a814add2d8c572a4383ae2a1e802a5de913fde63 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.22.0 package db diff --git a/internal/db/querier.go b/internal/db/querier.go index 472137273387d85a83a27260037321adccc9230f..6d8a2a4f5f7479ce2e2e469728f3fa860d0479d6 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -1,11 +1,12 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.22.0 package db import ( "context" + "database/sql" ) type Querier interface { @@ -21,6 +22,8 @@ type Querier interface { GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) GetMessage(ctx context.Context, id string) (Message, error) GetSessionByID(ctx context.Context, id string) (Session, error) + ListAllSessions(ctx context.Context) ([]Session, error) + ListChildSessions(ctx context.Context, parentSessionID sql.NullString) ([]Session, error) ListFilesByPath(ctx context.Context, path string) ([]File, error) ListFilesBySession(ctx context.Context, sessionID string) ([]File, error) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 76ef6480b8e435cff66f29f7a1912aa5db5b9e9d..f61adfa0c59727b3b16fe310fa5d2b8ef8d47823 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.22.0 // source: sessions.sql package db @@ -106,6 +106,87 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error return i, err } +const listAllSessions = `-- name: ListAllSessions :many +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id +FROM sessions +ORDER BY created_at DESC +` + +func (q *Queries) ListAllSessions(ctx context.Context) ([]Session, error) { + rows, err := q.query(ctx, q.listAllSessionsStmt, listAllSessions) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listChildSessions = `-- name: ListChildSessions :many +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id +FROM sessions +WHERE parent_session_id = ? +ORDER BY created_at ASC +` + +func (q *Queries) ListChildSessions(ctx context.Context, parentSessionID sql.NullString) ([]Session, error) { + rows, err := q.query(ctx, q.listChildSessionsStmt, listChildSessions, parentSessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listSessions = `-- name: ListSessions :many SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id FROM sessions diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index ebeab90d39f641c0aee72152c1f60ef455d5dff4..b2a165c6fc9cf4979d5d9591fb51da5d7d1a6a23 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -49,3 +49,14 @@ RETURNING *; -- name: DeleteSession :exec DELETE FROM sessions WHERE id = ?; + +-- name: ListChildSessions :many +SELECT * +FROM sessions +WHERE parent_session_id = ? +ORDER BY created_at ASC; + +-- name: ListAllSessions :many +SELECT * +FROM sessions +ORDER BY created_at DESC; diff --git a/internal/session/session.go b/internal/session/session.go index d988dac3414fa7dd00d13b375e1309f8d6c515dd..d98444d3265fe82984493b4fd483b58d042bc766 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -29,6 +29,8 @@ type Service interface { CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) Get(ctx context.Context, id string) (Session, error) List(ctx context.Context) ([]Session, error) + ListAll(ctx context.Context) ([]Session, error) + ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) Delete(ctx context.Context, id string) error } @@ -132,6 +134,33 @@ func (s *service) List(ctx context.Context) ([]Session, error) { return sessions, nil } +func (s *service) ListAll(ctx context.Context) ([]Session, error) { + dbSessions, err := s.q.ListAllSessions(ctx) + if err != nil { + return nil, err + } + sessions := make([]Session, len(dbSessions)) + for i, dbSession := range dbSessions { + sessions[i] = s.fromDBItem(dbSession) + } + return sessions, nil +} + +func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) { + dbSessions, err := s.q.ListChildSessions(ctx, sql.NullString{ + String: parentSessionID, + Valid: true, + }) + if err != nil { + return nil, err + } + sessions := make([]Session, len(dbSessions)) + for i, dbSession := range dbSessions { + sessions[i] = s.fromDBItem(dbSession) + } + return sessions, nil +} + func (s service) fromDBItem(item db.Session) Session { return Session{ ID: item.ID,