sessions.go

  1package cmd
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"os"
  8	"strings"
  9	"time"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/db"
 13	"github.com/charmbracelet/crush/internal/message"
 14	"github.com/charmbracelet/crush/internal/session"
 15	"github.com/spf13/cobra"
 16	"gopkg.in/yaml.v3"
 17)
 18
 19// SessionWithChildren represents a session with its nested children
 20type SessionWithChildren struct {
 21	session.Session
 22	Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"`
 23}
 24
 25var sessionsCmd = &cobra.Command{
 26	Use:   "sessions",
 27	Short: "Manage sessions",
 28	Long:  `List and export sessions and their nested subsessions`,
 29}
 30
 31var listCmd = &cobra.Command{
 32	Use:   "list",
 33	Short: "List sessions",
 34	Long:  `List all sessions in a hierarchical format`,
 35	RunE: func(cmd *cobra.Command, args []string) error {
 36		format, _ := cmd.Flags().GetString("format")
 37		return runSessionsList(cmd.Context(), format)
 38	},
 39}
 40
 41var exportCmd = &cobra.Command{
 42	Use:   "export",
 43	Short: "Export sessions",
 44	Long:  `Export all sessions and their nested subsessions to different formats`,
 45	RunE: func(cmd *cobra.Command, args []string) error {
 46		format, _ := cmd.Flags().GetString("format")
 47		return runSessionsExport(cmd.Context(), format)
 48	},
 49}
 50
 51var exportConversationCmd = &cobra.Command{
 52	Use:   "export-conversation <session-id>",
 53	Short: "Export a single conversation",
 54	Long:  `Export a single session with all its messages as markdown for sharing`,
 55	Args:  cobra.ExactArgs(1),
 56	RunE: func(cmd *cobra.Command, args []string) error {
 57		sessionID := args[0]
 58		format, _ := cmd.Flags().GetString("format")
 59		return runExportConversation(cmd.Context(), sessionID, format)
 60	},
 61}
 62
 63func init() {
 64	rootCmd.AddCommand(sessionsCmd)
 65	sessionsCmd.AddCommand(listCmd)
 66	sessionsCmd.AddCommand(exportCmd)
 67	sessionsCmd.AddCommand(exportConversationCmd)
 68
 69	listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)")
 70	exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)")
 71	exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)")
 72}
 73
 74func runSessionsList(ctx context.Context, format string) error {
 75	sessionService, err := createSessionService(ctx)
 76	if err != nil {
 77		return err
 78	}
 79
 80	sessions, err := buildSessionTree(ctx, sessionService)
 81	if err != nil {
 82		return err
 83	}
 84
 85	return formatOutput(sessions, format, false)
 86}
 87
 88func runSessionsExport(ctx context.Context, format string) error {
 89	sessionService, err := createSessionService(ctx)
 90	if err != nil {
 91		return err
 92	}
 93
 94	sessions, err := buildSessionTree(ctx, sessionService)
 95	if err != nil {
 96		return err
 97	}
 98
 99	return formatOutput(sessions, format, true)
100}
101
102func runExportConversation(ctx context.Context, sessionID, format string) error {
103	sessionService, messageService, err := createServices(ctx)
104	if err != nil {
105		return err
106	}
107
108	// Get the session
109	sess, err := sessionService.Get(ctx, sessionID)
110	if err != nil {
111		return fmt.Errorf("failed to get session %s: %w", sessionID, err)
112	}
113
114	// Get all messages for the session
115	messages, err := messageService.List(ctx, sessionID)
116	if err != nil {
117		return fmt.Errorf("failed to get messages for session %s: %w", sessionID, err)
118	}
119
120	return formatConversation(sess, messages, format)
121}
122
123func createSessionService(ctx context.Context) (session.Service, error) {
124	cwd, err := getCwd()
125	if err != nil {
126		return nil, err
127	}
128
129	cfg, err := config.Init(cwd, false)
130	if err != nil {
131		return nil, err
132	}
133
134	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
135	if err != nil {
136		return nil, err
137	}
138
139	queries := db.New(conn)
140	return session.NewService(queries), nil
141}
142
143func createServices(ctx context.Context) (session.Service, message.Service, error) {
144	cwd, err := getCwd()
145	if err != nil {
146		return nil, nil, err
147	}
148
149	cfg, err := config.Init(cwd, false)
150	if err != nil {
151		return nil, nil, err
152	}
153
154	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
155	if err != nil {
156		return nil, nil, err
157	}
158
159	queries := db.New(conn)
160	sessionService := session.NewService(queries)
161	messageService := message.NewService(queries)
162	return sessionService, messageService, nil
163}
164
165func getCwd() (string, error) {
166	// This could be enhanced to use the same logic as root.go
167	cwd, err := getCwdFromFlags()
168	if err != nil {
169		return "", err
170	}
171	return cwd, nil
172}
173
174func getCwdFromFlags() (string, error) {
175	return os.Getwd()
176}
177
178func buildSessionTree(ctx context.Context, sessionService session.Service) ([]SessionWithChildren, error) {
179	// Get all top-level sessions (no parent)
180	topLevelSessions, err := sessionService.List(ctx)
181	if err != nil {
182		return nil, fmt.Errorf("failed to list sessions: %w", err)
183	}
184
185	var result []SessionWithChildren
186	for _, sess := range topLevelSessions {
187		sessionWithChildren, err := buildSessionWithChildren(ctx, sessionService, sess)
188		if err != nil {
189			return nil, err
190		}
191		result = append(result, sessionWithChildren)
192	}
193
194	return result, nil
195}
196
197func buildSessionWithChildren(ctx context.Context, sessionService session.Service, sess session.Session) (SessionWithChildren, error) {
198	children, err := sessionService.ListChildren(ctx, sess.ID)
199	if err != nil {
200		return SessionWithChildren{}, fmt.Errorf("failed to list children for session %s: %w", sess.ID, err)
201	}
202
203	var childrenWithChildren []SessionWithChildren
204	for _, child := range children {
205		childWithChildren, err := buildSessionWithChildren(ctx, sessionService, child)
206		if err != nil {
207			return SessionWithChildren{}, err
208		}
209		childrenWithChildren = append(childrenWithChildren, childWithChildren)
210	}
211
212	return SessionWithChildren{
213		Session:  sess,
214		Children: childrenWithChildren,
215	}, nil
216}
217
218func formatOutput(sessions []SessionWithChildren, format string, includeMetadata bool) error {
219	switch strings.ToLower(format) {
220	case "json":
221		return formatJSON(sessions)
222	case "yaml":
223		return formatYAML(sessions)
224	case "markdown", "md":
225		return formatMarkdown(sessions, includeMetadata)
226	case "text":
227		return formatText(sessions)
228	default:
229		return fmt.Errorf("unsupported format: %s", format)
230	}
231}
232
233func formatJSON(sessions []SessionWithChildren) error {
234	data, err := json.MarshalIndent(sessions, "", "  ")
235	if err != nil {
236		return fmt.Errorf("failed to marshal JSON: %w", err)
237	}
238	fmt.Println(string(data))
239	return nil
240}
241
242func formatYAML(sessions []SessionWithChildren) error {
243	data, err := yaml.Marshal(sessions)
244	if err != nil {
245		return fmt.Errorf("failed to marshal YAML: %w", err)
246	}
247	fmt.Println(string(data))
248	return nil
249}
250
251func formatMarkdown(sessions []SessionWithChildren, includeMetadata bool) error {
252	fmt.Println("# Sessions")
253	fmt.Println()
254
255	if len(sessions) == 0 {
256		fmt.Println("No sessions found.")
257		return nil
258	}
259
260	for _, sess := range sessions {
261		printSessionMarkdown(sess, 0, includeMetadata)
262	}
263
264	return nil
265}
266
267func formatText(sessions []SessionWithChildren) error {
268	if len(sessions) == 0 {
269		fmt.Println("No sessions found.")
270		return nil
271	}
272
273	for _, sess := range sessions {
274		printSessionText(sess, 0)
275	}
276
277	return nil
278}
279
280func printSessionMarkdown(sess SessionWithChildren, level int, includeMetadata bool) {
281	indent := strings.Repeat("#", level+2)
282	fmt.Printf("%s %s\n", indent, sess.Title)
283	fmt.Println()
284
285	if includeMetadata {
286		fmt.Printf("- **ID**: %s\n", sess.ID)
287		if sess.ParentSessionID != "" {
288			fmt.Printf("- **Parent**: %s\n", sess.ParentSessionID)
289		}
290		fmt.Printf("- **Messages**: %d\n", sess.MessageCount)
291		fmt.Printf("- **Tokens**: %d prompt, %d completion\n", sess.PromptTokens, sess.CompletionTokens)
292		fmt.Printf("- **Cost**: $%.4f\n", sess.Cost)
293		fmt.Printf("- **Created**: %s\n", formatTimestamp(sess.CreatedAt))
294		fmt.Printf("- **Updated**: %s\n", formatTimestamp(sess.UpdatedAt))
295		fmt.Println()
296	}
297
298	for _, child := range sess.Children {
299		printSessionMarkdown(child, level+1, includeMetadata)
300	}
301}
302
303func printSessionText(sess SessionWithChildren, level int) {
304	indent := strings.Repeat("  ", level)
305	fmt.Printf("%s• %s (ID: %s, Messages: %d, Cost: $%.4f)\n",
306		indent, sess.Title, sess.ID, sess.MessageCount, sess.Cost)
307
308	for _, child := range sess.Children {
309		printSessionText(child, level+1)
310	}
311}
312
313func formatTimestamp(timestamp int64) string {
314	// Assuming timestamp is Unix seconds
315	return time.Unix(timestamp, 0).Format("2006-01-02 15:04:05")
316}
317
318func formatConversation(sess session.Session, messages []message.Message, format string) error {
319	switch strings.ToLower(format) {
320	case "markdown", "md":
321		return formatConversationMarkdown(sess, messages)
322	case "json":
323		return formatConversationJSON(sess, messages)
324	case "yaml":
325		return formatConversationYAML(sess, messages)
326	default:
327		return fmt.Errorf("unsupported format: %s", format)
328	}
329}
330
331func formatConversationMarkdown(sess session.Session, messages []message.Message) error {
332	fmt.Printf("# %s\n\n", sess.Title)
333
334	// Session metadata
335	fmt.Printf("**Session ID:** %s  \n", sess.ID)
336	fmt.Printf("**Created:** %s  \n", formatTimestamp(sess.CreatedAt))
337	fmt.Printf("**Messages:** %d  \n", sess.MessageCount)
338	fmt.Printf("**Tokens:** %d prompt, %d completion  \n", sess.PromptTokens, sess.CompletionTokens)
339	if sess.Cost > 0 {
340		fmt.Printf("**Cost:** $%.4f  \n", sess.Cost)
341	}
342	fmt.Println()
343	fmt.Println("---")
344	fmt.Println()
345
346	for i, msg := range messages {
347		formatMessageMarkdown(msg, i+1)
348	}
349
350	return nil
351}
352
353func formatMessageMarkdown(msg message.Message, index int) {
354	// Role header
355	switch msg.Role {
356	case message.User:
357		fmt.Printf("## šŸ‘¤ User\n\n")
358	case message.Assistant:
359		fmt.Printf("## šŸ¤– Assistant")
360		if msg.Model != "" {
361			fmt.Printf(" (%s)", msg.Model)
362		}
363		fmt.Printf("\n\n")
364	case message.System:
365		fmt.Printf("## āš™ļø System\n\n")
366	case message.Tool:
367		fmt.Printf("## šŸ”§ Tool\n\n")
368	}
369
370	// Process each part
371	for _, part := range msg.Parts {
372		switch p := part.(type) {
373		case message.TextContent:
374			fmt.Printf("%s\n\n", p.Text)
375		case message.ReasoningContent:
376			if p.Thinking != "" {
377				fmt.Printf("### 🧠 Reasoning\n\n")
378				fmt.Printf("```\n%s\n```\n\n", p.Thinking)
379			}
380		case message.ToolCall:
381			fmt.Printf("### šŸ”§ Tool Call: %s\n\n", p.Name)
382			fmt.Printf("**ID:** %s  \n", p.ID)
383			if p.Input != "" {
384				fmt.Printf("**Input:**\n```json\n%s\n```\n\n", p.Input)
385			}
386		case message.ToolResult:
387			fmt.Printf("### šŸ“ Tool Result: %s\n\n", p.Name)
388			if p.IsError {
389				fmt.Printf("**āŒ Error:**\n```\n%s\n```\n\n", p.Content)
390			} else {
391				fmt.Printf("**āœ… Result:**\n```\n%s\n```\n\n", p.Content)
392			}
393		case message.ImageURLContent:
394			fmt.Printf("![Image](%s)\n\n", p.URL)
395		case message.BinaryContent:
396			fmt.Printf("**File:** %s (%s)\n\n", p.Path, p.MIMEType)
397		case message.Finish:
398			if p.Reason != message.FinishReasonEndTurn {
399				fmt.Printf("**Finish Reason:** %s\n", p.Reason)
400				if p.Message != "" {
401					fmt.Printf("**Message:** %s\n", p.Message)
402				}
403				fmt.Println()
404			}
405		}
406	}
407
408	fmt.Println("---")
409	fmt.Println()
410}
411
412func formatConversationJSON(sess session.Session, messages []message.Message) error {
413	data := struct {
414		Session  session.Session   `json:"session"`
415		Messages []message.Message `json:"messages"`
416	}{
417		Session:  sess,
418		Messages: messages,
419	}
420
421	jsonData, err := json.MarshalIndent(data, "", "  ")
422	if err != nil {
423		return fmt.Errorf("failed to marshal JSON: %w", err)
424	}
425	fmt.Println(string(jsonData))
426	return nil
427}
428
429func formatConversationYAML(sess session.Session, messages []message.Message) error {
430	data := struct {
431		Session  session.Session   `yaml:"session"`
432		Messages []message.Message `yaml:"messages"`
433	}{
434		Session:  sess,
435		Messages: messages,
436	}
437
438	yamlData, err := yaml.Marshal(data)
439	if err != nil {
440		return fmt.Errorf("failed to marshal YAML: %w", err)
441	}
442	fmt.Println(string(yamlData))
443	return nil
444}