1package cmd
2
3import (
4 "bytes"
5 "context"
6 "database/sql"
7 "encoding/json"
8 "fmt"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/db"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/charmbracelet/crush/internal/session"
18 "github.com/google/uuid"
19 "github.com/spf13/cobra"
20 "gopkg.in/yaml.v3"
21)
22
23// SessionWithChildren represents a session with its nested children
24type SessionWithChildren struct {
25 session.Session
26 Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"`
27}
28
29// ImportSession represents a session with proper JSON tags for import
30type ImportSession struct {
31 ID string `json:"id"`
32 ParentSessionID string `json:"parent_session_id"`
33 Title string `json:"title"`
34 MessageCount int64 `json:"message_count"`
35 PromptTokens int64 `json:"prompt_tokens"`
36 CompletionTokens int64 `json:"completion_tokens"`
37 Cost float64 `json:"cost"`
38 CreatedAt int64 `json:"created_at"`
39 UpdatedAt int64 `json:"updated_at"`
40 SummaryMessageID string `json:"summary_message_id,omitempty"`
41 Children []ImportSession `json:"children,omitempty"`
42}
43
44// ImportData represents the full import structure for sessions
45type ImportData struct {
46 Version string `json:"version" yaml:"version"`
47 ExportedAt string `json:"exported_at,omitempty" yaml:"exported_at,omitempty"`
48 TotalSessions int `json:"total_sessions,omitempty" yaml:"total_sessions,omitempty"`
49 Sessions []ImportSession `json:"sessions" yaml:"sessions"`
50}
51
52// ImportMessage represents a message with proper JSON tags for import
53type ImportMessage struct {
54 ID string `json:"id"`
55 Role string `json:"role"`
56 SessionID string `json:"session_id"`
57 Parts []interface{} `json:"parts"`
58 Model string `json:"model,omitempty"`
59 Provider string `json:"provider,omitempty"`
60 CreatedAt int64 `json:"created_at"`
61}
62
63// ImportSessionInfo represents session info with proper JSON tags for conversation import
64type ImportSessionInfo struct {
65 ID string `json:"id"`
66 Title string `json:"title"`
67 MessageCount int64 `json:"message_count"`
68 PromptTokens int64 `json:"prompt_tokens,omitempty"`
69 CompletionTokens int64 `json:"completion_tokens,omitempty"`
70 Cost float64 `json:"cost,omitempty"`
71 CreatedAt int64 `json:"created_at"`
72}
73
74// ConversationData represents a single conversation import structure
75type ConversationData struct {
76 Version string `json:"version" yaml:"version"`
77 Session ImportSessionInfo `json:"session" yaml:"session"`
78 Messages []ImportMessage `json:"messages" yaml:"messages"`
79}
80
81// ImportResult contains the results of an import operation
82type ImportResult struct {
83 TotalSessions int `json:"total_sessions"`
84 ImportedSessions int `json:"imported_sessions"`
85 SkippedSessions int `json:"skipped_sessions"`
86 ImportedMessages int `json:"imported_messages"`
87 Errors []string `json:"errors,omitempty"`
88 SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
89}
90
91// SessionStats represents aggregated session statistics
92type SessionStats struct {
93 TotalSessions int64 `json:"total_sessions"`
94 TotalMessages int64 `json:"total_messages"`
95 TotalPromptTokens int64 `json:"total_prompt_tokens"`
96 TotalCompletionTokens int64 `json:"total_completion_tokens"`
97 TotalCost float64 `json:"total_cost"`
98 AvgCostPerSession float64 `json:"avg_cost_per_session"`
99}
100
101// GroupedSessionStats represents statistics grouped by time period
102type GroupedSessionStats struct {
103 Period string `json:"period"`
104 SessionCount int64 `json:"session_count"`
105 MessageCount int64 `json:"message_count"`
106 PromptTokens int64 `json:"prompt_tokens"`
107 CompletionTokens int64 `json:"completion_tokens"`
108 TotalCost float64 `json:"total_cost"`
109 AvgCost float64 `json:"avg_cost"`
110}
111
112var sessionsCmd = &cobra.Command{
113 Use: "sessions",
114 Short: "Manage sessions",
115 Long: `List and export sessions and their nested subsessions`,
116}
117
118var listCmd = &cobra.Command{
119 Use: "list",
120 Short: "List sessions",
121 Long: `List all sessions in a hierarchical format`,
122 RunE: func(cmd *cobra.Command, args []string) error {
123 format, _ := cmd.Flags().GetString("format")
124 return runSessionsList(cmd.Context(), format)
125 },
126}
127
128var exportCmd = &cobra.Command{
129 Use: "export",
130 Short: "Export sessions",
131 Long: `Export all sessions and their nested subsessions to different formats`,
132 RunE: func(cmd *cobra.Command, args []string) error {
133 format, _ := cmd.Flags().GetString("format")
134 return runSessionsExport(cmd.Context(), format)
135 },
136}
137
138var exportConversationCmd = &cobra.Command{
139 Use: "export-conversation <session-id>",
140 Short: "Export a single conversation",
141 Long: `Export a single session with all its messages as markdown for sharing`,
142 Args: cobra.ExactArgs(1),
143 RunE: func(cmd *cobra.Command, args []string) error {
144 sessionID := args[0]
145 format, _ := cmd.Flags().GetString("format")
146 return runExportConversation(cmd.Context(), sessionID, format)
147 },
148}
149
150var importCmd = &cobra.Command{
151 Use: "import <file>",
152 Short: "Import sessions from a file",
153 Long: `Import sessions from a JSON or YAML file with hierarchical structure`,
154 Args: cobra.ExactArgs(1),
155 RunE: func(cmd *cobra.Command, args []string) error {
156 file := args[0]
157 format, _ := cmd.Flags().GetString("format")
158 dryRun, _ := cmd.Flags().GetBool("dry-run")
159 return runImport(cmd.Context(), file, format, dryRun)
160 },
161}
162
163var importConversationCmd = &cobra.Command{
164 Use: "import-conversation <file>",
165 Short: "Import a single conversation from a file",
166 Long: `Import a single conversation with messages from a JSON, YAML, or Markdown file`,
167 Args: cobra.ExactArgs(1),
168 RunE: func(cmd *cobra.Command, args []string) error {
169 file := args[0]
170 format, _ := cmd.Flags().GetString("format")
171 return runImportConversation(cmd.Context(), file, format)
172 },
173}
174
175var searchCmd = &cobra.Command{
176 Use: "search",
177 Short: "Search sessions by title or message content",
178 Long: `Search sessions by title pattern (case-insensitive) or message text content`,
179 RunE: func(cmd *cobra.Command, args []string) error {
180 titlePattern, _ := cmd.Flags().GetString("title")
181 textPattern, _ := cmd.Flags().GetString("text")
182 format, _ := cmd.Flags().GetString("format")
183
184 if titlePattern == "" && textPattern == "" {
185 return fmt.Errorf("at least one of --title or --text must be provided")
186 }
187
188 return runSessionsSearch(cmd.Context(), titlePattern, textPattern, format)
189 },
190}
191
192var statsCmd = &cobra.Command{
193 Use: "stats",
194 Short: "Show session statistics",
195 Long: `Display aggregated statistics about sessions including total counts, tokens, and costs`,
196 RunE: func(cmd *cobra.Command, args []string) error {
197 format, _ := cmd.Flags().GetString("format")
198 groupBy, _ := cmd.Flags().GetString("group-by")
199 return runSessionsStats(cmd.Context(), format, groupBy)
200 },
201}
202
203func init() {
204 rootCmd.AddCommand(sessionsCmd)
205 sessionsCmd.AddCommand(listCmd)
206 sessionsCmd.AddCommand(exportCmd)
207 sessionsCmd.AddCommand(exportConversationCmd)
208 sessionsCmd.AddCommand(importCmd)
209 sessionsCmd.AddCommand(importConversationCmd)
210 sessionsCmd.AddCommand(searchCmd)
211 sessionsCmd.AddCommand(statsCmd)
212
213 listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)")
214 exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)")
215 exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)")
216 importCmd.Flags().StringP("format", "f", "", "Import format (json, yaml) - auto-detected if not specified")
217 importCmd.Flags().Bool("dry-run", false, "Validate import data without persisting changes")
218 importConversationCmd.Flags().StringP("format", "f", "", "Import format (json, yaml, markdown) - auto-detected if not specified")
219 searchCmd.Flags().String("title", "", "Search by session title pattern (case-insensitive substring search)")
220 searchCmd.Flags().String("text", "", "Search by message text content")
221 searchCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
222 statsCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
223 statsCmd.Flags().String("group-by", "", "Group statistics by time period (day, week, month)")
224}
225
226func runSessionsList(ctx context.Context, format string) error {
227 sessionService, err := createSessionService(ctx)
228 if err != nil {
229 return err
230 }
231
232 sessions, err := buildSessionTree(ctx, sessionService)
233 if err != nil {
234 return err
235 }
236
237 return formatOutput(sessions, format, false)
238}
239
240func runSessionsExport(ctx context.Context, format string) error {
241 sessionService, err := createSessionService(ctx)
242 if err != nil {
243 return err
244 }
245
246 sessions, err := buildSessionTree(ctx, sessionService)
247 if err != nil {
248 return err
249 }
250
251 return formatOutput(sessions, format, true)
252}
253
254func runExportConversation(ctx context.Context, sessionID, format string) error {
255 sessionService, messageService, err := createServices(ctx)
256 if err != nil {
257 return err
258 }
259
260 // Get the session
261 sess, err := sessionService.Get(ctx, sessionID)
262 if err != nil {
263 return fmt.Errorf("failed to get session %s: %w", sessionID, err)
264 }
265
266 // Get all messages for the session
267 messages, err := messageService.List(ctx, sessionID)
268 if err != nil {
269 return fmt.Errorf("failed to get messages for session %s: %w", sessionID, err)
270 }
271
272 return formatConversation(sess, messages, format)
273}
274
275func createSessionService(ctx context.Context) (session.Service, error) {
276 cwd, err := getCwd()
277 if err != nil {
278 return nil, err
279 }
280
281 cfg, err := config.Init(cwd, false)
282 if err != nil {
283 return nil, err
284 }
285
286 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
287 if err != nil {
288 return nil, err
289 }
290
291 queries := db.New(conn)
292 return session.NewService(queries), nil
293}
294
295func createServices(ctx context.Context) (session.Service, message.Service, error) {
296 cwd, err := getCwd()
297 if err != nil {
298 return nil, nil, err
299 }
300
301 cfg, err := config.Init(cwd, false)
302 if err != nil {
303 return nil, nil, err
304 }
305
306 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
307 if err != nil {
308 return nil, nil, err
309 }
310
311 queries := db.New(conn)
312 sessionService := session.NewService(queries)
313 messageService := message.NewService(queries)
314 return sessionService, messageService, nil
315}
316
317func getCwd() (string, error) {
318 // This could be enhanced to use the same logic as root.go
319 cwd, err := getCwdFromFlags()
320 if err != nil {
321 return "", err
322 }
323 return cwd, nil
324}
325
326func getCwdFromFlags() (string, error) {
327 return os.Getwd()
328}
329
330func buildSessionTree(ctx context.Context, sessionService session.Service) ([]SessionWithChildren, error) {
331 // Get all top-level sessions (no parent)
332 topLevelSessions, err := sessionService.List(ctx)
333 if err != nil {
334 return nil, fmt.Errorf("failed to list sessions: %w", err)
335 }
336
337 var result []SessionWithChildren
338 for _, sess := range topLevelSessions {
339 sessionWithChildren, err := buildSessionWithChildren(ctx, sessionService, sess)
340 if err != nil {
341 return nil, err
342 }
343 result = append(result, sessionWithChildren)
344 }
345
346 return result, nil
347}
348
349func buildSessionWithChildren(ctx context.Context, sessionService session.Service, sess session.Session) (SessionWithChildren, error) {
350 children, err := sessionService.ListChildren(ctx, sess.ID)
351 if err != nil {
352 return SessionWithChildren{}, fmt.Errorf("failed to list children for session %s: %w", sess.ID, err)
353 }
354
355 var childrenWithChildren []SessionWithChildren
356 for _, child := range children {
357 childWithChildren, err := buildSessionWithChildren(ctx, sessionService, child)
358 if err != nil {
359 return SessionWithChildren{}, err
360 }
361 childrenWithChildren = append(childrenWithChildren, childWithChildren)
362 }
363
364 return SessionWithChildren{
365 Session: sess,
366 Children: childrenWithChildren,
367 }, nil
368}
369
370func formatOutput(sessions []SessionWithChildren, format string, includeMetadata bool) error {
371 switch strings.ToLower(format) {
372 case "json":
373 return formatJSON(sessions)
374 case "yaml":
375 return formatYAML(sessions)
376 case "markdown", "md":
377 return formatMarkdown(sessions, includeMetadata)
378 case "text":
379 return formatText(sessions)
380 default:
381 return fmt.Errorf("unsupported format: %s", format)
382 }
383}
384
385func formatJSON(sessions []SessionWithChildren) error {
386 data, err := json.MarshalIndent(sessions, "", " ")
387 if err != nil {
388 return fmt.Errorf("failed to marshal JSON: %w", err)
389 }
390 fmt.Println(string(data))
391 return nil
392}
393
394func formatYAML(sessions []SessionWithChildren) error {
395 data, err := yaml.Marshal(sessions)
396 if err != nil {
397 return fmt.Errorf("failed to marshal YAML: %w", err)
398 }
399 fmt.Println(string(data))
400 return nil
401}
402
403func formatMarkdown(sessions []SessionWithChildren, includeMetadata bool) error {
404 fmt.Println("# Sessions")
405 fmt.Println()
406
407 if len(sessions) == 0 {
408 fmt.Println("No sessions found.")
409 return nil
410 }
411
412 for _, sess := range sessions {
413 printSessionMarkdown(sess, 0, includeMetadata)
414 }
415
416 return nil
417}
418
419func formatText(sessions []SessionWithChildren) error {
420 if len(sessions) == 0 {
421 fmt.Println("No sessions found.")
422 return nil
423 }
424
425 for _, sess := range sessions {
426 printSessionText(sess, 0)
427 }
428
429 return nil
430}
431
432func printSessionMarkdown(sess SessionWithChildren, level int, includeMetadata bool) {
433 indent := strings.Repeat("#", level+2)
434 fmt.Printf("%s %s\n", indent, sess.Title)
435 fmt.Println()
436
437 if includeMetadata {
438 fmt.Printf("- **ID**: %s\n", sess.ID)
439 if sess.ParentSessionID != "" {
440 fmt.Printf("- **Parent**: %s\n", sess.ParentSessionID)
441 }
442 fmt.Printf("- **Messages**: %d\n", sess.MessageCount)
443 fmt.Printf("- **Tokens**: %d prompt, %d completion\n", sess.PromptTokens, sess.CompletionTokens)
444 fmt.Printf("- **Cost**: $%.4f\n", sess.Cost)
445 fmt.Printf("- **Created**: %s\n", formatTimestamp(sess.CreatedAt))
446 fmt.Printf("- **Updated**: %s\n", formatTimestamp(sess.UpdatedAt))
447 fmt.Println()
448 }
449
450 for _, child := range sess.Children {
451 printSessionMarkdown(child, level+1, includeMetadata)
452 }
453}
454
455func printSessionText(sess SessionWithChildren, level int) {
456 indent := strings.Repeat(" ", level)
457 fmt.Printf("%s• %s (ID: %s, Messages: %d, Cost: $%.4f)\n",
458 indent, sess.Title, sess.ID, sess.MessageCount, sess.Cost)
459
460 for _, child := range sess.Children {
461 printSessionText(child, level+1)
462 }
463}
464
465func formatTimestamp(timestamp int64) string {
466 // Assuming timestamp is Unix seconds
467 return time.Unix(timestamp, 0).Format("2006-01-02 15:04:05")
468}
469
470func formatConversation(sess session.Session, messages []message.Message, format string) error {
471 switch strings.ToLower(format) {
472 case "markdown", "md":
473 return formatConversationMarkdown(sess, messages)
474 case "json":
475 return formatConversationJSON(sess, messages)
476 case "yaml":
477 return formatConversationYAML(sess, messages)
478 default:
479 return fmt.Errorf("unsupported format: %s", format)
480 }
481}
482
483func formatConversationMarkdown(sess session.Session, messages []message.Message) error {
484 fmt.Printf("# %s\n\n", sess.Title)
485
486 // Session metadata
487 fmt.Printf("**Session ID:** %s \n", sess.ID)
488 fmt.Printf("**Created:** %s \n", formatTimestamp(sess.CreatedAt))
489 fmt.Printf("**Messages:** %d \n", sess.MessageCount)
490 fmt.Printf("**Tokens:** %d prompt, %d completion \n", sess.PromptTokens, sess.CompletionTokens)
491 if sess.Cost > 0 {
492 fmt.Printf("**Cost:** $%.4f \n", sess.Cost)
493 }
494 fmt.Println()
495 fmt.Println("---")
496 fmt.Println()
497
498 for i, msg := range messages {
499 formatMessageMarkdown(msg, i+1)
500 }
501
502 return nil
503}
504
505func formatMessageMarkdown(msg message.Message, index int) {
506 // Role header
507 switch msg.Role {
508 case message.User:
509 fmt.Printf("## 👤 User\n\n")
510 case message.Assistant:
511 fmt.Printf("## 🤖 Assistant")
512 if msg.Model != "" {
513 fmt.Printf(" (%s)", msg.Model)
514 }
515 fmt.Printf("\n\n")
516 case message.System:
517 fmt.Printf("## ⚙️ System\n\n")
518 case message.Tool:
519 fmt.Printf("## 🔧 Tool\n\n")
520 }
521
522 // Process each part
523 for _, part := range msg.Parts {
524 switch p := part.(type) {
525 case message.TextContent:
526 fmt.Printf("%s\n\n", p.Text)
527 case message.ReasoningContent:
528 if p.Thinking != "" {
529 fmt.Printf("### 🧠 Reasoning\n\n")
530 fmt.Printf("```\n%s\n```\n\n", p.Thinking)
531 }
532 case message.ToolCall:
533 fmt.Printf("### 🔧 Tool Call: %s\n\n", p.Name)
534 fmt.Printf("**ID:** %s \n", p.ID)
535 if p.Input != "" {
536 fmt.Printf("**Input:**\n```json\n%s\n```\n\n", p.Input)
537 }
538 case message.ToolResult:
539 fmt.Printf("### 📝 Tool Result: %s\n\n", p.Name)
540 if p.IsError {
541 fmt.Printf("**❌ Error:**\n```\n%s\n```\n\n", p.Content)
542 } else {
543 fmt.Printf("**✅ Result:**\n```\n%s\n```\n\n", p.Content)
544 }
545 case message.ImageURLContent:
546 fmt.Printf("\n\n", p.URL)
547 case message.BinaryContent:
548 fmt.Printf("**File:** %s (%s)\n\n", p.Path, p.MIMEType)
549 case message.Finish:
550 if p.Reason != message.FinishReasonEndTurn {
551 fmt.Printf("**Finish Reason:** %s\n", p.Reason)
552 if p.Message != "" {
553 fmt.Printf("**Message:** %s\n", p.Message)
554 }
555 fmt.Println()
556 }
557 }
558 }
559
560 fmt.Println("---")
561 fmt.Println()
562}
563
564func formatConversationJSON(sess session.Session, messages []message.Message) error {
565 data := struct {
566 Session session.Session `json:"session"`
567 Messages []message.Message `json:"messages"`
568 }{
569 Session: sess,
570 Messages: messages,
571 }
572
573 jsonData, err := json.MarshalIndent(data, "", " ")
574 if err != nil {
575 return fmt.Errorf("failed to marshal JSON: %w", err)
576 }
577 fmt.Println(string(jsonData))
578 return nil
579}
580
581func formatConversationYAML(sess session.Session, messages []message.Message) error {
582 data := struct {
583 Session session.Session `yaml:"session"`
584 Messages []message.Message `yaml:"messages"`
585 }{
586 Session: sess,
587 Messages: messages,
588 }
589
590 yamlData, err := yaml.Marshal(data)
591 if err != nil {
592 return fmt.Errorf("failed to marshal YAML: %w", err)
593 }
594 fmt.Println(string(yamlData))
595 return nil
596}
597
598func runImport(ctx context.Context, file, format string, dryRun bool) error {
599 // Read the file
600 data, err := readImportFile(file, format)
601 if err != nil {
602 return fmt.Errorf("failed to read import file: %w", err)
603 }
604
605 // Validate the data structure
606 if err := validateImportData(data); err != nil {
607 return fmt.Errorf("invalid import data: %w", err)
608 }
609
610 if dryRun {
611 result := ImportResult{
612 TotalSessions: countTotalImportSessions(data.Sessions),
613 ImportedSessions: 0,
614 SkippedSessions: 0,
615 ImportedMessages: 0,
616 SessionMapping: make(map[string]string),
617 }
618 fmt.Printf("Dry run: Would import %d sessions\n", result.TotalSessions)
619 return nil
620 }
621
622 // Perform the actual import
623 sessionService, messageService, err := createServices(ctx)
624 if err != nil {
625 return err
626 }
627
628 result, err := importSessions(ctx, sessionService, messageService, data)
629 if err != nil {
630 return fmt.Errorf("import failed: %w", err)
631 }
632
633 // Print summary
634 fmt.Printf("Import completed successfully:\n")
635 fmt.Printf(" Total sessions processed: %d\n", result.TotalSessions)
636 fmt.Printf(" Sessions imported: %d\n", result.ImportedSessions)
637 fmt.Printf(" Sessions skipped: %d\n", result.SkippedSessions)
638 fmt.Printf(" Messages imported: %d\n", result.ImportedMessages)
639
640 if len(result.Errors) > 0 {
641 fmt.Printf(" Errors encountered: %d\n", len(result.Errors))
642 for _, errStr := range result.Errors {
643 fmt.Printf(" - %s\n", errStr)
644 }
645 }
646
647 return nil
648}
649
650func runImportConversation(ctx context.Context, file, format string) error {
651 // Read the conversation file
652 convData, err := readConversationFile(file, format)
653 if err != nil {
654 return fmt.Errorf("failed to read conversation file: %w", err)
655 }
656
657 // Validate the conversation data
658 if err := validateConversationData(convData); err != nil {
659 return fmt.Errorf("invalid conversation data: %w", err)
660 }
661
662 // Import the conversation
663 sessionService, messageService, err := createServices(ctx)
664 if err != nil {
665 return err
666 }
667
668 newSessionID, messageCount, err := importConversation(ctx, sessionService, messageService, convData)
669 if err != nil {
670 return fmt.Errorf("conversation import failed: %w", err)
671 }
672
673 fmt.Printf("Conversation imported successfully:\n")
674 fmt.Printf(" Session ID: %s\n", newSessionID)
675 fmt.Printf(" Title: %s\n", convData.Session.Title)
676 fmt.Printf(" Messages imported: %d\n", messageCount)
677
678 return nil
679}
680
681func readImportFile(file, format string) (*ImportData, error) {
682 fileData, err := os.ReadFile(file)
683 if err != nil {
684 return nil, fmt.Errorf("failed to read file %s: %w", file, err)
685 }
686
687 // Auto-detect format if not specified
688 if format == "" {
689 format = detectFormat(file, fileData)
690 }
691
692 var data ImportData
693 switch strings.ToLower(format) {
694 case "json":
695 if err := json.Unmarshal(fileData, &data); err != nil {
696 return nil, fmt.Errorf("failed to parse JSON: %w", err)
697 }
698 case "yaml", "yml":
699 if err := yaml.Unmarshal(fileData, &data); err != nil {
700 return nil, fmt.Errorf("failed to parse YAML: %w", err)
701 }
702 default:
703 return nil, fmt.Errorf("unsupported format: %s", format)
704 }
705
706 return &data, nil
707}
708
709func readConversationFile(file, format string) (*ConversationData, error) {
710 fileData, err := os.ReadFile(file)
711 if err != nil {
712 return nil, fmt.Errorf("failed to read file %s: %w", file, err)
713 }
714
715 // Auto-detect format if not specified
716 if format == "" {
717 format = detectFormat(file, fileData)
718 }
719
720 var data ConversationData
721 switch strings.ToLower(format) {
722 case "json":
723 if err := json.Unmarshal(fileData, &data); err != nil {
724 return nil, fmt.Errorf("failed to parse JSON: %w", err)
725 }
726 case "yaml", "yml":
727 if err := yaml.Unmarshal(fileData, &data); err != nil {
728 return nil, fmt.Errorf("failed to parse YAML: %w", err)
729 }
730 case "markdown", "md":
731 return nil, fmt.Errorf("markdown import for conversations is not yet implemented")
732 default:
733 return nil, fmt.Errorf("unsupported format: %s", format)
734 }
735
736 return &data, nil
737}
738
739func detectFormat(filename string, data []byte) string {
740 // First try file extension
741 ext := strings.ToLower(filepath.Ext(filename))
742 switch ext {
743 case ".json":
744 return "json"
745 case ".yaml", ".yml":
746 return "yaml"
747 case ".md", ".markdown":
748 return "markdown"
749 }
750
751 // Try to detect from content
752 data = bytes.TrimSpace(data)
753 if len(data) > 0 {
754 if data[0] == '{' || data[0] == '[' {
755 return "json"
756 }
757 if strings.HasPrefix(string(data), "---") || strings.Contains(string(data), ":") {
758 return "yaml"
759 }
760 }
761
762 return "json" // default fallback
763}
764
765func validateImportData(data *ImportData) error {
766 if data == nil {
767 return fmt.Errorf("import data is nil")
768 }
769
770 if len(data.Sessions) == 0 {
771 return fmt.Errorf("no sessions to import")
772 }
773
774 // Validate session structure
775 for i, sess := range data.Sessions {
776 if err := validateImportSessionHierarchy(sess, ""); err != nil {
777 return fmt.Errorf("session %d validation failed: %w", i, err)
778 }
779 }
780
781 return nil
782}
783
784func validateConversationData(data *ConversationData) error {
785 if data == nil {
786 return fmt.Errorf("conversation data is nil")
787 }
788
789 if data.Session.Title == "" {
790 return fmt.Errorf("session title is required")
791 }
792
793 if len(data.Messages) == 0 {
794 return fmt.Errorf("no messages to import")
795 }
796
797 return nil
798}
799
800func validateImportSessionHierarchy(sess ImportSession, expectedParent string) error {
801 if sess.ID == "" {
802 return fmt.Errorf("session ID is required")
803 }
804
805 if sess.Title == "" {
806 return fmt.Errorf("session title is required")
807 }
808
809 // For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
810 if expectedParent == "" {
811 if sess.ParentSessionID != "" {
812 return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
813 }
814 } else {
815 // For child sessions, parent should match expected parent
816 if sess.ParentSessionID != expectedParent {
817 return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
818 }
819 }
820
821 // Validate children
822 for _, child := range sess.Children {
823 if err := validateImportSessionHierarchy(child, sess.ID); err != nil {
824 return err
825 }
826 }
827
828 return nil
829}
830
831func validateSessionHierarchy(sess SessionWithChildren, expectedParent string) error {
832 if sess.ID == "" {
833 return fmt.Errorf("session ID is required")
834 }
835
836 if sess.Title == "" {
837 return fmt.Errorf("session title is required")
838 }
839
840 // For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
841 if expectedParent == "" {
842 if sess.ParentSessionID != "" {
843 return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
844 }
845 } else {
846 // For child sessions, parent should match expected parent
847 if sess.ParentSessionID != expectedParent {
848 return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
849 }
850 }
851
852 // Validate children
853 for _, child := range sess.Children {
854 if err := validateSessionHierarchy(child, sess.ID); err != nil {
855 return err
856 }
857 }
858
859 return nil
860}
861
862func countTotalImportSessions(sessions []ImportSession) int {
863 count := len(sessions)
864 for _, sess := range sessions {
865 count += countTotalImportSessions(sess.Children)
866 }
867 return count
868}
869
870func countTotalSessions(sessions []SessionWithChildren) int {
871 count := len(sessions)
872 for _, sess := range sessions {
873 count += countTotalSessions(sess.Children)
874 }
875 return count
876}
877
878func importSessions(ctx context.Context, sessionService session.Service, messageService message.Service, data *ImportData) (ImportResult, error) {
879 result := ImportResult{
880 TotalSessions: countTotalImportSessions(data.Sessions),
881 SessionMapping: make(map[string]string),
882 }
883
884 // Import sessions recursively, starting with top-level sessions
885 for _, sess := range data.Sessions {
886 err := importImportSessionWithChildren(ctx, sessionService, messageService, sess, "", &result)
887 if err != nil {
888 result.Errors = append(result.Errors, fmt.Sprintf("failed to import session %s: %v", sess.ID, err))
889 }
890 }
891
892 return result, nil
893}
894
895func importConversation(ctx context.Context, sessionService session.Service, messageService message.Service, data *ConversationData) (string, int, error) {
896 // Generate new session ID
897 newSessionID := uuid.New().String()
898
899 // Create the session using the low-level database API
900 cwd, err := getCwd()
901 if err != nil {
902 return "", 0, err
903 }
904
905 cfg, err := config.Init(cwd, false)
906 if err != nil {
907 return "", 0, err
908 }
909
910 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
911 if err != nil {
912 return "", 0, err
913 }
914
915 queries := db.New(conn)
916
917 // Create session with all original metadata
918 _, err = queries.CreateSession(ctx, db.CreateSessionParams{
919 ID: newSessionID,
920 ParentSessionID: sql.NullString{Valid: false},
921 Title: data.Session.Title,
922 MessageCount: data.Session.MessageCount,
923 PromptTokens: data.Session.PromptTokens,
924 CompletionTokens: data.Session.CompletionTokens,
925 Cost: data.Session.Cost,
926 })
927 if err != nil {
928 return "", 0, fmt.Errorf("failed to create session: %w", err)
929 }
930
931 // Import messages
932 messageCount := 0
933 for _, msg := range data.Messages {
934 // Generate new message ID
935 newMessageID := uuid.New().String()
936
937 // Marshal message parts
938 partsJSON, err := json.Marshal(msg.Parts)
939 if err != nil {
940 return "", 0, fmt.Errorf("failed to marshal message parts: %w", err)
941 }
942
943 // Create message
944 _, err = queries.CreateMessage(ctx, db.CreateMessageParams{
945 ID: newMessageID,
946 SessionID: newSessionID,
947 Role: string(msg.Role),
948 Parts: string(partsJSON),
949 Model: sql.NullString{String: msg.Model, Valid: msg.Model != ""},
950 Provider: sql.NullString{String: msg.Provider, Valid: msg.Provider != ""},
951 })
952 if err != nil {
953 return "", 0, fmt.Errorf("failed to create message: %w", err)
954 }
955 messageCount++
956 }
957
958 return newSessionID, messageCount, nil
959}
960
961func importImportSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess ImportSession, parentID string, result *ImportResult) error {
962 // Generate new session ID
963 newSessionID := uuid.New().String()
964 result.SessionMapping[sess.ID] = newSessionID
965
966 // Create the session using the low-level database API to preserve metadata
967 cwd, err := getCwd()
968 if err != nil {
969 return err
970 }
971
972 cfg, err := config.Init(cwd, false)
973 if err != nil {
974 return err
975 }
976
977 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
978 if err != nil {
979 return err
980 }
981
982 queries := db.New(conn)
983
984 // Create session with all original metadata
985 parentSessionID := sql.NullString{Valid: false}
986 if parentID != "" {
987 parentSessionID = sql.NullString{String: parentID, Valid: true}
988 }
989
990 _, err = queries.CreateSession(ctx, db.CreateSessionParams{
991 ID: newSessionID,
992 ParentSessionID: parentSessionID,
993 Title: sess.Title,
994 MessageCount: sess.MessageCount,
995 PromptTokens: sess.PromptTokens,
996 CompletionTokens: sess.CompletionTokens,
997 Cost: sess.Cost,
998 })
999 if err != nil {
1000 return fmt.Errorf("failed to create session: %w", err)
1001 }
1002
1003 result.ImportedSessions++
1004
1005 // Import children recursively
1006 for _, child := range sess.Children {
1007 err := importImportSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
1008 if err != nil {
1009 result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
1010 }
1011 }
1012
1013 return nil
1014}
1015
1016func importSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess SessionWithChildren, parentID string, result *ImportResult) error {
1017 // Generate new session ID
1018 newSessionID := uuid.New().String()
1019 result.SessionMapping[sess.ID] = newSessionID
1020
1021 // Create the session using the low-level database API to preserve metadata
1022 cwd, err := getCwd()
1023 if err != nil {
1024 return err
1025 }
1026
1027 cfg, err := config.Init(cwd, false)
1028 if err != nil {
1029 return err
1030 }
1031
1032 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
1033 if err != nil {
1034 return err
1035 }
1036
1037 queries := db.New(conn)
1038
1039 // Create session with all original metadata
1040 parentSessionID := sql.NullString{Valid: false}
1041 if parentID != "" {
1042 parentSessionID = sql.NullString{String: parentID, Valid: true}
1043 }
1044
1045 _, err = queries.CreateSession(ctx, db.CreateSessionParams{
1046 ID: newSessionID,
1047 ParentSessionID: parentSessionID,
1048 Title: sess.Title,
1049 MessageCount: sess.MessageCount,
1050 PromptTokens: sess.PromptTokens,
1051 CompletionTokens: sess.CompletionTokens,
1052 Cost: sess.Cost,
1053 })
1054 if err != nil {
1055 return fmt.Errorf("failed to create session: %w", err)
1056 }
1057
1058 result.ImportedSessions++
1059
1060 // Import children recursively
1061 for _, child := range sess.Children {
1062 err := importSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
1063 if err != nil {
1064 result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
1065 }
1066 }
1067
1068 return nil
1069}
1070
1071func runSessionsSearch(ctx context.Context, titlePattern, textPattern, format string) error {
1072 sessionService, err := createSessionService(ctx)
1073 if err != nil {
1074 return err
1075 }
1076
1077 var sessions []session.Session
1078
1079 // Determine which search method to use based on provided patterns
1080 if titlePattern != "" && textPattern != "" {
1081 sessions, err = sessionService.SearchByTitleAndText(ctx, titlePattern, textPattern)
1082 } else if titlePattern != "" {
1083 sessions, err = sessionService.SearchByTitle(ctx, titlePattern)
1084 } else if textPattern != "" {
1085 sessions, err = sessionService.SearchByText(ctx, textPattern)
1086 }
1087
1088 if err != nil {
1089 return fmt.Errorf("search failed: %w", err)
1090 }
1091
1092 return formatSearchResults(sessions, format)
1093}
1094
1095func formatSearchResults(sessions []session.Session, format string) error {
1096 switch strings.ToLower(format) {
1097 case "json":
1098 return formatSearchResultsJSON(sessions)
1099 case "text":
1100 return formatSearchResultsText(sessions)
1101 default:
1102 return fmt.Errorf("unsupported format: %s", format)
1103 }
1104}
1105
1106func formatSearchResultsJSON(sessions []session.Session) error {
1107 data, err := json.MarshalIndent(sessions, "", " ")
1108 if err != nil {
1109 return fmt.Errorf("failed to marshal JSON: %w", err)
1110 }
1111 fmt.Println(string(data))
1112 return nil
1113}
1114
1115func formatSearchResultsText(sessions []session.Session) error {
1116 if len(sessions) == 0 {
1117 fmt.Println("No sessions found matching the search criteria.")
1118 return nil
1119 }
1120
1121 fmt.Printf("Found %d session(s):\n\n", len(sessions))
1122 for _, sess := range sessions {
1123 fmt.Printf("• %s (ID: %s)\n", sess.Title, sess.ID)
1124 fmt.Printf(" Messages: %d, Cost: $%.4f\n", sess.MessageCount, sess.Cost)
1125 fmt.Printf(" Created: %s\n", formatTimestamp(sess.CreatedAt))
1126 if sess.ParentSessionID != "" {
1127 fmt.Printf(" Parent: %s\n", sess.ParentSessionID)
1128 }
1129 fmt.Println()
1130 }
1131
1132 return nil
1133}
1134
1135func runSessionsStats(ctx context.Context, format, groupBy string) error {
1136 // Get database connection
1137 cwd, err := getCwd()
1138 if err != nil {
1139 return err
1140 }
1141
1142 cfg, err := config.Init(cwd, false)
1143 if err != nil {
1144 return err
1145 }
1146
1147 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
1148 if err != nil {
1149 return err
1150 }
1151
1152 queries := db.New(conn)
1153
1154 // Handle grouped statistics
1155 if groupBy != "" {
1156 return runGroupedStats(ctx, queries, format, groupBy)
1157 }
1158
1159 // Get overall statistics
1160 statsRow, err := queries.GetSessionStats(ctx)
1161 if err != nil {
1162 return fmt.Errorf("failed to get session stats: %w", err)
1163 }
1164
1165 // Convert to our struct, handling NULL values
1166 stats := SessionStats{
1167 TotalSessions: statsRow.TotalSessions,
1168 TotalMessages: convertNullFloat64ToInt64(statsRow.TotalMessages),
1169 TotalPromptTokens: convertNullFloat64ToInt64(statsRow.TotalPromptTokens),
1170 TotalCompletionTokens: convertNullFloat64ToInt64(statsRow.TotalCompletionTokens),
1171 TotalCost: convertNullFloat64(statsRow.TotalCost),
1172 AvgCostPerSession: convertNullFloat64(statsRow.AvgCostPerSession),
1173 }
1174
1175 return formatStats(stats, format)
1176}
1177
1178func runGroupedStats(ctx context.Context, queries *db.Queries, format, groupBy string) error {
1179 var groupedStats []GroupedSessionStats
1180
1181 switch strings.ToLower(groupBy) {
1182 case "day":
1183 rows, err := queries.GetSessionStatsByDay(ctx)
1184 if err != nil {
1185 return fmt.Errorf("failed to get daily stats: %w", err)
1186 }
1187 groupedStats = convertDayStatsRows(rows)
1188 case "week":
1189 rows, err := queries.GetSessionStatsByWeek(ctx)
1190 if err != nil {
1191 return fmt.Errorf("failed to get weekly stats: %w", err)
1192 }
1193 groupedStats = convertWeekStatsRows(rows)
1194 case "month":
1195 rows, err := queries.GetSessionStatsByMonth(ctx)
1196 if err != nil {
1197 return fmt.Errorf("failed to get monthly stats: %w", err)
1198 }
1199 groupedStats = convertMonthStatsRows(rows)
1200 default:
1201 return fmt.Errorf("unsupported group-by value: %s. Valid values are: day, week, month", groupBy)
1202 }
1203
1204 return formatGroupedStats(groupedStats, format, groupBy)
1205}
1206
1207func convertNullFloat64(val sql.NullFloat64) float64 {
1208 if val.Valid {
1209 return val.Float64
1210 }
1211 return 0.0
1212}
1213
1214func convertNullFloat64ToInt64(val sql.NullFloat64) int64 {
1215 if val.Valid {
1216 return int64(val.Float64)
1217 }
1218 return 0
1219}
1220
1221func convertDayStatsRows(rows []db.GetSessionStatsByDayRow) []GroupedSessionStats {
1222 result := make([]GroupedSessionStats, 0, len(rows))
1223 for _, row := range rows {
1224 stats := GroupedSessionStats{
1225 Period: fmt.Sprintf("%v", row.Day),
1226 SessionCount: row.SessionCount,
1227 MessageCount: convertNullFloat64ToInt64(row.MessageCount),
1228 PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
1229 CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
1230 TotalCost: convertNullFloat64(row.TotalCost),
1231 AvgCost: convertNullFloat64(row.AvgCost),
1232 }
1233 result = append(result, stats)
1234 }
1235 return result
1236}
1237
1238func convertWeekStatsRows(rows []db.GetSessionStatsByWeekRow) []GroupedSessionStats {
1239 result := make([]GroupedSessionStats, 0, len(rows))
1240 for _, row := range rows {
1241 stats := GroupedSessionStats{
1242 Period: fmt.Sprintf("%v", row.WeekStart),
1243 SessionCount: row.SessionCount,
1244 MessageCount: convertNullFloat64ToInt64(row.MessageCount),
1245 PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
1246 CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
1247 TotalCost: convertNullFloat64(row.TotalCost),
1248 AvgCost: convertNullFloat64(row.AvgCost),
1249 }
1250 result = append(result, stats)
1251 }
1252 return result
1253}
1254
1255func convertMonthStatsRows(rows []db.GetSessionStatsByMonthRow) []GroupedSessionStats {
1256 result := make([]GroupedSessionStats, 0, len(rows))
1257 for _, row := range rows {
1258 stats := GroupedSessionStats{
1259 Period: fmt.Sprintf("%v", row.Month),
1260 SessionCount: row.SessionCount,
1261 MessageCount: convertNullFloat64ToInt64(row.MessageCount),
1262 PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
1263 CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
1264 TotalCost: convertNullFloat64(row.TotalCost),
1265 AvgCost: convertNullFloat64(row.AvgCost),
1266 }
1267 result = append(result, stats)
1268 }
1269 return result
1270}
1271
1272func formatStats(stats SessionStats, format string) error {
1273 switch strings.ToLower(format) {
1274 case "json":
1275 return formatStatsJSON(stats)
1276 case "text":
1277 return formatStatsText(stats)
1278 default:
1279 return fmt.Errorf("unsupported format: %s", format)
1280 }
1281}
1282
1283func formatGroupedStats(stats []GroupedSessionStats, format, groupBy string) error {
1284 switch strings.ToLower(format) {
1285 case "json":
1286 return formatGroupedStatsJSON(stats)
1287 case "text":
1288 return formatGroupedStatsText(stats, groupBy)
1289 default:
1290 return fmt.Errorf("unsupported format: %s", format)
1291 }
1292}
1293
1294func formatStatsJSON(stats SessionStats) error {
1295 data, err := json.MarshalIndent(stats, "", " ")
1296 if err != nil {
1297 return fmt.Errorf("failed to marshal JSON: %w", err)
1298 }
1299 fmt.Println(string(data))
1300 return nil
1301}
1302
1303func formatStatsText(stats SessionStats) error {
1304 if stats.TotalSessions == 0 {
1305 fmt.Println("No sessions found.")
1306 return nil
1307 }
1308
1309 fmt.Println("Session Statistics")
1310 fmt.Println("==================")
1311 fmt.Printf("Total Sessions: %d\n", stats.TotalSessions)
1312 fmt.Printf("Total Messages: %d\n", stats.TotalMessages)
1313 fmt.Printf("Total Prompt Tokens: %d\n", stats.TotalPromptTokens)
1314 fmt.Printf("Total Completion Tokens: %d\n", stats.TotalCompletionTokens)
1315 fmt.Printf("Total Cost: $%.4f\n", stats.TotalCost)
1316 fmt.Printf("Average Cost/Session: $%.4f\n", stats.AvgCostPerSession)
1317
1318 totalTokens := stats.TotalPromptTokens + stats.TotalCompletionTokens
1319 if totalTokens > 0 {
1320 fmt.Printf("Total Tokens: %d\n", totalTokens)
1321 fmt.Printf("Average Tokens/Session: %.1f\n", float64(totalTokens)/float64(stats.TotalSessions))
1322 }
1323
1324 if stats.TotalSessions > 0 {
1325 fmt.Printf("Average Messages/Session: %.1f\n", float64(stats.TotalMessages)/float64(stats.TotalSessions))
1326 }
1327
1328 return nil
1329}
1330
1331func formatGroupedStatsJSON(stats []GroupedSessionStats) error {
1332 data, err := json.MarshalIndent(stats, "", " ")
1333 if err != nil {
1334 return fmt.Errorf("failed to marshal JSON: %w", err)
1335 }
1336 fmt.Println(string(data))
1337 return nil
1338}
1339
1340func formatGroupedStatsText(stats []GroupedSessionStats, groupBy string) error {
1341 if len(stats) == 0 {
1342 fmt.Printf("No sessions found for grouping by %s.\n", groupBy)
1343 return nil
1344 }
1345
1346 fmt.Printf("Session Statistics (Grouped by %s)\n", strings.ToUpper(groupBy[:1])+groupBy[1:])
1347 fmt.Println(strings.Repeat("=", 30+len(groupBy)))
1348 fmt.Println()
1349
1350 for _, stat := range stats {
1351 fmt.Printf("Period: %s\n", stat.Period)
1352 fmt.Printf(" Sessions: %d\n", stat.SessionCount)
1353 fmt.Printf(" Messages: %d\n", stat.MessageCount)
1354 fmt.Printf(" Prompt Tokens: %d\n", stat.PromptTokens)
1355 fmt.Printf(" Completion Tokens: %d\n", stat.CompletionTokens)
1356 fmt.Printf(" Total Cost: $%.4f\n", stat.TotalCost)
1357 fmt.Printf(" Average Cost: $%.4f\n", stat.AvgCost)
1358 totalTokens := stat.PromptTokens + stat.CompletionTokens
1359 if totalTokens > 0 {
1360 fmt.Printf(" Total Tokens: %d\n", totalTokens)
1361 }
1362 fmt.Println()
1363 }
1364
1365 return nil
1366}