diff --git a/internal/cmd/sessions.go b/internal/cmd/sessions.go index 2937e70c5f5db9fae6cfc1d77574a81925049b1b..76ba27095fc772388b61ade3a98e2a2d6e2fdd5b 100644 --- a/internal/cmd/sessions.go +++ b/internal/cmd/sessions.go @@ -1,10 +1,13 @@ package cmd import ( + "bytes" "context" + "database/sql" "encoding/json" "fmt" "os" + "path/filepath" "strings" "time" @@ -12,6 +15,7 @@ import ( "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" + "github.com/google/uuid" "github.com/spf13/cobra" "gopkg.in/yaml.v3" ) @@ -22,6 +26,89 @@ type SessionWithChildren struct { Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"` } +// ImportSession represents a session with proper JSON tags for import +type ImportSession struct { + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + SummaryMessageID string `json:"summary_message_id,omitempty"` + Children []ImportSession `json:"children,omitempty"` +} + +// ImportData represents the full import structure for sessions +type ImportData struct { + Version string `json:"version" yaml:"version"` + ExportedAt string `json:"exported_at,omitempty" yaml:"exported_at,omitempty"` + TotalSessions int `json:"total_sessions,omitempty" yaml:"total_sessions,omitempty"` + Sessions []ImportSession `json:"sessions" yaml:"sessions"` +} + +// ImportMessage represents a message with proper JSON tags for import +type ImportMessage struct { + ID string `json:"id"` + Role string `json:"role"` + SessionID string `json:"session_id"` + Parts []interface{} `json:"parts"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + CreatedAt int64 `json:"created_at"` +} + +// ImportSessionInfo represents session info with proper JSON tags for conversation import +type ImportSessionInfo struct { + ID string `json:"id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens,omitempty"` + CompletionTokens int64 `json:"completion_tokens,omitempty"` + Cost float64 `json:"cost,omitempty"` + CreatedAt int64 `json:"created_at"` +} + +// ConversationData represents a single conversation import structure +type ConversationData struct { + Version string `json:"version" yaml:"version"` + Session ImportSessionInfo `json:"session" yaml:"session"` + Messages []ImportMessage `json:"messages" yaml:"messages"` +} + +// ImportResult contains the results of an import operation +type ImportResult struct { + TotalSessions int `json:"total_sessions"` + ImportedSessions int `json:"imported_sessions"` + SkippedSessions int `json:"skipped_sessions"` + ImportedMessages int `json:"imported_messages"` + Errors []string `json:"errors,omitempty"` + SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id +} + +// SessionStats represents aggregated session statistics +type SessionStats struct { + TotalSessions int64 `json:"total_sessions"` + TotalMessages int64 `json:"total_messages"` + TotalPromptTokens int64 `json:"total_prompt_tokens"` + TotalCompletionTokens int64 `json:"total_completion_tokens"` + TotalCost float64 `json:"total_cost"` + AvgCostPerSession float64 `json:"avg_cost_per_session"` +} + +// GroupedSessionStats represents statistics grouped by time period +type GroupedSessionStats struct { + Period string `json:"period"` + SessionCount int64 `json:"session_count"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalCost float64 `json:"total_cost"` + AvgCost float64 `json:"avg_cost"` +} + var sessionsCmd = &cobra.Command{ Use: "sessions", Short: "Manage sessions", @@ -60,15 +147,80 @@ var exportConversationCmd = &cobra.Command{ }, } +var importCmd = &cobra.Command{ + Use: "import ", + Short: "Import sessions from a file", + Long: `Import sessions from a JSON or YAML file with hierarchical structure`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + file := args[0] + format, _ := cmd.Flags().GetString("format") + dryRun, _ := cmd.Flags().GetBool("dry-run") + return runImport(cmd.Context(), file, format, dryRun) + }, +} + +var importConversationCmd = &cobra.Command{ + Use: "import-conversation ", + Short: "Import a single conversation from a file", + Long: `Import a single conversation with messages from a JSON, YAML, or Markdown file`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + file := args[0] + format, _ := cmd.Flags().GetString("format") + return runImportConversation(cmd.Context(), file, format) + }, +} + +var searchCmd = &cobra.Command{ + Use: "search", + Short: "Search sessions by title or message content", + Long: `Search sessions by title pattern (case-insensitive) or message text content`, + RunE: func(cmd *cobra.Command, args []string) error { + titlePattern, _ := cmd.Flags().GetString("title") + textPattern, _ := cmd.Flags().GetString("text") + format, _ := cmd.Flags().GetString("format") + + if titlePattern == "" && textPattern == "" { + return fmt.Errorf("at least one of --title or --text must be provided") + } + + return runSessionsSearch(cmd.Context(), titlePattern, textPattern, format) + }, +} + +var statsCmd = &cobra.Command{ + Use: "stats", + Short: "Show session statistics", + Long: `Display aggregated statistics about sessions including total counts, tokens, and costs`, + RunE: func(cmd *cobra.Command, args []string) error { + format, _ := cmd.Flags().GetString("format") + groupBy, _ := cmd.Flags().GetString("group-by") + return runSessionsStats(cmd.Context(), format, groupBy) + }, +} + func init() { rootCmd.AddCommand(sessionsCmd) sessionsCmd.AddCommand(listCmd) sessionsCmd.AddCommand(exportCmd) sessionsCmd.AddCommand(exportConversationCmd) + sessionsCmd.AddCommand(importCmd) + sessionsCmd.AddCommand(importConversationCmd) + sessionsCmd.AddCommand(searchCmd) + sessionsCmd.AddCommand(statsCmd) listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)") exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)") exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)") + importCmd.Flags().StringP("format", "f", "", "Import format (json, yaml) - auto-detected if not specified") + importCmd.Flags().Bool("dry-run", false, "Validate import data without persisting changes") + importConversationCmd.Flags().StringP("format", "f", "", "Import format (json, yaml, markdown) - auto-detected if not specified") + searchCmd.Flags().String("title", "", "Search by session title pattern (case-insensitive substring search)") + searchCmd.Flags().String("text", "", "Search by message text content") + searchCmd.Flags().StringP("format", "f", "text", "Output format (text, json)") + statsCmd.Flags().StringP("format", "f", "text", "Output format (text, json)") + statsCmd.Flags().String("group-by", "", "Group statistics by time period (day, week, month)") } func runSessionsList(ctx context.Context, format string) error { @@ -442,3 +594,773 @@ func formatConversationYAML(sess session.Session, messages []message.Message) er fmt.Println(string(yamlData)) return nil } + +func runImport(ctx context.Context, file, format string, dryRun bool) error { + // Read the file + data, err := readImportFile(file, format) + if err != nil { + return fmt.Errorf("failed to read import file: %w", err) + } + + // Validate the data structure + if err := validateImportData(data); err != nil { + return fmt.Errorf("invalid import data: %w", err) + } + + if dryRun { + result := ImportResult{ + TotalSessions: countTotalImportSessions(data.Sessions), + ImportedSessions: 0, + SkippedSessions: 0, + ImportedMessages: 0, + SessionMapping: make(map[string]string), + } + fmt.Printf("Dry run: Would import %d sessions\n", result.TotalSessions) + return nil + } + + // Perform the actual import + sessionService, messageService, err := createServices(ctx) + if err != nil { + return err + } + + result, err := importSessions(ctx, sessionService, messageService, data) + if err != nil { + return fmt.Errorf("import failed: %w", err) + } + + // Print summary + fmt.Printf("Import completed successfully:\n") + fmt.Printf(" Total sessions processed: %d\n", result.TotalSessions) + fmt.Printf(" Sessions imported: %d\n", result.ImportedSessions) + fmt.Printf(" Sessions skipped: %d\n", result.SkippedSessions) + fmt.Printf(" Messages imported: %d\n", result.ImportedMessages) + + if len(result.Errors) > 0 { + fmt.Printf(" Errors encountered: %d\n", len(result.Errors)) + for _, errStr := range result.Errors { + fmt.Printf(" - %s\n", errStr) + } + } + + return nil +} + +func runImportConversation(ctx context.Context, file, format string) error { + // Read the conversation file + convData, err := readConversationFile(file, format) + if err != nil { + return fmt.Errorf("failed to read conversation file: %w", err) + } + + // Validate the conversation data + if err := validateConversationData(convData); err != nil { + return fmt.Errorf("invalid conversation data: %w", err) + } + + // Import the conversation + sessionService, messageService, err := createServices(ctx) + if err != nil { + return err + } + + newSessionID, messageCount, err := importConversation(ctx, sessionService, messageService, convData) + if err != nil { + return fmt.Errorf("conversation import failed: %w", err) + } + + fmt.Printf("Conversation imported successfully:\n") + fmt.Printf(" Session ID: %s\n", newSessionID) + fmt.Printf(" Title: %s\n", convData.Session.Title) + fmt.Printf(" Messages imported: %d\n", messageCount) + + return nil +} + +func readImportFile(file, format string) (*ImportData, error) { + fileData, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", file, err) + } + + // Auto-detect format if not specified + if format == "" { + format = detectFormat(file, fileData) + } + + var data ImportData + switch strings.ToLower(format) { + case "json": + if err := json.Unmarshal(fileData, &data); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + case "yaml", "yml": + if err := yaml.Unmarshal(fileData, &data); err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + default: + return nil, fmt.Errorf("unsupported format: %s", format) + } + + return &data, nil +} + +func readConversationFile(file, format string) (*ConversationData, error) { + fileData, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", file, err) + } + + // Auto-detect format if not specified + if format == "" { + format = detectFormat(file, fileData) + } + + var data ConversationData + switch strings.ToLower(format) { + case "json": + if err := json.Unmarshal(fileData, &data); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + case "yaml", "yml": + if err := yaml.Unmarshal(fileData, &data); err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + case "markdown", "md": + return nil, fmt.Errorf("markdown import for conversations is not yet implemented") + default: + return nil, fmt.Errorf("unsupported format: %s", format) + } + + return &data, nil +} + +func detectFormat(filename string, data []byte) string { + // First try file extension + ext := strings.ToLower(filepath.Ext(filename)) + switch ext { + case ".json": + return "json" + case ".yaml", ".yml": + return "yaml" + case ".md", ".markdown": + return "markdown" + } + + // Try to detect from content + data = bytes.TrimSpace(data) + if len(data) > 0 { + if data[0] == '{' || data[0] == '[' { + return "json" + } + if strings.HasPrefix(string(data), "---") || strings.Contains(string(data), ":") { + return "yaml" + } + } + + return "json" // default fallback +} + +func validateImportData(data *ImportData) error { + if data == nil { + return fmt.Errorf("import data is nil") + } + + if len(data.Sessions) == 0 { + return fmt.Errorf("no sessions to import") + } + + // Validate session structure + for i, sess := range data.Sessions { + if err := validateImportSessionHierarchy(sess, ""); err != nil { + return fmt.Errorf("session %d validation failed: %w", i, err) + } + } + + return nil +} + +func validateConversationData(data *ConversationData) error { + if data == nil { + return fmt.Errorf("conversation data is nil") + } + + if data.Session.Title == "" { + return fmt.Errorf("session title is required") + } + + if len(data.Messages) == 0 { + return fmt.Errorf("no messages to import") + } + + return nil +} + +func validateImportSessionHierarchy(sess ImportSession, expectedParent string) error { + if sess.ID == "" { + return fmt.Errorf("session ID is required") + } + + if sess.Title == "" { + return fmt.Errorf("session title is required") + } + + // For top-level sessions, expectedParent should be empty and session should have no parent or empty parent + if expectedParent == "" { + if sess.ParentSessionID != "" { + return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID) + } + } else { + // For child sessions, parent should match expected parent + if sess.ParentSessionID != expectedParent { + return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID) + } + } + + // Validate children + for _, child := range sess.Children { + if err := validateImportSessionHierarchy(child, sess.ID); err != nil { + return err + } + } + + return nil +} + +func validateSessionHierarchy(sess SessionWithChildren, expectedParent string) error { + if sess.ID == "" { + return fmt.Errorf("session ID is required") + } + + if sess.Title == "" { + return fmt.Errorf("session title is required") + } + + // For top-level sessions, expectedParent should be empty and session should have no parent or empty parent + if expectedParent == "" { + if sess.ParentSessionID != "" { + return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID) + } + } else { + // For child sessions, parent should match expected parent + if sess.ParentSessionID != expectedParent { + return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID) + } + } + + // Validate children + for _, child := range sess.Children { + if err := validateSessionHierarchy(child, sess.ID); err != nil { + return err + } + } + + return nil +} + +func countTotalImportSessions(sessions []ImportSession) int { + count := len(sessions) + for _, sess := range sessions { + count += countTotalImportSessions(sess.Children) + } + return count +} + +func countTotalSessions(sessions []SessionWithChildren) int { + count := len(sessions) + for _, sess := range sessions { + count += countTotalSessions(sess.Children) + } + return count +} + +func importSessions(ctx context.Context, sessionService session.Service, messageService message.Service, data *ImportData) (ImportResult, error) { + result := ImportResult{ + TotalSessions: countTotalImportSessions(data.Sessions), + SessionMapping: make(map[string]string), + } + + // Import sessions recursively, starting with top-level sessions + for _, sess := range data.Sessions { + err := importImportSessionWithChildren(ctx, sessionService, messageService, sess, "", &result) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to import session %s: %v", sess.ID, err)) + } + } + + return result, nil +} + +func importConversation(ctx context.Context, sessionService session.Service, messageService message.Service, data *ConversationData) (string, int, error) { + // Generate new session ID + newSessionID := uuid.New().String() + + // Create the session using the low-level database API + cwd, err := getCwd() + if err != nil { + return "", 0, err + } + + cfg, err := config.Init(cwd, false) + if err != nil { + return "", 0, err + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return "", 0, err + } + + queries := db.New(conn) + + // Create session with all original metadata + _, err = queries.CreateSession(ctx, db.CreateSessionParams{ + ID: newSessionID, + ParentSessionID: sql.NullString{Valid: false}, + Title: data.Session.Title, + MessageCount: data.Session.MessageCount, + PromptTokens: data.Session.PromptTokens, + CompletionTokens: data.Session.CompletionTokens, + Cost: data.Session.Cost, + }) + if err != nil { + return "", 0, fmt.Errorf("failed to create session: %w", err) + } + + // Import messages + messageCount := 0 + for _, msg := range data.Messages { + // Generate new message ID + newMessageID := uuid.New().String() + + // Marshal message parts + partsJSON, err := json.Marshal(msg.Parts) + if err != nil { + return "", 0, fmt.Errorf("failed to marshal message parts: %w", err) + } + + // Create message + _, err = queries.CreateMessage(ctx, db.CreateMessageParams{ + ID: newMessageID, + SessionID: newSessionID, + Role: string(msg.Role), + Parts: string(partsJSON), + Model: sql.NullString{String: msg.Model, Valid: msg.Model != ""}, + Provider: sql.NullString{String: msg.Provider, Valid: msg.Provider != ""}, + }) + if err != nil { + return "", 0, fmt.Errorf("failed to create message: %w", err) + } + messageCount++ + } + + return newSessionID, messageCount, nil +} + +func importImportSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess ImportSession, parentID string, result *ImportResult) error { + // Generate new session ID + newSessionID := uuid.New().String() + result.SessionMapping[sess.ID] = newSessionID + + // Create the session using the low-level database API to preserve metadata + cwd, err := getCwd() + if err != nil { + return err + } + + cfg, err := config.Init(cwd, false) + if err != nil { + return err + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return err + } + + queries := db.New(conn) + + // Create session with all original metadata + parentSessionID := sql.NullString{Valid: false} + if parentID != "" { + parentSessionID = sql.NullString{String: parentID, Valid: true} + } + + _, err = queries.CreateSession(ctx, db.CreateSessionParams{ + ID: newSessionID, + ParentSessionID: parentSessionID, + Title: sess.Title, + MessageCount: sess.MessageCount, + PromptTokens: sess.PromptTokens, + CompletionTokens: sess.CompletionTokens, + Cost: sess.Cost, + }) + if err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + + result.ImportedSessions++ + + // Import children recursively + for _, child := range sess.Children { + err := importImportSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err)) + } + } + + return nil +} + +func importSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess SessionWithChildren, parentID string, result *ImportResult) error { + // Generate new session ID + newSessionID := uuid.New().String() + result.SessionMapping[sess.ID] = newSessionID + + // Create the session using the low-level database API to preserve metadata + cwd, err := getCwd() + if err != nil { + return err + } + + cfg, err := config.Init(cwd, false) + if err != nil { + return err + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return err + } + + queries := db.New(conn) + + // Create session with all original metadata + parentSessionID := sql.NullString{Valid: false} + if parentID != "" { + parentSessionID = sql.NullString{String: parentID, Valid: true} + } + + _, err = queries.CreateSession(ctx, db.CreateSessionParams{ + ID: newSessionID, + ParentSessionID: parentSessionID, + Title: sess.Title, + MessageCount: sess.MessageCount, + PromptTokens: sess.PromptTokens, + CompletionTokens: sess.CompletionTokens, + Cost: sess.Cost, + }) + if err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + + result.ImportedSessions++ + + // Import children recursively + for _, child := range sess.Children { + err := importSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result) + if err != nil { + result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err)) + } + } + + return nil +} + +func runSessionsSearch(ctx context.Context, titlePattern, textPattern, format string) error { + sessionService, err := createSessionService(ctx) + if err != nil { + return err + } + + var sessions []session.Session + + // Determine which search method to use based on provided patterns + if titlePattern != "" && textPattern != "" { + sessions, err = sessionService.SearchByTitleAndText(ctx, titlePattern, textPattern) + } else if titlePattern != "" { + sessions, err = sessionService.SearchByTitle(ctx, titlePattern) + } else if textPattern != "" { + sessions, err = sessionService.SearchByText(ctx, textPattern) + } + + if err != nil { + return fmt.Errorf("search failed: %w", err) + } + + return formatSearchResults(sessions, format) +} + +func formatSearchResults(sessions []session.Session, format string) error { + switch strings.ToLower(format) { + case "json": + return formatSearchResultsJSON(sessions) + case "text": + return formatSearchResultsText(sessions) + default: + return fmt.Errorf("unsupported format: %s", format) + } +} + +func formatSearchResultsJSON(sessions []session.Session) error { + data, err := json.MarshalIndent(sessions, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + fmt.Println(string(data)) + return nil +} + +func formatSearchResultsText(sessions []session.Session) error { + if len(sessions) == 0 { + fmt.Println("No sessions found matching the search criteria.") + return nil + } + + fmt.Printf("Found %d session(s):\n\n", len(sessions)) + for _, sess := range sessions { + fmt.Printf("• %s (ID: %s)\n", sess.Title, sess.ID) + fmt.Printf(" Messages: %d, Cost: $%.4f\n", sess.MessageCount, sess.Cost) + fmt.Printf(" Created: %s\n", formatTimestamp(sess.CreatedAt)) + if sess.ParentSessionID != "" { + fmt.Printf(" Parent: %s\n", sess.ParentSessionID) + } + fmt.Println() + } + + return nil +} + +func runSessionsStats(ctx context.Context, format, groupBy string) error { + // Get database connection + cwd, err := getCwd() + if err != nil { + return err + } + + cfg, err := config.Init(cwd, false) + if err != nil { + return err + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return err + } + + queries := db.New(conn) + + // Handle grouped statistics + if groupBy != "" { + return runGroupedStats(ctx, queries, format, groupBy) + } + + // Get overall statistics + statsRow, err := queries.GetSessionStats(ctx) + if err != nil { + return fmt.Errorf("failed to get session stats: %w", err) + } + + // Convert to our struct, handling NULL values + stats := SessionStats{ + TotalSessions: statsRow.TotalSessions, + TotalMessages: convertNullFloat64ToInt64(statsRow.TotalMessages), + TotalPromptTokens: convertNullFloat64ToInt64(statsRow.TotalPromptTokens), + TotalCompletionTokens: convertNullFloat64ToInt64(statsRow.TotalCompletionTokens), + TotalCost: convertNullFloat64(statsRow.TotalCost), + AvgCostPerSession: convertNullFloat64(statsRow.AvgCostPerSession), + } + + return formatStats(stats, format) +} + +func runGroupedStats(ctx context.Context, queries *db.Queries, format, groupBy string) error { + var groupedStats []GroupedSessionStats + + switch strings.ToLower(groupBy) { + case "day": + rows, err := queries.GetSessionStatsByDay(ctx) + if err != nil { + return fmt.Errorf("failed to get daily stats: %w", err) + } + groupedStats = convertDayStatsRows(rows) + case "week": + rows, err := queries.GetSessionStatsByWeek(ctx) + if err != nil { + return fmt.Errorf("failed to get weekly stats: %w", err) + } + groupedStats = convertWeekStatsRows(rows) + case "month": + rows, err := queries.GetSessionStatsByMonth(ctx) + if err != nil { + return fmt.Errorf("failed to get monthly stats: %w", err) + } + groupedStats = convertMonthStatsRows(rows) + default: + return fmt.Errorf("unsupported group-by value: %s. Valid values are: day, week, month", groupBy) + } + + return formatGroupedStats(groupedStats, format, groupBy) +} + +func convertNullFloat64(val sql.NullFloat64) float64 { + if val.Valid { + return val.Float64 + } + return 0.0 +} + +func convertNullFloat64ToInt64(val sql.NullFloat64) int64 { + if val.Valid { + return int64(val.Float64) + } + return 0 +} + +func convertDayStatsRows(rows []db.GetSessionStatsByDayRow) []GroupedSessionStats { + result := make([]GroupedSessionStats, 0, len(rows)) + for _, row := range rows { + stats := GroupedSessionStats{ + Period: fmt.Sprintf("%v", row.Day), + SessionCount: row.SessionCount, + MessageCount: convertNullFloat64ToInt64(row.MessageCount), + PromptTokens: convertNullFloat64ToInt64(row.PromptTokens), + CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens), + TotalCost: convertNullFloat64(row.TotalCost), + AvgCost: convertNullFloat64(row.AvgCost), + } + result = append(result, stats) + } + return result +} + +func convertWeekStatsRows(rows []db.GetSessionStatsByWeekRow) []GroupedSessionStats { + result := make([]GroupedSessionStats, 0, len(rows)) + for _, row := range rows { + stats := GroupedSessionStats{ + Period: fmt.Sprintf("%v", row.WeekStart), + SessionCount: row.SessionCount, + MessageCount: convertNullFloat64ToInt64(row.MessageCount), + PromptTokens: convertNullFloat64ToInt64(row.PromptTokens), + CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens), + TotalCost: convertNullFloat64(row.TotalCost), + AvgCost: convertNullFloat64(row.AvgCost), + } + result = append(result, stats) + } + return result +} + +func convertMonthStatsRows(rows []db.GetSessionStatsByMonthRow) []GroupedSessionStats { + result := make([]GroupedSessionStats, 0, len(rows)) + for _, row := range rows { + stats := GroupedSessionStats{ + Period: fmt.Sprintf("%v", row.Month), + SessionCount: row.SessionCount, + MessageCount: convertNullFloat64ToInt64(row.MessageCount), + PromptTokens: convertNullFloat64ToInt64(row.PromptTokens), + CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens), + TotalCost: convertNullFloat64(row.TotalCost), + AvgCost: convertNullFloat64(row.AvgCost), + } + result = append(result, stats) + } + return result +} + +func formatStats(stats SessionStats, format string) error { + switch strings.ToLower(format) { + case "json": + return formatStatsJSON(stats) + case "text": + return formatStatsText(stats) + default: + return fmt.Errorf("unsupported format: %s", format) + } +} + +func formatGroupedStats(stats []GroupedSessionStats, format, groupBy string) error { + switch strings.ToLower(format) { + case "json": + return formatGroupedStatsJSON(stats) + case "text": + return formatGroupedStatsText(stats, groupBy) + default: + return fmt.Errorf("unsupported format: %s", format) + } +} + +func formatStatsJSON(stats SessionStats) error { + data, err := json.MarshalIndent(stats, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + fmt.Println(string(data)) + return nil +} + +func formatStatsText(stats SessionStats) error { + if stats.TotalSessions == 0 { + fmt.Println("No sessions found.") + return nil + } + + fmt.Println("Session Statistics") + fmt.Println("==================") + fmt.Printf("Total Sessions: %d\n", stats.TotalSessions) + fmt.Printf("Total Messages: %d\n", stats.TotalMessages) + fmt.Printf("Total Prompt Tokens: %d\n", stats.TotalPromptTokens) + fmt.Printf("Total Completion Tokens: %d\n", stats.TotalCompletionTokens) + fmt.Printf("Total Cost: $%.4f\n", stats.TotalCost) + fmt.Printf("Average Cost/Session: $%.4f\n", stats.AvgCostPerSession) + + totalTokens := stats.TotalPromptTokens + stats.TotalCompletionTokens + if totalTokens > 0 { + fmt.Printf("Total Tokens: %d\n", totalTokens) + fmt.Printf("Average Tokens/Session: %.1f\n", float64(totalTokens)/float64(stats.TotalSessions)) + } + + if stats.TotalSessions > 0 { + fmt.Printf("Average Messages/Session: %.1f\n", float64(stats.TotalMessages)/float64(stats.TotalSessions)) + } + + return nil +} + +func formatGroupedStatsJSON(stats []GroupedSessionStats) error { + data, err := json.MarshalIndent(stats, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + fmt.Println(string(data)) + return nil +} + +func formatGroupedStatsText(stats []GroupedSessionStats, groupBy string) error { + if len(stats) == 0 { + fmt.Printf("No sessions found for grouping by %s.\n", groupBy) + return nil + } + + fmt.Printf("Session Statistics (Grouped by %s)\n", strings.ToUpper(groupBy[:1])+groupBy[1:]) + fmt.Println(strings.Repeat("=", 30+len(groupBy))) + fmt.Println() + + for _, stat := range stats { + fmt.Printf("Period: %s\n", stat.Period) + fmt.Printf(" Sessions: %d\n", stat.SessionCount) + fmt.Printf(" Messages: %d\n", stat.MessageCount) + fmt.Printf(" Prompt Tokens: %d\n", stat.PromptTokens) + fmt.Printf(" Completion Tokens: %d\n", stat.CompletionTokens) + fmt.Printf(" Total Cost: $%.4f\n", stat.TotalCost) + fmt.Printf(" Average Cost: $%.4f\n", stat.AvgCost) + totalTokens := stat.PromptTokens + stat.CompletionTokens + if totalTokens > 0 { + fmt.Printf(" Total Tokens: %d\n", totalTokens) + } + fmt.Println() + } + + return nil +} diff --git a/internal/db/db.go b/internal/db/db.go index 3c83e53a8bae375f59ef32775e2511864eec3d7e..d0d337e76ca885f9ad2b321feba6d8a62046cf0f 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -60,6 +60,18 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil { return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err) } + if q.getSessionStatsStmt, err = db.PrepareContext(ctx, getSessionStats); err != nil { + return nil, fmt.Errorf("error preparing query GetSessionStats: %w", err) + } + if q.getSessionStatsByDayStmt, err = db.PrepareContext(ctx, getSessionStatsByDay); err != nil { + return nil, fmt.Errorf("error preparing query GetSessionStatsByDay: %w", err) + } + if q.getSessionStatsByMonthStmt, err = db.PrepareContext(ctx, getSessionStatsByMonth); err != nil { + return nil, fmt.Errorf("error preparing query GetSessionStatsByMonth: %w", err) + } + if q.getSessionStatsByWeekStmt, err = db.PrepareContext(ctx, getSessionStatsByWeek); err != nil { + return nil, fmt.Errorf("error preparing query GetSessionStatsByWeek: %w", err) + } if q.listAllSessionsStmt, err = db.PrepareContext(ctx, listAllSessions); err != nil { return nil, fmt.Errorf("error preparing query ListAllSessions: %w", err) } @@ -84,6 +96,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil { return nil, fmt.Errorf("error preparing query ListSessions: %w", err) } + if q.searchSessionsByTextStmt, err = db.PrepareContext(ctx, searchSessionsByText); err != nil { + return nil, fmt.Errorf("error preparing query SearchSessionsByText: %w", err) + } + if q.searchSessionsByTitleStmt, err = db.PrepareContext(ctx, searchSessionsByTitle); err != nil { + return nil, fmt.Errorf("error preparing query SearchSessionsByTitle: %w", err) + } + if q.searchSessionsByTitleAndTextStmt, err = db.PrepareContext(ctx, searchSessionsByTitleAndText); err != nil { + return nil, fmt.Errorf("error preparing query SearchSessionsByTitleAndText: %w", err) + } if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil { return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err) } @@ -155,6 +176,26 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr) } } + if q.getSessionStatsStmt != nil { + if cerr := q.getSessionStatsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getSessionStatsStmt: %w", cerr) + } + } + if q.getSessionStatsByDayStmt != nil { + if cerr := q.getSessionStatsByDayStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getSessionStatsByDayStmt: %w", cerr) + } + } + if q.getSessionStatsByMonthStmt != nil { + if cerr := q.getSessionStatsByMonthStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getSessionStatsByMonthStmt: %w", cerr) + } + } + if q.getSessionStatsByWeekStmt != nil { + if cerr := q.getSessionStatsByWeekStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getSessionStatsByWeekStmt: %w", cerr) + } + } if q.listAllSessionsStmt != nil { if cerr := q.listAllSessionsStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listAllSessionsStmt: %w", cerr) @@ -195,6 +236,21 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing listSessionsStmt: %w", cerr) } } + if q.searchSessionsByTextStmt != nil { + if cerr := q.searchSessionsByTextStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing searchSessionsByTextStmt: %w", cerr) + } + } + if q.searchSessionsByTitleStmt != nil { + if cerr := q.searchSessionsByTitleStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing searchSessionsByTitleStmt: %w", cerr) + } + } + if q.searchSessionsByTitleAndTextStmt != nil { + if cerr := q.searchSessionsByTitleAndTextStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing searchSessionsByTitleAndTextStmt: %w", cerr) + } + } if q.updateMessageStmt != nil { if cerr := q.updateMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateMessageStmt: %w", cerr) @@ -242,57 +298,71 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - createFileStmt *sql.Stmt - createMessageStmt *sql.Stmt - createSessionStmt *sql.Stmt - deleteFileStmt *sql.Stmt - deleteMessageStmt *sql.Stmt - deleteSessionStmt *sql.Stmt - deleteSessionFilesStmt *sql.Stmt - deleteSessionMessagesStmt *sql.Stmt - getFileStmt *sql.Stmt - getFileByPathAndSessionStmt *sql.Stmt - getMessageStmt *sql.Stmt - getSessionByIDStmt *sql.Stmt - listAllSessionsStmt *sql.Stmt - listChildSessionsStmt *sql.Stmt - listFilesByPathStmt *sql.Stmt - listFilesBySessionStmt *sql.Stmt - listLatestSessionFilesStmt *sql.Stmt - listMessagesBySessionStmt *sql.Stmt - listNewFilesStmt *sql.Stmt - listSessionsStmt *sql.Stmt - updateMessageStmt *sql.Stmt - updateSessionStmt *sql.Stmt + db DBTX + tx *sql.Tx + createFileStmt *sql.Stmt + createMessageStmt *sql.Stmt + createSessionStmt *sql.Stmt + deleteFileStmt *sql.Stmt + deleteMessageStmt *sql.Stmt + deleteSessionStmt *sql.Stmt + deleteSessionFilesStmt *sql.Stmt + deleteSessionMessagesStmt *sql.Stmt + getFileStmt *sql.Stmt + getFileByPathAndSessionStmt *sql.Stmt + getMessageStmt *sql.Stmt + getSessionByIDStmt *sql.Stmt + getSessionStatsStmt *sql.Stmt + getSessionStatsByDayStmt *sql.Stmt + getSessionStatsByMonthStmt *sql.Stmt + getSessionStatsByWeekStmt *sql.Stmt + listAllSessionsStmt *sql.Stmt + listChildSessionsStmt *sql.Stmt + listFilesByPathStmt *sql.Stmt + listFilesBySessionStmt *sql.Stmt + listLatestSessionFilesStmt *sql.Stmt + listMessagesBySessionStmt *sql.Stmt + listNewFilesStmt *sql.Stmt + listSessionsStmt *sql.Stmt + searchSessionsByTextStmt *sql.Stmt + searchSessionsByTitleStmt *sql.Stmt + searchSessionsByTitleAndTextStmt *sql.Stmt + updateMessageStmt *sql.Stmt + updateSessionStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, - tx: tx, - createFileStmt: q.createFileStmt, - createMessageStmt: q.createMessageStmt, - createSessionStmt: q.createSessionStmt, - deleteFileStmt: q.deleteFileStmt, - deleteMessageStmt: q.deleteMessageStmt, - deleteSessionStmt: q.deleteSessionStmt, - deleteSessionFilesStmt: q.deleteSessionFilesStmt, - deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, - getFileStmt: q.getFileStmt, - getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, - getMessageStmt: q.getMessageStmt, - getSessionByIDStmt: q.getSessionByIDStmt, - listAllSessionsStmt: q.listAllSessionsStmt, - listChildSessionsStmt: q.listChildSessionsStmt, - listFilesByPathStmt: q.listFilesByPathStmt, - listFilesBySessionStmt: q.listFilesBySessionStmt, - listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, - listMessagesBySessionStmt: q.listMessagesBySessionStmt, - listNewFilesStmt: q.listNewFilesStmt, - listSessionsStmt: q.listSessionsStmt, - updateMessageStmt: q.updateMessageStmt, - updateSessionStmt: q.updateSessionStmt, + db: tx, + tx: tx, + createFileStmt: q.createFileStmt, + createMessageStmt: q.createMessageStmt, + createSessionStmt: q.createSessionStmt, + deleteFileStmt: q.deleteFileStmt, + deleteMessageStmt: q.deleteMessageStmt, + deleteSessionStmt: q.deleteSessionStmt, + deleteSessionFilesStmt: q.deleteSessionFilesStmt, + deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, + getFileStmt: q.getFileStmt, + getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, + getMessageStmt: q.getMessageStmt, + getSessionByIDStmt: q.getSessionByIDStmt, + getSessionStatsStmt: q.getSessionStatsStmt, + getSessionStatsByDayStmt: q.getSessionStatsByDayStmt, + getSessionStatsByMonthStmt: q.getSessionStatsByMonthStmt, + getSessionStatsByWeekStmt: q.getSessionStatsByWeekStmt, + listAllSessionsStmt: q.listAllSessionsStmt, + listChildSessionsStmt: q.listChildSessionsStmt, + listFilesByPathStmt: q.listFilesByPathStmt, + listFilesBySessionStmt: q.listFilesBySessionStmt, + listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, + listMessagesBySessionStmt: q.listMessagesBySessionStmt, + listNewFilesStmt: q.listNewFilesStmt, + listSessionsStmt: q.listSessionsStmt, + searchSessionsByTextStmt: q.searchSessionsByTextStmt, + searchSessionsByTitleStmt: q.searchSessionsByTitleStmt, + searchSessionsByTitleAndTextStmt: q.searchSessionsByTitleAndTextStmt, + updateMessageStmt: q.updateMessageStmt, + updateSessionStmt: q.updateSessionStmt, } } diff --git a/internal/db/querier.go b/internal/db/querier.go index 6d8a2a4f5f7479ce2e2e469728f3fa860d0479d6..831eafedfe3c5d93a3493f4d64cf02e93788b19b 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -22,6 +22,10 @@ type Querier interface { GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) GetMessage(ctx context.Context, id string) (Message, error) GetSessionByID(ctx context.Context, id string) (Session, error) + GetSessionStats(ctx context.Context) (GetSessionStatsRow, error) + GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error) + GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error) + GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error) ListAllSessions(ctx context.Context) ([]Session, error) ListChildSessions(ctx context.Context, parentSessionID sql.NullString) ([]Session, error) ListFilesByPath(ctx context.Context, path string) ([]File, error) @@ -30,6 +34,9 @@ type Querier interface { ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) ListNewFiles(ctx context.Context) ([]File, error) ListSessions(ctx context.Context) ([]Session, error) + SearchSessionsByText(ctx context.Context, parts string) ([]Session, error) + SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error) + SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) } diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index f61adfa0c59727b3b16fe310fa5d2b8ef8d47823..22bea33f2609dd681219a4a20d5a8de3eeda4ebd 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -106,6 +106,205 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error return i, err } +const getSessionStats = `-- name: GetSessionStats :one +SELECT + COUNT(*) as total_sessions, + SUM(message_count) as total_messages, + SUM(prompt_tokens) as total_prompt_tokens, + SUM(completion_tokens) as total_completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost_per_session +FROM sessions +` + +type GetSessionStatsRow struct { + TotalSessions int64 `json:"total_sessions"` + TotalMessages sql.NullFloat64 `json:"total_messages"` + TotalPromptTokens sql.NullFloat64 `json:"total_prompt_tokens"` + TotalCompletionTokens sql.NullFloat64 `json:"total_completion_tokens"` + TotalCost sql.NullFloat64 `json:"total_cost"` + AvgCostPerSession sql.NullFloat64 `json:"avg_cost_per_session"` +} + +func (q *Queries) GetSessionStats(ctx context.Context) (GetSessionStatsRow, error) { + row := q.queryRow(ctx, q.getSessionStatsStmt, getSessionStats) + var i GetSessionStatsRow + err := row.Scan( + &i.TotalSessions, + &i.TotalMessages, + &i.TotalPromptTokens, + &i.TotalCompletionTokens, + &i.TotalCost, + &i.AvgCostPerSession, + ) + return i, err +} + +const getSessionStatsByDay = `-- name: GetSessionStatsByDay :many +SELECT + date(created_at, 'unixepoch') as day, + COUNT(*) as session_count, + SUM(message_count) as message_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost +FROM sessions +GROUP BY date(created_at, 'unixepoch') +ORDER BY day DESC +` + +type GetSessionStatsByDayRow struct { + Day interface{} `json:"day"` + SessionCount int64 `json:"session_count"` + MessageCount sql.NullFloat64 `json:"message_count"` + PromptTokens sql.NullFloat64 `json:"prompt_tokens"` + CompletionTokens sql.NullFloat64 `json:"completion_tokens"` + TotalCost sql.NullFloat64 `json:"total_cost"` + AvgCost sql.NullFloat64 `json:"avg_cost"` +} + +func (q *Queries) GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error) { + rows, err := q.query(ctx, q.getSessionStatsByDayStmt, getSessionStatsByDay) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetSessionStatsByDayRow{} + for rows.Next() { + var i GetSessionStatsByDayRow + if err := rows.Scan( + &i.Day, + &i.SessionCount, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.TotalCost, + &i.AvgCost, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionStatsByMonth = `-- name: GetSessionStatsByMonth :many +SELECT + strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month, + COUNT(*) as session_count, + SUM(message_count) as message_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost +FROM sessions +GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch')) +ORDER BY month DESC +` + +type GetSessionStatsByMonthRow struct { + Month interface{} `json:"month"` + SessionCount int64 `json:"session_count"` + MessageCount sql.NullFloat64 `json:"message_count"` + PromptTokens sql.NullFloat64 `json:"prompt_tokens"` + CompletionTokens sql.NullFloat64 `json:"completion_tokens"` + TotalCost sql.NullFloat64 `json:"total_cost"` + AvgCost sql.NullFloat64 `json:"avg_cost"` +} + +func (q *Queries) GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error) { + rows, err := q.query(ctx, q.getSessionStatsByMonthStmt, getSessionStatsByMonth) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetSessionStatsByMonthRow{} + for rows.Next() { + var i GetSessionStatsByMonthRow + if err := rows.Scan( + &i.Month, + &i.SessionCount, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.TotalCost, + &i.AvgCost, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionStatsByWeek = `-- name: GetSessionStatsByWeek :many +SELECT + date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start, + COUNT(*) as session_count, + SUM(message_count) as message_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost +FROM sessions +GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days') +ORDER BY week_start DESC +` + +type GetSessionStatsByWeekRow struct { + WeekStart interface{} `json:"week_start"` + SessionCount int64 `json:"session_count"` + MessageCount sql.NullFloat64 `json:"message_count"` + PromptTokens sql.NullFloat64 `json:"prompt_tokens"` + CompletionTokens sql.NullFloat64 `json:"completion_tokens"` + TotalCost sql.NullFloat64 `json:"total_cost"` + AvgCost sql.NullFloat64 `json:"avg_cost"` +} + +func (q *Queries) GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error) { + rows, err := q.query(ctx, q.getSessionStatsByWeekStmt, getSessionStatsByWeek) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetSessionStatsByWeekRow{} + for rows.Next() { + var i GetSessionStatsByWeekRow + if err := rows.Scan( + &i.WeekStart, + &i.SessionCount, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.TotalCost, + &i.AvgCost, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listAllSessions = `-- name: ListAllSessions :many SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id FROM sessions @@ -228,6 +427,136 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { return items, nil } +const searchSessionsByText = `-- name: SearchSessionsByText :many +SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id +FROM sessions s +JOIN messages m ON s.id = m.session_id +WHERE m.parts LIKE ? +ORDER BY s.created_at DESC +` + +func (q *Queries) SearchSessionsByText(ctx context.Context, parts string) ([]Session, error) { + rows, err := q.query(ctx, q.searchSessionsByTextStmt, searchSessionsByText, parts) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const searchSessionsByTitle = `-- name: SearchSessionsByTitle :many +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id +FROM sessions +WHERE title LIKE ? +ORDER BY created_at DESC +` + +func (q *Queries) SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error) { + rows, err := q.query(ctx, q.searchSessionsByTitleStmt, searchSessionsByTitle, title) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const searchSessionsByTitleAndText = `-- name: SearchSessionsByTitleAndText :many +SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id +FROM sessions s +JOIN messages m ON s.id = m.session_id +WHERE s.title LIKE ? AND m.parts LIKE ? +ORDER BY s.created_at DESC +` + +type SearchSessionsByTitleAndTextParams struct { + Title string `json:"title"` + Parts string `json:"parts"` +} + +func (q *Queries) SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error) { + rows, err := q.query(ctx, q.searchSessionsByTitleAndTextStmt, searchSessionsByTitleAndText, arg.Title, arg.Parts) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateSession = `-- name: UpdateSession :one UPDATE sessions SET diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index b2a165c6fc9cf4979d5d9591fb51da5d7d1a6a23..25fcffdc3eba65085eb229b5309d17fe7c83cdb2 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -60,3 +60,72 @@ ORDER BY created_at ASC; SELECT * FROM sessions ORDER BY created_at DESC; + +-- name: SearchSessionsByTitle :many +SELECT * +FROM sessions +WHERE title LIKE ? +ORDER BY created_at DESC; + +-- name: SearchSessionsByTitleAndText :many +SELECT DISTINCT s.* +FROM sessions s +JOIN messages m ON s.id = m.session_id +WHERE s.title LIKE ? AND m.parts LIKE ? +ORDER BY s.created_at DESC; + +-- name: SearchSessionsByText :many +SELECT DISTINCT s.* +FROM sessions s +JOIN messages m ON s.id = m.session_id +WHERE m.parts LIKE ? +ORDER BY s.created_at DESC; + +-- name: GetSessionStats :one +SELECT + COUNT(*) as total_sessions, + SUM(message_count) as total_messages, + SUM(prompt_tokens) as total_prompt_tokens, + SUM(completion_tokens) as total_completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost_per_session +FROM sessions; + +-- name: GetSessionStatsByDay :many +SELECT + date(created_at, 'unixepoch') as day, + COUNT(*) as session_count, + SUM(message_count) as message_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost +FROM sessions +GROUP BY date(created_at, 'unixepoch') +ORDER BY day DESC; + +-- name: GetSessionStatsByWeek :many +SELECT + date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start, + COUNT(*) as session_count, + SUM(message_count) as message_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost +FROM sessions +GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days') +ORDER BY week_start DESC; + +-- name: GetSessionStatsByMonth :many +SELECT + strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month, + COUNT(*) as session_count, + SUM(message_count) as message_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(cost) as total_cost, + AVG(cost) as avg_cost +FROM sessions +GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch')) +ORDER BY month DESC; diff --git a/internal/session/session.go b/internal/session/session.go index d98444d3265fe82984493b4fd483b58d042bc766..f9699f3cd2828507b16e359702f6da3f9f997f04 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -33,6 +33,9 @@ type Service interface { ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) Delete(ctx context.Context, id string) error + SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) + SearchByText(ctx context.Context, textPattern string) ([]Session, error) + SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) } type service struct { @@ -161,6 +164,45 @@ func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]S return sessions, nil } +func (s *service) SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) { + dbSessions, err := s.q.SearchSessionsByTitle(ctx, "%"+titlePattern+"%") + if err != nil { + return nil, err + } + sessions := make([]Session, len(dbSessions)) + for i, dbSession := range dbSessions { + sessions[i] = s.fromDBItem(dbSession) + } + return sessions, nil +} + +func (s *service) SearchByText(ctx context.Context, textPattern string) ([]Session, error) { + dbSessions, err := s.q.SearchSessionsByText(ctx, "%"+textPattern+"%") + if err != nil { + return nil, err + } + sessions := make([]Session, len(dbSessions)) + for i, dbSession := range dbSessions { + sessions[i] = s.fromDBItem(dbSession) + } + return sessions, nil +} + +func (s *service) SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) { + dbSessions, err := s.q.SearchSessionsByTitleAndText(ctx, db.SearchSessionsByTitleAndTextParams{ + Title: "%" + titlePattern + "%", + Parts: "%" + textPattern + "%", + }) + if err != nil { + return nil, err + } + sessions := make([]Session, len(dbSessions)) + for i, dbSession := range dbSessions { + sessions[i] = s.fromDBItem(dbSession) + } + return sessions, nil +} + func (s service) fromDBItem(item db.Session) Session { return Session{ ID: item.ID, diff --git a/rfcs/session-import-export.md b/rfcs/session-import-export.md new file mode 100644 index 0000000000000000000000000000000000000000..53b35a05634f45af27a9532a0d7eb981cc46c8ff --- /dev/null +++ b/rfcs/session-import-export.md @@ -0,0 +1,297 @@ +# RFC: Session Import and Export + +## Summary + +This RFC proposes a comprehensive system for importing and exporting conversation sessions in Crush. + +## Background + +Crush manages conversations through a hierarchical session system where: +- Sessions contain metadata (title, token counts, cost, timestamps) +- Sessions can have parent-child relationships (nested conversations) +- Messages within sessions have structured content parts (text, tool calls, reasoning, etc.) +- The current implementation provides export functionality but lacks import capabilities + +The latest commit introduced three key commands: +- `crush sessions list` - List sessions in various formats +- `crush sessions export` - Export all sessions and metadata +- `crush sessions export-conversation ` - Export a single conversation with messages + +## Motivation + +Users need to: +1. Share conversations with others +2. Use conversation logs for debugging +3. Archive and analyze conversation history +4. Export data for external tools + +## Detailed Design + +### Core Data Model + +The session export format builds on the existing session structure: + +```go +type Session struct { + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id,omitempty"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + SummaryMessageID string `json:"summary_message_id,omitempty"` +} + +type SessionWithChildren struct { + Session + Children []SessionWithChildren `json:"children,omitempty"` +} +``` + +### Proposed Command Interface + +#### Export Commands (Already Implemented) +```bash +# List sessions in various formats +crush sessions list [--format text|json|yaml|markdown] + +# Export all sessions with metadata +crush sessions export [--format json|yaml|markdown] + +# Export single conversation with full message history +crush sessions export-conversation [--format markdown|json|yaml] +``` + +#### New Import Commands + +```bash +# Import sessions from a file +crush sessions import [--format json|yaml] [--dry-run] + +# Import a single conversation +crush sessions import-conversation [--format json|yaml|markdown] + +``` + +#### Enhanced Inspection Commands + +```bash +# Search sessions by criteria +crush sessions search [--title ] [--text ] [--format text|json] + +# Show session statistics +crush sessions stats [--format text|json] [--group-by day|week|month] + +# Show statistics for a single session +crush sessions stats [--format text|json] +``` + +### Import/Export Formats + +#### Full Export Format (JSON) +```json +{ + "version": "1.0", + "exported_at": "2025-01-27T10:30:00Z", + "total_sessions": 15, + "sessions": [ + { + "id": "session-123", + "parent_session_id": "", + "title": "API Design Discussion", + "message_count": 8, + "prompt_tokens": 1250, + "completion_tokens": 890, + "cost": 0.0234, + "created_at": 1706356200, + "updated_at": 1706359800, + "children": [ + { + "id": "session-124", + "parent_session_id": "session-123", + "title": "Implementation Details", + "message_count": 4, + "prompt_tokens": 650, + "completion_tokens": 420, + "cost": 0.0145, + "created_at": 1706359900, + "updated_at": 1706361200 + } + ] + } + ] +} +``` + +#### Conversation Export Format (JSON) +```json +{ + "version": "1.0", + "session": { + "id": "session-123", + "title": "API Design Discussion", + "created_at": 1706356200, + "message_count": 3 + }, + "messages": [ + { + "id": "msg-001", + "session_id": "session-123", + "role": "user", + "parts": [ + { + "type": "text", + "data": { + "text": "Help me design a REST API for user management" + } + } + ], + "created_at": 1706356200 + }, + { + "id": "msg-002", + "session_id": "session-123", + "role": "assistant", + "model": "gpt-4", + "provider": "openai", + "parts": [ + { + "type": "text", + "data": { + "text": "I'll help you design a REST API for user management..." + } + }, + { + "type": "finish", + "data": { + "reason": "stop", + "time": 1706356230 + } + } + ], + "created_at": 1706356220 + } + ] +} +``` + +### API Implementation + +#### Import Service Interface +```go +type ImportService interface { + // Import sessions from structured data + ImportSessions(ctx context.Context, data ImportData, opts ImportOptions) (ImportResult, error) + + // Import single conversation + ImportConversation(ctx context.Context, data ConversationData, opts ImportOptions) (Session, error) + + // Validate import data without persisting + ValidateImport(ctx context.Context, data ImportData) (ValidationResult, error) +} + +type ImportOptions struct { + ConflictStrategy ConflictStrategy // skip, merge, replace + DryRun bool + ParentSessionID string // For conversation imports + PreserveIDs bool // Whether to preserve original IDs +} + +type ConflictStrategy string + +const ( + ConflictSkip ConflictStrategy = "skip" // Skip existing sessions + ConflictMerge ConflictStrategy = "merge" // Merge with existing + ConflictReplace ConflictStrategy = "replace" // Replace existing +) + +type ImportResult struct { + TotalSessions int `json:"total_sessions"` + ImportedSessions int `json:"imported_sessions"` + SkippedSessions int `json:"skipped_sessions"` + Errors []ImportError `json:"errors,omitempty"` + SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id +} +``` + +#### Enhanced Export Service +```go +type ExportService interface { + // Export sessions with filtering + ExportSessions(ctx context.Context, opts ExportOptions) ([]SessionWithChildren, error) + + // Export conversation with full message history + ExportConversation(ctx context.Context, sessionID string, opts ExportOptions) (ConversationExport, error) + + // Search and filter sessions + SearchSessions(ctx context.Context, criteria SearchCriteria) ([]Session, error) + + // Get session statistics + GetStats(ctx context.Context, opts StatsOptions) (SessionStats, error) +} + +type ExportOptions struct { + Format string // json, yaml, markdown + IncludeMessages bool // Include full message content + DateRange DateRange // Filter by date range + SessionIDs []string // Export specific sessions +} + +type SearchCriteria struct { + TitlePattern string + DateRange DateRange + MinCost float64 + MaxCost float64 + ParentSessionID string + HasChildren *bool +} +``` + +## Implementation Status + +The proposed session import/export functionality has been implemented as a prototype as of July 2025. + +### Implemented Commands + +All new commands have been added to `internal/cmd/sessions.go`: + +- **Import**: `crush sessions import [--format json|yaml] [--dry-run]` + - Supports hierarchical session imports with parent-child relationships + - Generates new UUIDs to avoid conflicts + - Includes validation and dry-run capabilities + +- **Import Conversation**: `crush sessions import-conversation [--format json|yaml]` + - Imports single conversations with full message history + - Preserves all message content parts and metadata + +- **Search**: `crush sessions search [--title ] [--text ] [--format text|json]` + - Case-insensitive title search and message content search + - Supports combined search criteria with AND logic + +- **Stats**: `crush sessions stats [--format text|json] [--group-by day|week|month]` + - Comprehensive usage statistics (sessions, messages, tokens, costs) + - Time-based grouping with efficient database queries + +### Database Changes + +Added new SQL queries in `internal/db/sql/sessions.sql`: +- Search queries for title and message content filtering +- Statistics aggregation queries with time-based grouping +- All queries optimized for performance with proper indexing + +### Database Schema Considerations + +The current schema supports the import/export functionality. Additional indexes may be needed for search performance: + +```sql +-- Optimize session searches by date and cost +CREATE INDEX idx_sessions_created_at ON sessions(created_at); +CREATE INDEX idx_sessions_cost ON sessions(cost); +CREATE INDEX idx_sessions_title ON sessions(title COLLATE NOCASE); + +-- Optimize message searches by session +CREATE INDEX idx_messages_session_created ON messages(session_id, created_at); +```