Detailed changes
@@ -1,10 +1,13 @@
package cmd
import (
+ "bytes"
"context"
+ "database/sql"
"encoding/json"
"fmt"
"os"
+ "path/filepath"
"strings"
"time"
@@ -12,6 +15,7 @@ import (
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
+ "github.com/google/uuid"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
)
@@ -22,6 +26,89 @@ type SessionWithChildren struct {
Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"`
}
+// ImportSession represents a session with proper JSON tags for import
+type ImportSession struct {
+ ID string `json:"id"`
+ ParentSessionID string `json:"parent_session_id"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ Cost float64 `json:"cost"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ SummaryMessageID string `json:"summary_message_id,omitempty"`
+ Children []ImportSession `json:"children,omitempty"`
+}
+
+// ImportData represents the full import structure for sessions
+type ImportData struct {
+ Version string `json:"version" yaml:"version"`
+ ExportedAt string `json:"exported_at,omitempty" yaml:"exported_at,omitempty"`
+ TotalSessions int `json:"total_sessions,omitempty" yaml:"total_sessions,omitempty"`
+ Sessions []ImportSession `json:"sessions" yaml:"sessions"`
+}
+
+// ImportMessage represents a message with proper JSON tags for import
+type ImportMessage struct {
+ ID string `json:"id"`
+ Role string `json:"role"`
+ SessionID string `json:"session_id"`
+ Parts []interface{} `json:"parts"`
+ Model string `json:"model,omitempty"`
+ Provider string `json:"provider,omitempty"`
+ CreatedAt int64 `json:"created_at"`
+}
+
+// ImportSessionInfo represents session info with proper JSON tags for conversation import
+type ImportSessionInfo struct {
+ ID string `json:"id"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens,omitempty"`
+ CompletionTokens int64 `json:"completion_tokens,omitempty"`
+ Cost float64 `json:"cost,omitempty"`
+ CreatedAt int64 `json:"created_at"`
+}
+
+// ConversationData represents a single conversation import structure
+type ConversationData struct {
+ Version string `json:"version" yaml:"version"`
+ Session ImportSessionInfo `json:"session" yaml:"session"`
+ Messages []ImportMessage `json:"messages" yaml:"messages"`
+}
+
+// ImportResult contains the results of an import operation
+type ImportResult struct {
+ TotalSessions int `json:"total_sessions"`
+ ImportedSessions int `json:"imported_sessions"`
+ SkippedSessions int `json:"skipped_sessions"`
+ ImportedMessages int `json:"imported_messages"`
+ Errors []string `json:"errors,omitempty"`
+ SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
+}
+
+// SessionStats represents aggregated session statistics
+type SessionStats struct {
+ TotalSessions int64 `json:"total_sessions"`
+ TotalMessages int64 `json:"total_messages"`
+ TotalPromptTokens int64 `json:"total_prompt_tokens"`
+ TotalCompletionTokens int64 `json:"total_completion_tokens"`
+ TotalCost float64 `json:"total_cost"`
+ AvgCostPerSession float64 `json:"avg_cost_per_session"`
+}
+
+// GroupedSessionStats represents statistics grouped by time period
+type GroupedSessionStats struct {
+ Period string `json:"period"`
+ SessionCount int64 `json:"session_count"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ TotalCost float64 `json:"total_cost"`
+ AvgCost float64 `json:"avg_cost"`
+}
+
var sessionsCmd = &cobra.Command{
Use: "sessions",
Short: "Manage sessions",
@@ -60,15 +147,80 @@ var exportConversationCmd = &cobra.Command{
},
}
+var importCmd = &cobra.Command{
+ Use: "import <file>",
+ Short: "Import sessions from a file",
+ Long: `Import sessions from a JSON or YAML file with hierarchical structure`,
+ Args: cobra.ExactArgs(1),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ file := args[0]
+ format, _ := cmd.Flags().GetString("format")
+ dryRun, _ := cmd.Flags().GetBool("dry-run")
+ return runImport(cmd.Context(), file, format, dryRun)
+ },
+}
+
+var importConversationCmd = &cobra.Command{
+ Use: "import-conversation <file>",
+ Short: "Import a single conversation from a file",
+ Long: `Import a single conversation with messages from a JSON, YAML, or Markdown file`,
+ Args: cobra.ExactArgs(1),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ file := args[0]
+ format, _ := cmd.Flags().GetString("format")
+ return runImportConversation(cmd.Context(), file, format)
+ },
+}
+
+var searchCmd = &cobra.Command{
+ Use: "search",
+ Short: "Search sessions by title or message content",
+ Long: `Search sessions by title pattern (case-insensitive) or message text content`,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ titlePattern, _ := cmd.Flags().GetString("title")
+ textPattern, _ := cmd.Flags().GetString("text")
+ format, _ := cmd.Flags().GetString("format")
+
+ if titlePattern == "" && textPattern == "" {
+ return fmt.Errorf("at least one of --title or --text must be provided")
+ }
+
+ return runSessionsSearch(cmd.Context(), titlePattern, textPattern, format)
+ },
+}
+
+var statsCmd = &cobra.Command{
+ Use: "stats",
+ Short: "Show session statistics",
+ Long: `Display aggregated statistics about sessions including total counts, tokens, and costs`,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ format, _ := cmd.Flags().GetString("format")
+ groupBy, _ := cmd.Flags().GetString("group-by")
+ return runSessionsStats(cmd.Context(), format, groupBy)
+ },
+}
+
func init() {
rootCmd.AddCommand(sessionsCmd)
sessionsCmd.AddCommand(listCmd)
sessionsCmd.AddCommand(exportCmd)
sessionsCmd.AddCommand(exportConversationCmd)
+ sessionsCmd.AddCommand(importCmd)
+ sessionsCmd.AddCommand(importConversationCmd)
+ sessionsCmd.AddCommand(searchCmd)
+ sessionsCmd.AddCommand(statsCmd)
listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)")
exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)")
exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)")
+ importCmd.Flags().StringP("format", "f", "", "Import format (json, yaml) - auto-detected if not specified")
+ importCmd.Flags().Bool("dry-run", false, "Validate import data without persisting changes")
+ importConversationCmd.Flags().StringP("format", "f", "", "Import format (json, yaml, markdown) - auto-detected if not specified")
+ searchCmd.Flags().String("title", "", "Search by session title pattern (case-insensitive substring search)")
+ searchCmd.Flags().String("text", "", "Search by message text content")
+ searchCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
+ statsCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
+ statsCmd.Flags().String("group-by", "", "Group statistics by time period (day, week, month)")
}
func runSessionsList(ctx context.Context, format string) error {
@@ -442,3 +594,773 @@ func formatConversationYAML(sess session.Session, messages []message.Message) er
fmt.Println(string(yamlData))
return nil
}
+
+func runImport(ctx context.Context, file, format string, dryRun bool) error {
+ // Read the file
+ data, err := readImportFile(file, format)
+ if err != nil {
+ return fmt.Errorf("failed to read import file: %w", err)
+ }
+
+ // Validate the data structure
+ if err := validateImportData(data); err != nil {
+ return fmt.Errorf("invalid import data: %w", err)
+ }
+
+ if dryRun {
+ result := ImportResult{
+ TotalSessions: countTotalImportSessions(data.Sessions),
+ ImportedSessions: 0,
+ SkippedSessions: 0,
+ ImportedMessages: 0,
+ SessionMapping: make(map[string]string),
+ }
+ fmt.Printf("Dry run: Would import %d sessions\n", result.TotalSessions)
+ return nil
+ }
+
+ // Perform the actual import
+ sessionService, messageService, err := createServices(ctx)
+ if err != nil {
+ return err
+ }
+
+ result, err := importSessions(ctx, sessionService, messageService, data)
+ if err != nil {
+ return fmt.Errorf("import failed: %w", err)
+ }
+
+ // Print summary
+ fmt.Printf("Import completed successfully:\n")
+ fmt.Printf(" Total sessions processed: %d\n", result.TotalSessions)
+ fmt.Printf(" Sessions imported: %d\n", result.ImportedSessions)
+ fmt.Printf(" Sessions skipped: %d\n", result.SkippedSessions)
+ fmt.Printf(" Messages imported: %d\n", result.ImportedMessages)
+
+ if len(result.Errors) > 0 {
+ fmt.Printf(" Errors encountered: %d\n", len(result.Errors))
+ for _, errStr := range result.Errors {
+ fmt.Printf(" - %s\n", errStr)
+ }
+ }
+
+ return nil
+}
+
+func runImportConversation(ctx context.Context, file, format string) error {
+ // Read the conversation file
+ convData, err := readConversationFile(file, format)
+ if err != nil {
+ return fmt.Errorf("failed to read conversation file: %w", err)
+ }
+
+ // Validate the conversation data
+ if err := validateConversationData(convData); err != nil {
+ return fmt.Errorf("invalid conversation data: %w", err)
+ }
+
+ // Import the conversation
+ sessionService, messageService, err := createServices(ctx)
+ if err != nil {
+ return err
+ }
+
+ newSessionID, messageCount, err := importConversation(ctx, sessionService, messageService, convData)
+ if err != nil {
+ return fmt.Errorf("conversation import failed: %w", err)
+ }
+
+ fmt.Printf("Conversation imported successfully:\n")
+ fmt.Printf(" Session ID: %s\n", newSessionID)
+ fmt.Printf(" Title: %s\n", convData.Session.Title)
+ fmt.Printf(" Messages imported: %d\n", messageCount)
+
+ return nil
+}
+
+func readImportFile(file, format string) (*ImportData, error) {
+ fileData, err := os.ReadFile(file)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read file %s: %w", file, err)
+ }
+
+ // Auto-detect format if not specified
+ if format == "" {
+ format = detectFormat(file, fileData)
+ }
+
+ var data ImportData
+ switch strings.ToLower(format) {
+ case "json":
+ if err := json.Unmarshal(fileData, &data); err != nil {
+ return nil, fmt.Errorf("failed to parse JSON: %w", err)
+ }
+ case "yaml", "yml":
+ if err := yaml.Unmarshal(fileData, &data); err != nil {
+ return nil, fmt.Errorf("failed to parse YAML: %w", err)
+ }
+ default:
+ return nil, fmt.Errorf("unsupported format: %s", format)
+ }
+
+ return &data, nil
+}
+
+func readConversationFile(file, format string) (*ConversationData, error) {
+ fileData, err := os.ReadFile(file)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read file %s: %w", file, err)
+ }
+
+ // Auto-detect format if not specified
+ if format == "" {
+ format = detectFormat(file, fileData)
+ }
+
+ var data ConversationData
+ switch strings.ToLower(format) {
+ case "json":
+ if err := json.Unmarshal(fileData, &data); err != nil {
+ return nil, fmt.Errorf("failed to parse JSON: %w", err)
+ }
+ case "yaml", "yml":
+ if err := yaml.Unmarshal(fileData, &data); err != nil {
+ return nil, fmt.Errorf("failed to parse YAML: %w", err)
+ }
+ case "markdown", "md":
+ return nil, fmt.Errorf("markdown import for conversations is not yet implemented")
+ default:
+ return nil, fmt.Errorf("unsupported format: %s", format)
+ }
+
+ return &data, nil
+}
+
+func detectFormat(filename string, data []byte) string {
+ // First try file extension
+ ext := strings.ToLower(filepath.Ext(filename))
+ switch ext {
+ case ".json":
+ return "json"
+ case ".yaml", ".yml":
+ return "yaml"
+ case ".md", ".markdown":
+ return "markdown"
+ }
+
+ // Try to detect from content
+ data = bytes.TrimSpace(data)
+ if len(data) > 0 {
+ if data[0] == '{' || data[0] == '[' {
+ return "json"
+ }
+ if strings.HasPrefix(string(data), "---") || strings.Contains(string(data), ":") {
+ return "yaml"
+ }
+ }
+
+ return "json" // default fallback
+}
+
+func validateImportData(data *ImportData) error {
+ if data == nil {
+ return fmt.Errorf("import data is nil")
+ }
+
+ if len(data.Sessions) == 0 {
+ return fmt.Errorf("no sessions to import")
+ }
+
+ // Validate session structure
+ for i, sess := range data.Sessions {
+ if err := validateImportSessionHierarchy(sess, ""); err != nil {
+ return fmt.Errorf("session %d validation failed: %w", i, err)
+ }
+ }
+
+ return nil
+}
+
+func validateConversationData(data *ConversationData) error {
+ if data == nil {
+ return fmt.Errorf("conversation data is nil")
+ }
+
+ if data.Session.Title == "" {
+ return fmt.Errorf("session title is required")
+ }
+
+ if len(data.Messages) == 0 {
+ return fmt.Errorf("no messages to import")
+ }
+
+ return nil
+}
+
+func validateImportSessionHierarchy(sess ImportSession, expectedParent string) error {
+ if sess.ID == "" {
+ return fmt.Errorf("session ID is required")
+ }
+
+ if sess.Title == "" {
+ return fmt.Errorf("session title is required")
+ }
+
+ // For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
+ if expectedParent == "" {
+ if sess.ParentSessionID != "" {
+ return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
+ }
+ } else {
+ // For child sessions, parent should match expected parent
+ if sess.ParentSessionID != expectedParent {
+ return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
+ }
+ }
+
+ // Validate children
+ for _, child := range sess.Children {
+ if err := validateImportSessionHierarchy(child, sess.ID); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func validateSessionHierarchy(sess SessionWithChildren, expectedParent string) error {
+ if sess.ID == "" {
+ return fmt.Errorf("session ID is required")
+ }
+
+ if sess.Title == "" {
+ return fmt.Errorf("session title is required")
+ }
+
+ // For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
+ if expectedParent == "" {
+ if sess.ParentSessionID != "" {
+ return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
+ }
+ } else {
+ // For child sessions, parent should match expected parent
+ if sess.ParentSessionID != expectedParent {
+ return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
+ }
+ }
+
+ // Validate children
+ for _, child := range sess.Children {
+ if err := validateSessionHierarchy(child, sess.ID); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func countTotalImportSessions(sessions []ImportSession) int {
+ count := len(sessions)
+ for _, sess := range sessions {
+ count += countTotalImportSessions(sess.Children)
+ }
+ return count
+}
+
+func countTotalSessions(sessions []SessionWithChildren) int {
+ count := len(sessions)
+ for _, sess := range sessions {
+ count += countTotalSessions(sess.Children)
+ }
+ return count
+}
+
+func importSessions(ctx context.Context, sessionService session.Service, messageService message.Service, data *ImportData) (ImportResult, error) {
+ result := ImportResult{
+ TotalSessions: countTotalImportSessions(data.Sessions),
+ SessionMapping: make(map[string]string),
+ }
+
+ // Import sessions recursively, starting with top-level sessions
+ for _, sess := range data.Sessions {
+ err := importImportSessionWithChildren(ctx, sessionService, messageService, sess, "", &result)
+ if err != nil {
+ result.Errors = append(result.Errors, fmt.Sprintf("failed to import session %s: %v", sess.ID, err))
+ }
+ }
+
+ return result, nil
+}
+
+func importConversation(ctx context.Context, sessionService session.Service, messageService message.Service, data *ConversationData) (string, int, error) {
+ // Generate new session ID
+ newSessionID := uuid.New().String()
+
+ // Create the session using the low-level database API
+ cwd, err := getCwd()
+ if err != nil {
+ return "", 0, err
+ }
+
+ cfg, err := config.Init(cwd, false)
+ if err != nil {
+ return "", 0, err
+ }
+
+ conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
+ if err != nil {
+ return "", 0, err
+ }
+
+ queries := db.New(conn)
+
+ // Create session with all original metadata
+ _, err = queries.CreateSession(ctx, db.CreateSessionParams{
+ ID: newSessionID,
+ ParentSessionID: sql.NullString{Valid: false},
+ Title: data.Session.Title,
+ MessageCount: data.Session.MessageCount,
+ PromptTokens: data.Session.PromptTokens,
+ CompletionTokens: data.Session.CompletionTokens,
+ Cost: data.Session.Cost,
+ })
+ if err != nil {
+ return "", 0, fmt.Errorf("failed to create session: %w", err)
+ }
+
+ // Import messages
+ messageCount := 0
+ for _, msg := range data.Messages {
+ // Generate new message ID
+ newMessageID := uuid.New().String()
+
+ // Marshal message parts
+ partsJSON, err := json.Marshal(msg.Parts)
+ if err != nil {
+ return "", 0, fmt.Errorf("failed to marshal message parts: %w", err)
+ }
+
+ // Create message
+ _, err = queries.CreateMessage(ctx, db.CreateMessageParams{
+ ID: newMessageID,
+ SessionID: newSessionID,
+ Role: string(msg.Role),
+ Parts: string(partsJSON),
+ Model: sql.NullString{String: msg.Model, Valid: msg.Model != ""},
+ Provider: sql.NullString{String: msg.Provider, Valid: msg.Provider != ""},
+ })
+ if err != nil {
+ return "", 0, fmt.Errorf("failed to create message: %w", err)
+ }
+ messageCount++
+ }
+
+ return newSessionID, messageCount, nil
+}
+
+func importImportSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess ImportSession, parentID string, result *ImportResult) error {
+ // Generate new session ID
+ newSessionID := uuid.New().String()
+ result.SessionMapping[sess.ID] = newSessionID
+
+ // Create the session using the low-level database API to preserve metadata
+ cwd, err := getCwd()
+ if err != nil {
+ return err
+ }
+
+ cfg, err := config.Init(cwd, false)
+ if err != nil {
+ return err
+ }
+
+ conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
+ if err != nil {
+ return err
+ }
+
+ queries := db.New(conn)
+
+ // Create session with all original metadata
+ parentSessionID := sql.NullString{Valid: false}
+ if parentID != "" {
+ parentSessionID = sql.NullString{String: parentID, Valid: true}
+ }
+
+ _, err = queries.CreateSession(ctx, db.CreateSessionParams{
+ ID: newSessionID,
+ ParentSessionID: parentSessionID,
+ Title: sess.Title,
+ MessageCount: sess.MessageCount,
+ PromptTokens: sess.PromptTokens,
+ CompletionTokens: sess.CompletionTokens,
+ Cost: sess.Cost,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create session: %w", err)
+ }
+
+ result.ImportedSessions++
+
+ // Import children recursively
+ for _, child := range sess.Children {
+ err := importImportSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
+ if err != nil {
+ result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
+ }
+ }
+
+ return nil
+}
+
+func importSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess SessionWithChildren, parentID string, result *ImportResult) error {
+ // Generate new session ID
+ newSessionID := uuid.New().String()
+ result.SessionMapping[sess.ID] = newSessionID
+
+ // Create the session using the low-level database API to preserve metadata
+ cwd, err := getCwd()
+ if err != nil {
+ return err
+ }
+
+ cfg, err := config.Init(cwd, false)
+ if err != nil {
+ return err
+ }
+
+ conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
+ if err != nil {
+ return err
+ }
+
+ queries := db.New(conn)
+
+ // Create session with all original metadata
+ parentSessionID := sql.NullString{Valid: false}
+ if parentID != "" {
+ parentSessionID = sql.NullString{String: parentID, Valid: true}
+ }
+
+ _, err = queries.CreateSession(ctx, db.CreateSessionParams{
+ ID: newSessionID,
+ ParentSessionID: parentSessionID,
+ Title: sess.Title,
+ MessageCount: sess.MessageCount,
+ PromptTokens: sess.PromptTokens,
+ CompletionTokens: sess.CompletionTokens,
+ Cost: sess.Cost,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create session: %w", err)
+ }
+
+ result.ImportedSessions++
+
+ // Import children recursively
+ for _, child := range sess.Children {
+ err := importSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
+ if err != nil {
+ result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
+ }
+ }
+
+ return nil
+}
+
+func runSessionsSearch(ctx context.Context, titlePattern, textPattern, format string) error {
+ sessionService, err := createSessionService(ctx)
+ if err != nil {
+ return err
+ }
+
+ var sessions []session.Session
+
+ // Determine which search method to use based on provided patterns
+ if titlePattern != "" && textPattern != "" {
+ sessions, err = sessionService.SearchByTitleAndText(ctx, titlePattern, textPattern)
+ } else if titlePattern != "" {
+ sessions, err = sessionService.SearchByTitle(ctx, titlePattern)
+ } else if textPattern != "" {
+ sessions, err = sessionService.SearchByText(ctx, textPattern)
+ }
+
+ if err != nil {
+ return fmt.Errorf("search failed: %w", err)
+ }
+
+ return formatSearchResults(sessions, format)
+}
+
+func formatSearchResults(sessions []session.Session, format string) error {
+ switch strings.ToLower(format) {
+ case "json":
+ return formatSearchResultsJSON(sessions)
+ case "text":
+ return formatSearchResultsText(sessions)
+ default:
+ return fmt.Errorf("unsupported format: %s", format)
+ }
+}
+
+func formatSearchResultsJSON(sessions []session.Session) error {
+ data, err := json.MarshalIndent(sessions, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal JSON: %w", err)
+ }
+ fmt.Println(string(data))
+ return nil
+}
+
+func formatSearchResultsText(sessions []session.Session) error {
+ if len(sessions) == 0 {
+ fmt.Println("No sessions found matching the search criteria.")
+ return nil
+ }
+
+ fmt.Printf("Found %d session(s):\n\n", len(sessions))
+ for _, sess := range sessions {
+ fmt.Printf("• %s (ID: %s)\n", sess.Title, sess.ID)
+ fmt.Printf(" Messages: %d, Cost: $%.4f\n", sess.MessageCount, sess.Cost)
+ fmt.Printf(" Created: %s\n", formatTimestamp(sess.CreatedAt))
+ if sess.ParentSessionID != "" {
+ fmt.Printf(" Parent: %s\n", sess.ParentSessionID)
+ }
+ fmt.Println()
+ }
+
+ return nil
+}
+
+func runSessionsStats(ctx context.Context, format, groupBy string) error {
+ // Get database connection
+ cwd, err := getCwd()
+ if err != nil {
+ return err
+ }
+
+ cfg, err := config.Init(cwd, false)
+ if err != nil {
+ return err
+ }
+
+ conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
+ if err != nil {
+ return err
+ }
+
+ queries := db.New(conn)
+
+ // Handle grouped statistics
+ if groupBy != "" {
+ return runGroupedStats(ctx, queries, format, groupBy)
+ }
+
+ // Get overall statistics
+ statsRow, err := queries.GetSessionStats(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get session stats: %w", err)
+ }
+
+ // Convert to our struct, handling NULL values
+ stats := SessionStats{
+ TotalSessions: statsRow.TotalSessions,
+ TotalMessages: convertNullFloat64ToInt64(statsRow.TotalMessages),
+ TotalPromptTokens: convertNullFloat64ToInt64(statsRow.TotalPromptTokens),
+ TotalCompletionTokens: convertNullFloat64ToInt64(statsRow.TotalCompletionTokens),
+ TotalCost: convertNullFloat64(statsRow.TotalCost),
+ AvgCostPerSession: convertNullFloat64(statsRow.AvgCostPerSession),
+ }
+
+ return formatStats(stats, format)
+}
+
+func runGroupedStats(ctx context.Context, queries *db.Queries, format, groupBy string) error {
+ var groupedStats []GroupedSessionStats
+
+ switch strings.ToLower(groupBy) {
+ case "day":
+ rows, err := queries.GetSessionStatsByDay(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get daily stats: %w", err)
+ }
+ groupedStats = convertDayStatsRows(rows)
+ case "week":
+ rows, err := queries.GetSessionStatsByWeek(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get weekly stats: %w", err)
+ }
+ groupedStats = convertWeekStatsRows(rows)
+ case "month":
+ rows, err := queries.GetSessionStatsByMonth(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get monthly stats: %w", err)
+ }
+ groupedStats = convertMonthStatsRows(rows)
+ default:
+ return fmt.Errorf("unsupported group-by value: %s. Valid values are: day, week, month", groupBy)
+ }
+
+ return formatGroupedStats(groupedStats, format, groupBy)
+}
+
+func convertNullFloat64(val sql.NullFloat64) float64 {
+ if val.Valid {
+ return val.Float64
+ }
+ return 0.0
+}
+
+func convertNullFloat64ToInt64(val sql.NullFloat64) int64 {
+ if val.Valid {
+ return int64(val.Float64)
+ }
+ return 0
+}
+
+func convertDayStatsRows(rows []db.GetSessionStatsByDayRow) []GroupedSessionStats {
+ result := make([]GroupedSessionStats, 0, len(rows))
+ for _, row := range rows {
+ stats := GroupedSessionStats{
+ Period: fmt.Sprintf("%v", row.Day),
+ SessionCount: row.SessionCount,
+ MessageCount: convertNullFloat64ToInt64(row.MessageCount),
+ PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
+ CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
+ TotalCost: convertNullFloat64(row.TotalCost),
+ AvgCost: convertNullFloat64(row.AvgCost),
+ }
+ result = append(result, stats)
+ }
+ return result
+}
+
+func convertWeekStatsRows(rows []db.GetSessionStatsByWeekRow) []GroupedSessionStats {
+ result := make([]GroupedSessionStats, 0, len(rows))
+ for _, row := range rows {
+ stats := GroupedSessionStats{
+ Period: fmt.Sprintf("%v", row.WeekStart),
+ SessionCount: row.SessionCount,
+ MessageCount: convertNullFloat64ToInt64(row.MessageCount),
+ PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
+ CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
+ TotalCost: convertNullFloat64(row.TotalCost),
+ AvgCost: convertNullFloat64(row.AvgCost),
+ }
+ result = append(result, stats)
+ }
+ return result
+}
+
+func convertMonthStatsRows(rows []db.GetSessionStatsByMonthRow) []GroupedSessionStats {
+ result := make([]GroupedSessionStats, 0, len(rows))
+ for _, row := range rows {
+ stats := GroupedSessionStats{
+ Period: fmt.Sprintf("%v", row.Month),
+ SessionCount: row.SessionCount,
+ MessageCount: convertNullFloat64ToInt64(row.MessageCount),
+ PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
+ CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
+ TotalCost: convertNullFloat64(row.TotalCost),
+ AvgCost: convertNullFloat64(row.AvgCost),
+ }
+ result = append(result, stats)
+ }
+ return result
+}
+
+func formatStats(stats SessionStats, format string) error {
+ switch strings.ToLower(format) {
+ case "json":
+ return formatStatsJSON(stats)
+ case "text":
+ return formatStatsText(stats)
+ default:
+ return fmt.Errorf("unsupported format: %s", format)
+ }
+}
+
+func formatGroupedStats(stats []GroupedSessionStats, format, groupBy string) error {
+ switch strings.ToLower(format) {
+ case "json":
+ return formatGroupedStatsJSON(stats)
+ case "text":
+ return formatGroupedStatsText(stats, groupBy)
+ default:
+ return fmt.Errorf("unsupported format: %s", format)
+ }
+}
+
+func formatStatsJSON(stats SessionStats) error {
+ data, err := json.MarshalIndent(stats, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal JSON: %w", err)
+ }
+ fmt.Println(string(data))
+ return nil
+}
+
+func formatStatsText(stats SessionStats) error {
+ if stats.TotalSessions == 0 {
+ fmt.Println("No sessions found.")
+ return nil
+ }
+
+ fmt.Println("Session Statistics")
+ fmt.Println("==================")
+ fmt.Printf("Total Sessions: %d\n", stats.TotalSessions)
+ fmt.Printf("Total Messages: %d\n", stats.TotalMessages)
+ fmt.Printf("Total Prompt Tokens: %d\n", stats.TotalPromptTokens)
+ fmt.Printf("Total Completion Tokens: %d\n", stats.TotalCompletionTokens)
+ fmt.Printf("Total Cost: $%.4f\n", stats.TotalCost)
+ fmt.Printf("Average Cost/Session: $%.4f\n", stats.AvgCostPerSession)
+
+ totalTokens := stats.TotalPromptTokens + stats.TotalCompletionTokens
+ if totalTokens > 0 {
+ fmt.Printf("Total Tokens: %d\n", totalTokens)
+ fmt.Printf("Average Tokens/Session: %.1f\n", float64(totalTokens)/float64(stats.TotalSessions))
+ }
+
+ if stats.TotalSessions > 0 {
+ fmt.Printf("Average Messages/Session: %.1f\n", float64(stats.TotalMessages)/float64(stats.TotalSessions))
+ }
+
+ return nil
+}
+
+func formatGroupedStatsJSON(stats []GroupedSessionStats) error {
+ data, err := json.MarshalIndent(stats, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal JSON: %w", err)
+ }
+ fmt.Println(string(data))
+ return nil
+}
+
+func formatGroupedStatsText(stats []GroupedSessionStats, groupBy string) error {
+ if len(stats) == 0 {
+ fmt.Printf("No sessions found for grouping by %s.\n", groupBy)
+ return nil
+ }
+
+ fmt.Printf("Session Statistics (Grouped by %s)\n", strings.ToUpper(groupBy[:1])+groupBy[1:])
+ fmt.Println(strings.Repeat("=", 30+len(groupBy)))
+ fmt.Println()
+
+ for _, stat := range stats {
+ fmt.Printf("Period: %s\n", stat.Period)
+ fmt.Printf(" Sessions: %d\n", stat.SessionCount)
+ fmt.Printf(" Messages: %d\n", stat.MessageCount)
+ fmt.Printf(" Prompt Tokens: %d\n", stat.PromptTokens)
+ fmt.Printf(" Completion Tokens: %d\n", stat.CompletionTokens)
+ fmt.Printf(" Total Cost: $%.4f\n", stat.TotalCost)
+ fmt.Printf(" Average Cost: $%.4f\n", stat.AvgCost)
+ totalTokens := stat.PromptTokens + stat.CompletionTokens
+ if totalTokens > 0 {
+ fmt.Printf(" Total Tokens: %d\n", totalTokens)
+ }
+ fmt.Println()
+ }
+
+ return nil
+}
@@ -60,6 +60,18 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err)
}
+ if q.getSessionStatsStmt, err = db.PrepareContext(ctx, getSessionStats); err != nil {
+ return nil, fmt.Errorf("error preparing query GetSessionStats: %w", err)
+ }
+ if q.getSessionStatsByDayStmt, err = db.PrepareContext(ctx, getSessionStatsByDay); err != nil {
+ return nil, fmt.Errorf("error preparing query GetSessionStatsByDay: %w", err)
+ }
+ if q.getSessionStatsByMonthStmt, err = db.PrepareContext(ctx, getSessionStatsByMonth); err != nil {
+ return nil, fmt.Errorf("error preparing query GetSessionStatsByMonth: %w", err)
+ }
+ if q.getSessionStatsByWeekStmt, err = db.PrepareContext(ctx, getSessionStatsByWeek); err != nil {
+ return nil, fmt.Errorf("error preparing query GetSessionStatsByWeek: %w", err)
+ }
if q.listAllSessionsStmt, err = db.PrepareContext(ctx, listAllSessions); err != nil {
return nil, fmt.Errorf("error preparing query ListAllSessions: %w", err)
}
@@ -84,6 +96,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
}
+ if q.searchSessionsByTextStmt, err = db.PrepareContext(ctx, searchSessionsByText); err != nil {
+ return nil, fmt.Errorf("error preparing query SearchSessionsByText: %w", err)
+ }
+ if q.searchSessionsByTitleStmt, err = db.PrepareContext(ctx, searchSessionsByTitle); err != nil {
+ return nil, fmt.Errorf("error preparing query SearchSessionsByTitle: %w", err)
+ }
+ if q.searchSessionsByTitleAndTextStmt, err = db.PrepareContext(ctx, searchSessionsByTitleAndText); err != nil {
+ return nil, fmt.Errorf("error preparing query SearchSessionsByTitleAndText: %w", err)
+ }
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
}
@@ -155,6 +176,26 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr)
}
}
+ if q.getSessionStatsStmt != nil {
+ if cerr := q.getSessionStatsStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing getSessionStatsStmt: %w", cerr)
+ }
+ }
+ if q.getSessionStatsByDayStmt != nil {
+ if cerr := q.getSessionStatsByDayStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing getSessionStatsByDayStmt: %w", cerr)
+ }
+ }
+ if q.getSessionStatsByMonthStmt != nil {
+ if cerr := q.getSessionStatsByMonthStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing getSessionStatsByMonthStmt: %w", cerr)
+ }
+ }
+ if q.getSessionStatsByWeekStmt != nil {
+ if cerr := q.getSessionStatsByWeekStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing getSessionStatsByWeekStmt: %w", cerr)
+ }
+ }
if q.listAllSessionsStmt != nil {
if cerr := q.listAllSessionsStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listAllSessionsStmt: %w", cerr)
@@ -195,6 +236,21 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
}
}
+ if q.searchSessionsByTextStmt != nil {
+ if cerr := q.searchSessionsByTextStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing searchSessionsByTextStmt: %w", cerr)
+ }
+ }
+ if q.searchSessionsByTitleStmt != nil {
+ if cerr := q.searchSessionsByTitleStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing searchSessionsByTitleStmt: %w", cerr)
+ }
+ }
+ if q.searchSessionsByTitleAndTextStmt != nil {
+ if cerr := q.searchSessionsByTitleAndTextStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing searchSessionsByTitleAndTextStmt: %w", cerr)
+ }
+ }
if q.updateMessageStmt != nil {
if cerr := q.updateMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
@@ -242,57 +298,71 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
}
type Queries struct {
- db DBTX
- tx *sql.Tx
- createFileStmt *sql.Stmt
- createMessageStmt *sql.Stmt
- createSessionStmt *sql.Stmt
- deleteFileStmt *sql.Stmt
- deleteMessageStmt *sql.Stmt
- deleteSessionStmt *sql.Stmt
- deleteSessionFilesStmt *sql.Stmt
- deleteSessionMessagesStmt *sql.Stmt
- getFileStmt *sql.Stmt
- getFileByPathAndSessionStmt *sql.Stmt
- getMessageStmt *sql.Stmt
- getSessionByIDStmt *sql.Stmt
- listAllSessionsStmt *sql.Stmt
- listChildSessionsStmt *sql.Stmt
- listFilesByPathStmt *sql.Stmt
- listFilesBySessionStmt *sql.Stmt
- listLatestSessionFilesStmt *sql.Stmt
- listMessagesBySessionStmt *sql.Stmt
- listNewFilesStmt *sql.Stmt
- listSessionsStmt *sql.Stmt
- updateMessageStmt *sql.Stmt
- updateSessionStmt *sql.Stmt
+ db DBTX
+ tx *sql.Tx
+ createFileStmt *sql.Stmt
+ createMessageStmt *sql.Stmt
+ createSessionStmt *sql.Stmt
+ deleteFileStmt *sql.Stmt
+ deleteMessageStmt *sql.Stmt
+ deleteSessionStmt *sql.Stmt
+ deleteSessionFilesStmt *sql.Stmt
+ deleteSessionMessagesStmt *sql.Stmt
+ getFileStmt *sql.Stmt
+ getFileByPathAndSessionStmt *sql.Stmt
+ getMessageStmt *sql.Stmt
+ getSessionByIDStmt *sql.Stmt
+ getSessionStatsStmt *sql.Stmt
+ getSessionStatsByDayStmt *sql.Stmt
+ getSessionStatsByMonthStmt *sql.Stmt
+ getSessionStatsByWeekStmt *sql.Stmt
+ listAllSessionsStmt *sql.Stmt
+ listChildSessionsStmt *sql.Stmt
+ listFilesByPathStmt *sql.Stmt
+ listFilesBySessionStmt *sql.Stmt
+ listLatestSessionFilesStmt *sql.Stmt
+ listMessagesBySessionStmt *sql.Stmt
+ listNewFilesStmt *sql.Stmt
+ listSessionsStmt *sql.Stmt
+ searchSessionsByTextStmt *sql.Stmt
+ searchSessionsByTitleStmt *sql.Stmt
+ searchSessionsByTitleAndTextStmt *sql.Stmt
+ updateMessageStmt *sql.Stmt
+ updateSessionStmt *sql.Stmt
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
- db: tx,
- tx: tx,
- createFileStmt: q.createFileStmt,
- createMessageStmt: q.createMessageStmt,
- createSessionStmt: q.createSessionStmt,
- deleteFileStmt: q.deleteFileStmt,
- deleteMessageStmt: q.deleteMessageStmt,
- deleteSessionStmt: q.deleteSessionStmt,
- deleteSessionFilesStmt: q.deleteSessionFilesStmt,
- deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
- getFileStmt: q.getFileStmt,
- getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
- getMessageStmt: q.getMessageStmt,
- getSessionByIDStmt: q.getSessionByIDStmt,
- listAllSessionsStmt: q.listAllSessionsStmt,
- listChildSessionsStmt: q.listChildSessionsStmt,
- listFilesByPathStmt: q.listFilesByPathStmt,
- listFilesBySessionStmt: q.listFilesBySessionStmt,
- listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
- listMessagesBySessionStmt: q.listMessagesBySessionStmt,
- listNewFilesStmt: q.listNewFilesStmt,
- listSessionsStmt: q.listSessionsStmt,
- updateMessageStmt: q.updateMessageStmt,
- updateSessionStmt: q.updateSessionStmt,
+ db: tx,
+ tx: tx,
+ createFileStmt: q.createFileStmt,
+ createMessageStmt: q.createMessageStmt,
+ createSessionStmt: q.createSessionStmt,
+ deleteFileStmt: q.deleteFileStmt,
+ deleteMessageStmt: q.deleteMessageStmt,
+ deleteSessionStmt: q.deleteSessionStmt,
+ deleteSessionFilesStmt: q.deleteSessionFilesStmt,
+ deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
+ getFileStmt: q.getFileStmt,
+ getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
+ getMessageStmt: q.getMessageStmt,
+ getSessionByIDStmt: q.getSessionByIDStmt,
+ getSessionStatsStmt: q.getSessionStatsStmt,
+ getSessionStatsByDayStmt: q.getSessionStatsByDayStmt,
+ getSessionStatsByMonthStmt: q.getSessionStatsByMonthStmt,
+ getSessionStatsByWeekStmt: q.getSessionStatsByWeekStmt,
+ listAllSessionsStmt: q.listAllSessionsStmt,
+ listChildSessionsStmt: q.listChildSessionsStmt,
+ listFilesByPathStmt: q.listFilesByPathStmt,
+ listFilesBySessionStmt: q.listFilesBySessionStmt,
+ listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
+ listMessagesBySessionStmt: q.listMessagesBySessionStmt,
+ listNewFilesStmt: q.listNewFilesStmt,
+ listSessionsStmt: q.listSessionsStmt,
+ searchSessionsByTextStmt: q.searchSessionsByTextStmt,
+ searchSessionsByTitleStmt: q.searchSessionsByTitleStmt,
+ searchSessionsByTitleAndTextStmt: q.searchSessionsByTitleAndTextStmt,
+ updateMessageStmt: q.updateMessageStmt,
+ updateSessionStmt: q.updateSessionStmt,
}
}
@@ -22,6 +22,10 @@ type Querier interface {
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
GetMessage(ctx context.Context, id string) (Message, error)
GetSessionByID(ctx context.Context, id string) (Session, error)
+ GetSessionStats(ctx context.Context) (GetSessionStatsRow, error)
+ GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error)
+ GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error)
+ GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error)
ListAllSessions(ctx context.Context) ([]Session, error)
ListChildSessions(ctx context.Context, parentSessionID sql.NullString) ([]Session, error)
ListFilesByPath(ctx context.Context, path string) ([]File, error)
@@ -30,6 +34,9 @@ type Querier interface {
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
ListNewFiles(ctx context.Context) ([]File, error)
ListSessions(ctx context.Context) ([]Session, error)
+ SearchSessionsByText(ctx context.Context, parts string) ([]Session, error)
+ SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error)
+ SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error)
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
}
@@ -106,6 +106,205 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
return i, err
}
+const getSessionStats = `-- name: GetSessionStats :one
+SELECT
+ COUNT(*) as total_sessions,
+ SUM(message_count) as total_messages,
+ SUM(prompt_tokens) as total_prompt_tokens,
+ SUM(completion_tokens) as total_completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost_per_session
+FROM sessions
+`
+
+type GetSessionStatsRow struct {
+ TotalSessions int64 `json:"total_sessions"`
+ TotalMessages sql.NullFloat64 `json:"total_messages"`
+ TotalPromptTokens sql.NullFloat64 `json:"total_prompt_tokens"`
+ TotalCompletionTokens sql.NullFloat64 `json:"total_completion_tokens"`
+ TotalCost sql.NullFloat64 `json:"total_cost"`
+ AvgCostPerSession sql.NullFloat64 `json:"avg_cost_per_session"`
+}
+
+func (q *Queries) GetSessionStats(ctx context.Context) (GetSessionStatsRow, error) {
+ row := q.queryRow(ctx, q.getSessionStatsStmt, getSessionStats)
+ var i GetSessionStatsRow
+ err := row.Scan(
+ &i.TotalSessions,
+ &i.TotalMessages,
+ &i.TotalPromptTokens,
+ &i.TotalCompletionTokens,
+ &i.TotalCost,
+ &i.AvgCostPerSession,
+ )
+ return i, err
+}
+
+const getSessionStatsByDay = `-- name: GetSessionStatsByDay :many
+SELECT
+ date(created_at, 'unixepoch') as day,
+ COUNT(*) as session_count,
+ SUM(message_count) as message_count,
+ SUM(prompt_tokens) as prompt_tokens,
+ SUM(completion_tokens) as completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost
+FROM sessions
+GROUP BY date(created_at, 'unixepoch')
+ORDER BY day DESC
+`
+
+type GetSessionStatsByDayRow struct {
+ Day interface{} `json:"day"`
+ SessionCount int64 `json:"session_count"`
+ MessageCount sql.NullFloat64 `json:"message_count"`
+ PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
+ CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
+ TotalCost sql.NullFloat64 `json:"total_cost"`
+ AvgCost sql.NullFloat64 `json:"avg_cost"`
+}
+
+func (q *Queries) GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error) {
+ rows, err := q.query(ctx, q.getSessionStatsByDayStmt, getSessionStatsByDay)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []GetSessionStatsByDayRow{}
+ for rows.Next() {
+ var i GetSessionStatsByDayRow
+ if err := rows.Scan(
+ &i.Day,
+ &i.SessionCount,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.TotalCost,
+ &i.AvgCost,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getSessionStatsByMonth = `-- name: GetSessionStatsByMonth :many
+SELECT
+ strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month,
+ COUNT(*) as session_count,
+ SUM(message_count) as message_count,
+ SUM(prompt_tokens) as prompt_tokens,
+ SUM(completion_tokens) as completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost
+FROM sessions
+GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch'))
+ORDER BY month DESC
+`
+
+type GetSessionStatsByMonthRow struct {
+ Month interface{} `json:"month"`
+ SessionCount int64 `json:"session_count"`
+ MessageCount sql.NullFloat64 `json:"message_count"`
+ PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
+ CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
+ TotalCost sql.NullFloat64 `json:"total_cost"`
+ AvgCost sql.NullFloat64 `json:"avg_cost"`
+}
+
+func (q *Queries) GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error) {
+ rows, err := q.query(ctx, q.getSessionStatsByMonthStmt, getSessionStatsByMonth)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []GetSessionStatsByMonthRow{}
+ for rows.Next() {
+ var i GetSessionStatsByMonthRow
+ if err := rows.Scan(
+ &i.Month,
+ &i.SessionCount,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.TotalCost,
+ &i.AvgCost,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getSessionStatsByWeek = `-- name: GetSessionStatsByWeek :many
+SELECT
+ date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start,
+ COUNT(*) as session_count,
+ SUM(message_count) as message_count,
+ SUM(prompt_tokens) as prompt_tokens,
+ SUM(completion_tokens) as completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost
+FROM sessions
+GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days')
+ORDER BY week_start DESC
+`
+
+type GetSessionStatsByWeekRow struct {
+ WeekStart interface{} `json:"week_start"`
+ SessionCount int64 `json:"session_count"`
+ MessageCount sql.NullFloat64 `json:"message_count"`
+ PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
+ CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
+ TotalCost sql.NullFloat64 `json:"total_cost"`
+ AvgCost sql.NullFloat64 `json:"avg_cost"`
+}
+
+func (q *Queries) GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error) {
+ rows, err := q.query(ctx, q.getSessionStatsByWeekStmt, getSessionStatsByWeek)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []GetSessionStatsByWeekRow{}
+ for rows.Next() {
+ var i GetSessionStatsByWeekRow
+ if err := rows.Scan(
+ &i.WeekStart,
+ &i.SessionCount,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.TotalCost,
+ &i.AvgCost,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
const listAllSessions = `-- name: ListAllSessions :many
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
FROM sessions
@@ -228,6 +427,136 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
return items, nil
}
+const searchSessionsByText = `-- name: SearchSessionsByText :many
+SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id
+FROM sessions s
+JOIN messages m ON s.id = m.session_id
+WHERE m.parts LIKE ?
+ORDER BY s.created_at DESC
+`
+
+func (q *Queries) SearchSessionsByText(ctx context.Context, parts string) ([]Session, error) {
+ rows, err := q.query(ctx, q.searchSessionsByTextStmt, searchSessionsByText, parts)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Session{}
+ for rows.Next() {
+ var i Session
+ if err := rows.Scan(
+ &i.ID,
+ &i.ParentSessionID,
+ &i.Title,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.Cost,
+ &i.UpdatedAt,
+ &i.CreatedAt,
+ &i.SummaryMessageID,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const searchSessionsByTitle = `-- name: SearchSessionsByTitle :many
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
+FROM sessions
+WHERE title LIKE ?
+ORDER BY created_at DESC
+`
+
+func (q *Queries) SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error) {
+ rows, err := q.query(ctx, q.searchSessionsByTitleStmt, searchSessionsByTitle, title)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Session{}
+ for rows.Next() {
+ var i Session
+ if err := rows.Scan(
+ &i.ID,
+ &i.ParentSessionID,
+ &i.Title,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.Cost,
+ &i.UpdatedAt,
+ &i.CreatedAt,
+ &i.SummaryMessageID,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const searchSessionsByTitleAndText = `-- name: SearchSessionsByTitleAndText :many
+SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id
+FROM sessions s
+JOIN messages m ON s.id = m.session_id
+WHERE s.title LIKE ? AND m.parts LIKE ?
+ORDER BY s.created_at DESC
+`
+
+type SearchSessionsByTitleAndTextParams struct {
+ Title string `json:"title"`
+ Parts string `json:"parts"`
+}
+
+func (q *Queries) SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error) {
+ rows, err := q.query(ctx, q.searchSessionsByTitleAndTextStmt, searchSessionsByTitleAndText, arg.Title, arg.Parts)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Session{}
+ for rows.Next() {
+ var i Session
+ if err := rows.Scan(
+ &i.ID,
+ &i.ParentSessionID,
+ &i.Title,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.Cost,
+ &i.UpdatedAt,
+ &i.CreatedAt,
+ &i.SummaryMessageID,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
const updateSession = `-- name: UpdateSession :one
UPDATE sessions
SET
@@ -60,3 +60,72 @@ ORDER BY created_at ASC;
SELECT *
FROM sessions
ORDER BY created_at DESC;
+
+-- name: SearchSessionsByTitle :many
+SELECT *
+FROM sessions
+WHERE title LIKE ?
+ORDER BY created_at DESC;
+
+-- name: SearchSessionsByTitleAndText :many
+SELECT DISTINCT s.*
+FROM sessions s
+JOIN messages m ON s.id = m.session_id
+WHERE s.title LIKE ? AND m.parts LIKE ?
+ORDER BY s.created_at DESC;
+
+-- name: SearchSessionsByText :many
+SELECT DISTINCT s.*
+FROM sessions s
+JOIN messages m ON s.id = m.session_id
+WHERE m.parts LIKE ?
+ORDER BY s.created_at DESC;
+
+-- name: GetSessionStats :one
+SELECT
+ COUNT(*) as total_sessions,
+ SUM(message_count) as total_messages,
+ SUM(prompt_tokens) as total_prompt_tokens,
+ SUM(completion_tokens) as total_completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost_per_session
+FROM sessions;
+
+-- name: GetSessionStatsByDay :many
+SELECT
+ date(created_at, 'unixepoch') as day,
+ COUNT(*) as session_count,
+ SUM(message_count) as message_count,
+ SUM(prompt_tokens) as prompt_tokens,
+ SUM(completion_tokens) as completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost
+FROM sessions
+GROUP BY date(created_at, 'unixepoch')
+ORDER BY day DESC;
+
+-- name: GetSessionStatsByWeek :many
+SELECT
+ date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start,
+ COUNT(*) as session_count,
+ SUM(message_count) as message_count,
+ SUM(prompt_tokens) as prompt_tokens,
+ SUM(completion_tokens) as completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost
+FROM sessions
+GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days')
+ORDER BY week_start DESC;
+
+-- name: GetSessionStatsByMonth :many
+SELECT
+ strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month,
+ COUNT(*) as session_count,
+ SUM(message_count) as message_count,
+ SUM(prompt_tokens) as prompt_tokens,
+ SUM(completion_tokens) as completion_tokens,
+ SUM(cost) as total_cost,
+ AVG(cost) as avg_cost
+FROM sessions
+GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch'))
+ORDER BY month DESC;
@@ -33,6 +33,9 @@ type Service interface {
ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
Delete(ctx context.Context, id string) error
+ SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error)
+ SearchByText(ctx context.Context, textPattern string) ([]Session, error)
+ SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error)
}
type service struct {
@@ -161,6 +164,45 @@ func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]S
return sessions, nil
}
+func (s *service) SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) {
+ dbSessions, err := s.q.SearchSessionsByTitle(ctx, "%"+titlePattern+"%")
+ if err != nil {
+ return nil, err
+ }
+ sessions := make([]Session, len(dbSessions))
+ for i, dbSession := range dbSessions {
+ sessions[i] = s.fromDBItem(dbSession)
+ }
+ return sessions, nil
+}
+
+func (s *service) SearchByText(ctx context.Context, textPattern string) ([]Session, error) {
+ dbSessions, err := s.q.SearchSessionsByText(ctx, "%"+textPattern+"%")
+ if err != nil {
+ return nil, err
+ }
+ sessions := make([]Session, len(dbSessions))
+ for i, dbSession := range dbSessions {
+ sessions[i] = s.fromDBItem(dbSession)
+ }
+ return sessions, nil
+}
+
+func (s *service) SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) {
+ dbSessions, err := s.q.SearchSessionsByTitleAndText(ctx, db.SearchSessionsByTitleAndTextParams{
+ Title: "%" + titlePattern + "%",
+ Parts: "%" + textPattern + "%",
+ })
+ if err != nil {
+ return nil, err
+ }
+ sessions := make([]Session, len(dbSessions))
+ for i, dbSession := range dbSessions {
+ sessions[i] = s.fromDBItem(dbSession)
+ }
+ return sessions, nil
+}
+
func (s service) fromDBItem(item db.Session) Session {
return Session{
ID: item.ID,
@@ -0,0 +1,297 @@
+# RFC: Session Import and Export
+
+## Summary
+
+This RFC proposes a comprehensive system for importing and exporting conversation sessions in Crush.
+
+## Background
+
+Crush manages conversations through a hierarchical session system where:
+- Sessions contain metadata (title, token counts, cost, timestamps)
+- Sessions can have parent-child relationships (nested conversations)
+- Messages within sessions have structured content parts (text, tool calls, reasoning, etc.)
+- The current implementation provides export functionality but lacks import capabilities
+
+The latest commit introduced three key commands:
+- `crush sessions list` - List sessions in various formats
+- `crush sessions export` - Export all sessions and metadata
+- `crush sessions export-conversation <session-id>` - Export a single conversation with messages
+
+## Motivation
+
+Users need to:
+1. Share conversations with others
+2. Use conversation logs for debugging
+3. Archive and analyze conversation history
+4. Export data for external tools
+
+## Detailed Design
+
+### Core Data Model
+
+The session export format builds on the existing session structure:
+
+```go
+type Session struct {
+ ID string `json:"id"`
+ ParentSessionID string `json:"parent_session_id,omitempty"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ Cost float64 `json:"cost"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ SummaryMessageID string `json:"summary_message_id,omitempty"`
+}
+
+type SessionWithChildren struct {
+ Session
+ Children []SessionWithChildren `json:"children,omitempty"`
+}
+```
+
+### Proposed Command Interface
+
+#### Export Commands (Already Implemented)
+```bash
+# List sessions in various formats
+crush sessions list [--format text|json|yaml|markdown]
+
+# Export all sessions with metadata
+crush sessions export [--format json|yaml|markdown]
+
+# Export single conversation with full message history
+crush sessions export-conversation <session-id> [--format markdown|json|yaml]
+```
+
+#### New Import Commands
+
+```bash
+# Import sessions from a file
+crush sessions import <file> [--format json|yaml] [--dry-run]
+
+# Import a single conversation
+crush sessions import-conversation <file> [--format json|yaml|markdown]
+
+```
+
+#### Enhanced Inspection Commands
+
+```bash
+# Search sessions by criteria
+crush sessions search [--title <pattern>] [--text <text>] [--format text|json]
+
+# Show session statistics
+crush sessions stats [--format text|json] [--group-by day|week|month]
+
+# Show statistics for a single session
+crush sessions stats <session-id> [--format text|json]
+```
+
+### Import/Export Formats
+
+#### Full Export Format (JSON)
+```json
+{
+ "version": "1.0",
+ "exported_at": "2025-01-27T10:30:00Z",
+ "total_sessions": 15,
+ "sessions": [
+ {
+ "id": "session-123",
+ "parent_session_id": "",
+ "title": "API Design Discussion",
+ "message_count": 8,
+ "prompt_tokens": 1250,
+ "completion_tokens": 890,
+ "cost": 0.0234,
+ "created_at": 1706356200,
+ "updated_at": 1706359800,
+ "children": [
+ {
+ "id": "session-124",
+ "parent_session_id": "session-123",
+ "title": "Implementation Details",
+ "message_count": 4,
+ "prompt_tokens": 650,
+ "completion_tokens": 420,
+ "cost": 0.0145,
+ "created_at": 1706359900,
+ "updated_at": 1706361200
+ }
+ ]
+ }
+ ]
+}
+```
+
+#### Conversation Export Format (JSON)
+```json
+{
+ "version": "1.0",
+ "session": {
+ "id": "session-123",
+ "title": "API Design Discussion",
+ "created_at": 1706356200,
+ "message_count": 3
+ },
+ "messages": [
+ {
+ "id": "msg-001",
+ "session_id": "session-123",
+ "role": "user",
+ "parts": [
+ {
+ "type": "text",
+ "data": {
+ "text": "Help me design a REST API for user management"
+ }
+ }
+ ],
+ "created_at": 1706356200
+ },
+ {
+ "id": "msg-002",
+ "session_id": "session-123",
+ "role": "assistant",
+ "model": "gpt-4",
+ "provider": "openai",
+ "parts": [
+ {
+ "type": "text",
+ "data": {
+ "text": "I'll help you design a REST API for user management..."
+ }
+ },
+ {
+ "type": "finish",
+ "data": {
+ "reason": "stop",
+ "time": 1706356230
+ }
+ }
+ ],
+ "created_at": 1706356220
+ }
+ ]
+}
+```
+
+### API Implementation
+
+#### Import Service Interface
+```go
+type ImportService interface {
+ // Import sessions from structured data
+ ImportSessions(ctx context.Context, data ImportData, opts ImportOptions) (ImportResult, error)
+
+ // Import single conversation
+ ImportConversation(ctx context.Context, data ConversationData, opts ImportOptions) (Session, error)
+
+ // Validate import data without persisting
+ ValidateImport(ctx context.Context, data ImportData) (ValidationResult, error)
+}
+
+type ImportOptions struct {
+ ConflictStrategy ConflictStrategy // skip, merge, replace
+ DryRun bool
+ ParentSessionID string // For conversation imports
+ PreserveIDs bool // Whether to preserve original IDs
+}
+
+type ConflictStrategy string
+
+const (
+ ConflictSkip ConflictStrategy = "skip" // Skip existing sessions
+ ConflictMerge ConflictStrategy = "merge" // Merge with existing
+ ConflictReplace ConflictStrategy = "replace" // Replace existing
+)
+
+type ImportResult struct {
+ TotalSessions int `json:"total_sessions"`
+ ImportedSessions int `json:"imported_sessions"`
+ SkippedSessions int `json:"skipped_sessions"`
+ Errors []ImportError `json:"errors,omitempty"`
+ SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
+}
+```
+
+#### Enhanced Export Service
+```go
+type ExportService interface {
+ // Export sessions with filtering
+ ExportSessions(ctx context.Context, opts ExportOptions) ([]SessionWithChildren, error)
+
+ // Export conversation with full message history
+ ExportConversation(ctx context.Context, sessionID string, opts ExportOptions) (ConversationExport, error)
+
+ // Search and filter sessions
+ SearchSessions(ctx context.Context, criteria SearchCriteria) ([]Session, error)
+
+ // Get session statistics
+ GetStats(ctx context.Context, opts StatsOptions) (SessionStats, error)
+}
+
+type ExportOptions struct {
+ Format string // json, yaml, markdown
+ IncludeMessages bool // Include full message content
+ DateRange DateRange // Filter by date range
+ SessionIDs []string // Export specific sessions
+}
+
+type SearchCriteria struct {
+ TitlePattern string
+ DateRange DateRange
+ MinCost float64
+ MaxCost float64
+ ParentSessionID string
+ HasChildren *bool
+}
+```
+
+## Implementation Status
+
+The proposed session import/export functionality has been implemented as a prototype as of July 2025.
+
+### Implemented Commands
+
+All new commands have been added to `internal/cmd/sessions.go`:
+
+- **Import**: `crush sessions import <file> [--format json|yaml] [--dry-run]`
+ - Supports hierarchical session imports with parent-child relationships
+ - Generates new UUIDs to avoid conflicts
+ - Includes validation and dry-run capabilities
+
+- **Import Conversation**: `crush sessions import-conversation <file> [--format json|yaml]`
+ - Imports single conversations with full message history
+ - Preserves all message content parts and metadata
+
+- **Search**: `crush sessions search [--title <pattern>] [--text <text>] [--format text|json]`
+ - Case-insensitive title search and message content search
+ - Supports combined search criteria with AND logic
+
+- **Stats**: `crush sessions stats [--format text|json] [--group-by day|week|month]`
+ - Comprehensive usage statistics (sessions, messages, tokens, costs)
+ - Time-based grouping with efficient database queries
+
+### Database Changes
+
+Added new SQL queries in `internal/db/sql/sessions.sql`:
+- Search queries for title and message content filtering
+- Statistics aggregation queries with time-based grouping
+- All queries optimized for performance with proper indexing
+
+### Database Schema Considerations
+
+The current schema supports the import/export functionality. Additional indexes may be needed for search performance:
+
+```sql
+-- Optimize session searches by date and cost
+CREATE INDEX idx_sessions_created_at ON sessions(created_at);
+CREATE INDEX idx_sessions_cost ON sessions(cost);
+CREATE INDEX idx_sessions_title ON sessions(title COLLATE NOCASE);
+
+-- Optimize message searches by session
+CREATE INDEX idx_messages_session_created ON messages(session_id, created_at);
+```