diff --git a/go.mod b/go.mod index 9f91fda2fc86914d595bf8b5eaabaddcd7b05319..e9233ee972a421d3da705ca5641495abb7c03e63 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( charm.land/bubbles/v2 v2.0.0 charm.land/bubbletea/v2 v2.0.2 charm.land/catwalk v0.28.4 + charm.land/fang/v2 v2.0.1 charm.land/fantasy v0.12.1 charm.land/glamour/v2 v2.0.0 charm.land/lipgloss/v2 v2.0.1 @@ -21,7 +22,6 @@ require ( github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/charlievieth/fastwalk v1.0.14 github.com/charmbracelet/colorprofile v0.4.3 - github.com/charmbracelet/fang v1.0.0 github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 github.com/charmbracelet/x/ansi v0.11.6 github.com/charmbracelet/x/editor v0.2.0 diff --git a/go.sum b/go.sum index aff01decd2858103938d3c73f3584760189ccd91..61f7ac7cde8c1acba2ae775a764a09b45ad0f6f8 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0= charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ= charm.land/catwalk v0.28.4 h1:YaaXA1k0v7CKvvT+Gh1pDD7XrlUR93kROdaWqkkglRw= charm.land/catwalk v0.28.4/go.mod h1:+fqw/6YGNtvapvPy9vhwA/fAMxVjD2K8hVIKYov8Vhg= +charm.land/fang/v2 v2.0.1 h1:zQCM8JQJ1JnQX/66B5jlCYBUxL2as5JXQZ2KJ6EL0mY= +charm.land/fang/v2 v2.0.1/go.mod h1:S1GmkpcvK+OB5w9caywUnJcsMew45Ot8FXqoz8ALrII= charm.land/fantasy v0.12.1 h1:awszoi5O9FIjMEkfyCMiLJfVRNLckp/zQkFrA6IxQqc= charm.land/fantasy v0.12.1/go.mod h1:QeRVUeG1XNTWBszRAbhUtPyX1VWs6zjkCxwfcwnICdc= charm.land/glamour/v2 v2.0.0 h1:IDBoqLEy7Hdpb9VOXN+khLP/XSxtJy1VsHuW/yF87+U= @@ -100,8 +102,6 @@ github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1: github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8= github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q= github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q= -github.com/charmbracelet/fang v1.0.0 h1:jESBY40agJOlLYnnv9jE0mLqDGTxEk0hkOnx7YGyRlQ= -github.com/charmbracelet/fang v1.0.0/go.mod h1:P5/DNb9DddQ0Z0dbc0P3ol4/ix5Po7Ofr2KMBfAqoCo= github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 h1:eyFRbAmexyt43hVfeyBofiGSEmJ7krjLOYt/9CF5NKA= github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8/go.mod h1:SQpCTRNBtzJkwku5ye4S3HEuthAlGy2n9VXZnWkEW98= github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 7d41339811b6f4ca1d74fc903f5058ec833d5b8d..f8838be7f3b9c5321f0a24a54d5130ed8aad293a 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -842,9 +842,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user // Welp, the large model didn't work either. Use the default // session name and return. slog.Error("Error generating title with large model", "err", err) - saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0) + saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName) if saveErr != nil { - slog.Error("Failed to save session title and usage", "error", saveErr) + slog.Error("Failed to save session title", "error", saveErr) } return } @@ -854,9 +854,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user // Actually, we didn't get a response so we can't. Use the default // session name and return. slog.Error("Response is nil; can't generate title") - saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0) + saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName) if saveErr != nil { - slog.Error("Failed to save session title and usage", "error", saveErr) + slog.Error("Failed to save session title", "error", saveErr) } return } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 6e1bc08f2f14e8af3d65b5dca7826b95d890b116..060d94c13a3d1cb9b927ebfbddc473aa1593e9dc 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -13,6 +13,7 @@ import ( "strings" tea "charm.land/bubbletea/v2" + "charm.land/fang/v2" "charm.land/lipgloss/v2" "github.com/charmbracelet/colorprofile" "github.com/charmbracelet/crush/internal/app" @@ -23,7 +24,6 @@ import ( "github.com/charmbracelet/crush/internal/ui/common" ui "github.com/charmbracelet/crush/internal/ui/model" "github.com/charmbracelet/crush/internal/version" - "github.com/charmbracelet/fang" uv "github.com/charmbracelet/ultraviolet" "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/x/exp/charmtone" @@ -47,6 +47,7 @@ func init() { schemaCmd, loginCmd, statsCmd, + sessionCmd, ) } diff --git a/internal/cmd/session.go b/internal/cmd/session.go new file mode 100644 index 0000000000000000000000000000000000000000..67f40a0cea7442efe5d36600dc27260ab66b2320 --- /dev/null +++ b/internal/cmd/session.go @@ -0,0 +1,641 @@ +package cmd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "runtime" + "strings" + "syscall" + "time" + + "charm.land/lipgloss/v2" + "github.com/charmbracelet/colorprofile" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/crush/internal/ui/chat" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/x/ansi" + "github.com/charmbracelet/x/exp/charmtone" + "github.com/charmbracelet/x/term" + "github.com/spf13/cobra" +) + +var sessionCmd = &cobra.Command{ + Use: "session", + Aliases: []string{"sessions"}, + Short: "Manage sessions", + Long: "Manage Crush sessions. Agents can use --json for machine-readable output.", +} + +var ( + sessionListJSON bool + sessionShowJSON bool + sessionLastJSON bool + sessionDeleteJSON bool + sessionRenameJSON bool +) + +var sessionListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List all sessions", + Long: "List all sessions. Use --json for machine-readable output.", + RunE: runSessionList, +} + +var sessionShowCmd = &cobra.Command{ + Use: "show ", + Short: "Show session details", + Long: "Show session details. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.", + Args: cobra.ExactArgs(1), + RunE: runSessionShow, +} + +var sessionLastCmd = &cobra.Command{ + Use: "last", + Short: "Show most recent session", + Long: "Show the last updated session. Use --json for machine-readable output.", + RunE: runSessionLast, +} + +var sessionDeleteCmd = &cobra.Command{ + Use: "delete ", + Aliases: []string{"rm"}, + Short: "Delete a session", + Long: "Delete a session by ID. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.", + Args: cobra.ExactArgs(1), + RunE: runSessionDelete, +} + +var sessionRenameCmd = &cobra.Command{ + Use: "rename ", + Short: "Rename a session", + Long: "Rename a session by ID. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.", + Args: cobra.MinimumNArgs(2), + RunE: runSessionRename, +} + +func init() { + sessionListCmd.Flags().BoolVar(&sessionListJSON, "json", false, "output in JSON format") + sessionShowCmd.Flags().BoolVar(&sessionShowJSON, "json", false, "output in JSON format") + sessionLastCmd.Flags().BoolVar(&sessionLastJSON, "json", false, "output in JSON format") + sessionDeleteCmd.Flags().BoolVar(&sessionDeleteJSON, "json", false, "output in JSON format") + sessionRenameCmd.Flags().BoolVar(&sessionRenameJSON, "json", false, "output in JSON format") + sessionCmd.AddCommand(sessionListCmd) + sessionCmd.AddCommand(sessionShowCmd) + sessionCmd.AddCommand(sessionLastCmd) + sessionCmd.AddCommand(sessionDeleteCmd) + sessionCmd.AddCommand(sessionRenameCmd) +} + +type sessionServices struct { + sessions session.Service + messages message.Service +} + +func sessionSetup(cmd *cobra.Command) (context.Context, *sessionServices, func(), error) { + dataDir, _ := cmd.Flags().GetString("data-dir") + ctx := cmd.Context() + + if dataDir == "" { + cfg, err := config.Init("", "", false) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to initialize config: %w", err) + } + dataDir = cfg.Config().Options.DataDirectory + } + + conn, err := db.Connect(ctx, dataDir) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to connect to database: %w", err) + } + + queries := db.New(conn) + svc := &sessionServices{ + sessions: session.NewService(queries, conn), + messages: message.NewService(queries), + } + return ctx, svc, func() { conn.Close() }, nil +} + +func runSessionList(cmd *cobra.Command, _ []string) error { + ctx, svc, cleanup, err := sessionSetup(cmd) + if err != nil { + return err + } + defer cleanup() + + list, err := svc.sessions.List(ctx) + if err != nil { + return fmt.Errorf("failed to list sessions: %w", err) + } + + if sessionListJSON { + out := cmd.OutOrStdout() + output := make([]sessionJSON, len(list)) + for i, s := range list { + output[i] = sessionJSON{ + ID: session.HashID(s.ID), + UUID: s.ID, + Title: s.Title, + Created: time.Unix(s.CreatedAt, 0).Format(time.RFC3339), + Modified: time.Unix(s.UpdatedAt, 0).Format(time.RFC3339), + } + } + enc := json.NewEncoder(out) + enc.SetEscapeHTML(false) + return enc.Encode(output) + } + + w, cleanup, usingPager := sessionWriter(ctx, len(list)) + defer cleanup() + + hashStyle := lipgloss.NewStyle().Foreground(charmtone.Malibu) + dateStyle := lipgloss.NewStyle().Foreground(charmtone.Damson) + + width := sessionOutputWidth + if tw, _, err := term.GetSize(os.Stdout.Fd()); err == nil && tw > 0 { + width = tw + } + // 7 (hash) + 1 (space) + 25 (RFC3339 date) + 1 (space) = 34 chars prefix. + titleWidth := width - 34 + if titleWidth < 10 { + titleWidth = 10 + } + + var writeErr error + for _, s := range list { + hash := session.HashID(s.ID)[:7] + date := time.Unix(s.CreatedAt, 0).Format(time.RFC3339) + title := strings.ReplaceAll(s.Title, "\n", " ") + title = ansi.Truncate(title, titleWidth, "…") + _, writeErr = fmt.Fprintln(w, hashStyle.Render(hash), dateStyle.Render(date), title) + if writeErr != nil { + break + } + } + if writeErr != nil && usingPager && isBrokenPipe(writeErr) { + return nil + } + return writeErr +} + +type sessionJSON struct { + ID string `json:"id"` + UUID string `json:"uuid"` + Title string `json:"title"` + Created string `json:"created"` + Modified string `json:"modified"` +} + +type sessionMutationResult struct { + ID string `json:"id"` + UUID string `json:"uuid"` + Title string `json:"title"` + Deleted bool `json:"deleted,omitempty"` + Renamed bool `json:"renamed,omitempty"` +} + +// resolveSessionID resolves a session ID that can be a UUID, full hash, or hash prefix. +// Returns an error if the prefix is ambiguous (matches multiple sessions). +func resolveSessionID(ctx context.Context, svc session.Service, id string) (session.Session, error) { + // Try direct UUID lookup first + if s, err := svc.Get(ctx, id); err == nil { + return s, nil + } + + // List all sessions and check for hash matches + sessions, err := svc.List(ctx) + if err != nil { + return session.Session{}, err + } + + var matches []session.Session + for _, s := range sessions { + hash := session.HashID(s.ID) + if hash == id || strings.HasPrefix(hash, id) { + matches = append(matches, s) + } + } + + if len(matches) == 0 { + return session.Session{}, fmt.Errorf("session not found: %s", id) + } + + if len(matches) == 1 { + return matches[0], nil + } + + // Ambiguous - show matches like Git does + var sb strings.Builder + fmt.Fprintf(&sb, "session ID '%s' is ambiguous. Matches:\n\n", id) + for _, m := range matches { + hash := session.HashID(m.ID) + created := time.Unix(m.CreatedAt, 0).Format("2006-01-02") + // Keep title on one line by replacing newlines with spaces, and truncate. + title := strings.ReplaceAll(m.Title, "\n", " ") + title = ansi.Truncate(title, 50, "…") + fmt.Fprintf(&sb, " %s... %q (created %s)\n", hash[:12], title, created) + } + sb.WriteString("\nUse more characters or the full hash") + return session.Session{}, errors.New(sb.String()) +} + +func runSessionShow(cmd *cobra.Command, args []string) error { + ctx, svc, cleanup, err := sessionSetup(cmd) + if err != nil { + return err + } + defer cleanup() + + sess, err := resolveSessionID(ctx, svc.sessions, args[0]) + if err != nil { + return err + } + + msgs, err := svc.messages.List(ctx, sess.ID) + if err != nil { + return fmt.Errorf("failed to list messages: %w", err) + } + + msgPtrs := messagePtrs(msgs) + if sessionShowJSON { + return outputSessionJSON(cmd.OutOrStdout(), sess, msgPtrs) + } + return outputSessionHuman(ctx, sess, msgPtrs) +} + +func runSessionDelete(cmd *cobra.Command, args []string) error { + ctx, svc, cleanup, err := sessionSetup(cmd) + if err != nil { + return err + } + defer cleanup() + + sess, err := resolveSessionID(ctx, svc.sessions, args[0]) + if err != nil { + return err + } + + if err := svc.sessions.Delete(ctx, sess.ID); err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + + out := cmd.OutOrStdout() + if sessionDeleteJSON { + enc := json.NewEncoder(out) + enc.SetEscapeHTML(false) + return enc.Encode(sessionMutationResult{ + ID: session.HashID(sess.ID), + UUID: sess.ID, + Title: sess.Title, + Deleted: true, + }) + } + + fmt.Fprintf(out, "Deleted session %s\n", session.HashID(sess.ID)[:12]) + return nil +} + +func runSessionRename(cmd *cobra.Command, args []string) error { + ctx, svc, cleanup, err := sessionSetup(cmd) + if err != nil { + return err + } + defer cleanup() + + sess, err := resolveSessionID(ctx, svc.sessions, args[0]) + if err != nil { + return err + } + + newTitle := strings.Join(args[1:], " ") + if err := svc.sessions.Rename(ctx, sess.ID, newTitle); err != nil { + return fmt.Errorf("failed to rename session: %w", err) + } + + out := cmd.OutOrStdout() + if sessionRenameJSON { + enc := json.NewEncoder(out) + enc.SetEscapeHTML(false) + return enc.Encode(sessionMutationResult{ + ID: session.HashID(sess.ID), + UUID: sess.ID, + Title: newTitle, + Renamed: true, + }) + } + + fmt.Fprintf(out, "Renamed session %s to %q\n", session.HashID(sess.ID)[:12], newTitle) + return nil +} + +func runSessionLast(cmd *cobra.Command, _ []string) error { + ctx, svc, cleanup, err := sessionSetup(cmd) + if err != nil { + return err + } + defer cleanup() + + list, err := svc.sessions.List(ctx) + if err != nil { + return fmt.Errorf("failed to list sessions: %w", err) + } + + if len(list) == 0 { + return fmt.Errorf("no sessions found") + } + + sess := list[0] + + msgs, err := svc.messages.List(ctx, sess.ID) + if err != nil { + return fmt.Errorf("failed to list messages: %w", err) + } + + msgPtrs := messagePtrs(msgs) + if sessionLastJSON { + return outputSessionJSON(cmd.OutOrStdout(), sess, msgPtrs) + } + return outputSessionHuman(ctx, sess, msgPtrs) +} + +const ( + sessionOutputWidth = 80 + sessionMaxContentWidth = 120 +) + +func messagePtrs(msgs []message.Message) []*message.Message { + ptrs := make([]*message.Message, len(msgs)) + for i := range msgs { + ptrs[i] = &msgs[i] + } + return ptrs +} + +func outputSessionJSON(w io.Writer, sess session.Session, msgs []*message.Message) error { + output := sessionShowOutput{ + Meta: sessionShowMeta{ + ID: session.HashID(sess.ID), + UUID: sess.ID, + Title: sess.Title, + Created: time.Unix(sess.CreatedAt, 0).Format(time.RFC3339), + Modified: time.Unix(sess.UpdatedAt, 0).Format(time.RFC3339), + Cost: sess.Cost, + PromptTokens: sess.PromptTokens, + CompletionTokens: sess.CompletionTokens, + TotalTokens: sess.PromptTokens + sess.CompletionTokens, + }, + Messages: make([]sessionShowMessage, len(msgs)), + } + + for i, msg := range msgs { + output.Messages[i] = sessionShowMessage{ + ID: msg.ID, + Role: string(msg.Role), + Created: time.Unix(msg.CreatedAt, 0).Format(time.RFC3339), + Model: msg.Model, + Provider: msg.Provider, + Parts: convertParts(msg.Parts), + } + } + + enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) + return enc.Encode(output) +} + +func outputSessionHuman(ctx context.Context, sess session.Session, msgs []*message.Message) error { + sty := styles.DefaultStyles() + toolResults := chat.BuildToolResultMap(msgs) + + width := sessionOutputWidth + if w, _, err := term.GetSize(os.Stdout.Fd()); err == nil && w > 0 { + width = w + } + contentWidth := min(width, sessionMaxContentWidth) + + keyStyle := lipgloss.NewStyle().Foreground(charmtone.Damson) + valStyle := lipgloss.NewStyle().Foreground(charmtone.Malibu) + + hash := session.HashID(sess.ID)[:12] + created := time.Unix(sess.CreatedAt, 0).Format("Mon Jan 2 15:04:05 2006 -0700") + + // Render to buffer to determine actual height + var buf strings.Builder + + fmt.Fprintln(&buf, keyStyle.Render("ID: ")+valStyle.Render(hash)) + fmt.Fprintln(&buf, keyStyle.Render("UUID: ")+valStyle.Render(sess.ID)) + fmt.Fprintln(&buf, keyStyle.Render("Title: ")+valStyle.Render(sess.Title)) + fmt.Fprintln(&buf, keyStyle.Render("Date: ")+valStyle.Render(created)) + fmt.Fprintln(&buf) + + first := true + for _, msg := range msgs { + items := chat.ExtractMessageItems(&sty, msg, toolResults) + for _, item := range items { + if !first { + fmt.Fprintln(&buf) + } + first = false + fmt.Fprintln(&buf, item.Render(contentWidth)) + } + } + fmt.Fprintln(&buf) + + contentHeight := strings.Count(buf.String(), "\n") + w, cleanup, usingPager := sessionWriter(ctx, contentHeight) + defer cleanup() + + _, err := io.WriteString(w, buf.String()) + // Ignore broken pipe errors when using a pager. This happens when the user + // exits the pager early (e.g., pressing 'q' in less), which closes the pipe + // and causes subsequent writes to fail. These errors are expected user behavior. + if err != nil && usingPager && isBrokenPipe(err) { + return nil + } + return err +} + +func isBrokenPipe(err error) bool { + if err == nil { + return false + } + // Check for syscall.EPIPE (broken pipe) + if errors.Is(err, syscall.EPIPE) { + return true + } + // Also check for "broken pipe" in the error message + return strings.Contains(err.Error(), "broken pipe") +} + +// sessionWriter returns a writer, cleanup function, and a bool indicating if a pager is used. +// When the content fits within the terminal (or stdout is not a TTY), it returns +// a colorprofile.Writer wrapping stdout. When content exceeds terminal height, +// it starts a pager process (respecting $PAGER, defaulting to "less -R"). +func sessionWriter(ctx context.Context, contentHeight int) (io.Writer, func(), bool) { + // Use NewWriter which automatically detects TTY and strips ANSI when redirected + if runtime.GOOS == "windows" || !term.IsTerminal(os.Stdout.Fd()) { + return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false + } + + _, termHeight, err := term.GetSize(os.Stdout.Fd()) + if err != nil || contentHeight <= termHeight { + return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false + } + + // Detect color profile from stderr since stdout is piped to the pager. + profile := colorprofile.Detect(os.Stderr, os.Environ()) + + pager := os.Getenv("PAGER") + if pager == "" { + pager = "less -R" + } + + parts := strings.Fields(pager) + cmd := exec.CommandContext(ctx, parts[0], parts[1:]...) //nolint:gosec + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + pipe, err := cmd.StdinPipe() + if err != nil { + return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false + } + + if err := cmd.Start(); err != nil { + return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false + } + + return &colorprofile.Writer{ + Forward: pipe, + Profile: profile, + }, func() { + pipe.Close() + _ = cmd.Wait() + }, true +} + +type sessionShowMeta struct { + ID string `json:"id"` + UUID string `json:"uuid"` + Title string `json:"title"` + Created string `json:"created"` + Modified string `json:"modified"` + Cost float64 `json:"cost"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +type sessionShowMessage struct { + ID string `json:"id"` + Role string `json:"role"` + Created string `json:"created"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + Parts []sessionShowPart `json:"parts"` +} + +type sessionShowPart struct { + Type string `json:"type"` + + // Text content + Text string `json:"text,omitempty"` + + // Reasoning + Thinking string `json:"thinking,omitempty"` + StartedAt int64 `json:"started_at,omitempty"` + FinishedAt int64 `json:"finished_at,omitempty"` + + // Tool call + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` + Input string `json:"input,omitempty"` + + // Tool result + Content string `json:"content,omitempty"` + IsError bool `json:"is_error,omitempty"` + MIMEType string `json:"mime_type,omitempty"` + + // Binary + Size int64 `json:"size,omitempty"` + + // Image URL + URL string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` + + // Finish + Reason string `json:"reason,omitempty"` + Time int64 `json:"time,omitempty"` +} + +func convertParts(parts []message.ContentPart) []sessionShowPart { + result := make([]sessionShowPart, 0, len(parts)) + for _, part := range parts { + switch p := part.(type) { + case message.TextContent: + result = append(result, sessionShowPart{ + Type: "text", + Text: p.Text, + }) + case message.ReasoningContent: + result = append(result, sessionShowPart{ + Type: "reasoning", + Thinking: p.Thinking, + StartedAt: p.StartedAt, + FinishedAt: p.FinishedAt, + }) + case message.ToolCall: + result = append(result, sessionShowPart{ + Type: "tool_call", + ToolCallID: p.ID, + Name: p.Name, + Input: p.Input, + }) + case message.ToolResult: + result = append(result, sessionShowPart{ + Type: "tool_result", + ToolCallID: p.ToolCallID, + Name: p.Name, + Content: p.Content, + IsError: p.IsError, + MIMEType: p.MIMEType, + }) + case message.BinaryContent: + result = append(result, sessionShowPart{ + Type: "binary", + MIMEType: p.MIMEType, + Size: int64(len(p.Data)), + }) + case message.ImageURLContent: + result = append(result, sessionShowPart{ + Type: "image_url", + URL: p.URL, + Detail: p.Detail, + }) + case message.Finish: + result = append(result, sessionShowPart{ + Type: "finish", + Reason: string(p.Reason), + Time: p.Time, + }) + default: + result = append(result, sessionShowPart{ + Type: "unknown", + }) + } + } + return result +} + +type sessionShowOutput struct { + Meta sessionShowMeta `json:"meta"` + Messages []sessionShowMessage `json:"messages"` +} diff --git a/internal/db/db.go b/internal/db/db.go index ec4e3807057bf4ac456ad9c066a4edb00c1771d5..dbde2e493eea4c262aef55ef7dcadd904a1b9d65 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -120,6 +120,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.recordFileReadStmt, err = db.PrepareContext(ctx, recordFileRead); err != nil { return nil, fmt.Errorf("error preparing query RecordFileRead: %w", err) } + if q.renameSessionStmt, err = db.PrepareContext(ctx, renameSession); err != nil { + return nil, fmt.Errorf("error preparing query RenameSession: %w", err) + } if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil { return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err) } @@ -294,6 +297,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing recordFileReadStmt: %w", cerr) } } + if q.renameSessionStmt != nil { + if cerr := q.renameSessionStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing renameSessionStmt: %w", cerr) + } + } if q.updateMessageStmt != nil { if cerr := q.updateMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateMessageStmt: %w", cerr) @@ -380,6 +388,7 @@ type Queries struct { listSessionsStmt *sql.Stmt listUserMessagesBySessionStmt *sql.Stmt recordFileReadStmt *sql.Stmt + renameSessionStmt *sql.Stmt updateMessageStmt *sql.Stmt updateSessionStmt *sql.Stmt updateSessionTitleAndUsageStmt *sql.Stmt @@ -421,6 +430,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { listSessionsStmt: q.listSessionsStmt, listUserMessagesBySessionStmt: q.listUserMessagesBySessionStmt, recordFileReadStmt: q.recordFileReadStmt, + renameSessionStmt: q.renameSessionStmt, updateMessageStmt: q.updateMessageStmt, updateSessionStmt: q.updateSessionStmt, updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt, diff --git a/internal/db/models.go b/internal/db/models.go index a105074ab9e6320bd92b90121e7694b1f8cd1e5a..20034fb00a935bed7c4cfe4906dba66dd380ed64 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -34,7 +34,7 @@ type Message struct { type ReadFile struct { SessionID string `json:"session_id"` Path string `json:"path"` - ReadAt int64 `json:"read_at"` // Unix timestamp when file was last read + ReadAt int64 `json:"read_at"` } type Session struct { diff --git a/internal/db/querier.go b/internal/db/querier.go index 9a72be02c12a2760a6ab2acef8765cabb0f6bd0c..ae91927aedf797f84f347e7e14a93327120a847e 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -41,6 +41,7 @@ type Querier interface { ListSessions(ctx context.Context) ([]Session, error) ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) RecordFileRead(ctx context.Context, arg RecordFileReadParams) error + RenameSession(ctx context.Context, arg RenameSessionParams) error UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error diff --git a/internal/db/read_files.sql.go b/internal/db/read_files.sql.go index c1cda5ee633ede07b2faebe38619292c994a9f50..21bda47ce9db1177417fce87677169b49ebd56dc 100644 --- a/internal/db/read_files.sql.go +++ b/internal/db/read_files.sql.go @@ -22,32 +22,10 @@ type GetFileReadParams struct { func (q *Queries) GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error) { row := q.queryRow(ctx, q.getFileReadStmt, getFileRead, arg.SessionID, arg.Path) var i ReadFile - err := row.Scan( - &i.SessionID, - &i.Path, - &i.ReadAt, - ) + err := row.Scan(&i.SessionID, &i.Path, &i.ReadAt) return i, err } -const recordFileRead = `-- name: RecordFileRead :exec -INSERT INTO read_files ( - session_id, - path, - read_at -) VALUES ( - ?, - ?, - strftime('%s', 'now') -) ON CONFLICT(path, session_id) DO UPDATE SET - read_at = excluded.read_at -` - -type RecordFileReadParams struct { - SessionID string `json:"session_id"` - Path string `json:"path"` -} - const listSessionReadFiles = `-- name: ListSessionReadFiles :many SELECT session_id, path, read_at FROM read_files WHERE session_id = ? @@ -63,11 +41,7 @@ func (q *Queries) ListSessionReadFiles(ctx context.Context, sessionID string) ([ items := []ReadFile{} for rows.Next() { var i ReadFile - if err := rows.Scan( - &i.SessionID, - &i.Path, - &i.ReadAt, - ); err != nil { + if err := rows.Scan(&i.SessionID, &i.Path, &i.ReadAt); err != nil { return nil, err } items = append(items, i) @@ -81,10 +55,25 @@ func (q *Queries) ListSessionReadFiles(ctx context.Context, sessionID string) ([ return items, nil } +const recordFileRead = `-- name: RecordFileRead :exec +INSERT INTO read_files ( + session_id, + path, + read_at +) VALUES ( + ?, + ?, + strftime('%s', 'now') +) ON CONFLICT(path, session_id) DO UPDATE SET + read_at = excluded.read_at +` + +type RecordFileReadParams struct { + SessionID string `json:"session_id"` + Path string `json:"path"` +} + func (q *Queries) RecordFileRead(ctx context.Context, arg RecordFileReadParams) error { - _, err := q.exec(ctx, q.recordFileReadStmt, recordFileRead, - arg.SessionID, - arg.Path, - ) + _, err := q.exec(ctx, q.recordFileReadStmt, recordFileRead, arg.SessionID, arg.Path) return err } diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 3b1ecbfecb3c5d947e84b1ec07f7a3f72b8d6139..bdcddd01d9bdb95034a9a669e6881eed661dee10 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -150,6 +150,23 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { return items, nil } +const renameSession = `-- name: RenameSession :exec +UPDATE sessions +SET + title = ? +WHERE id = ? +` + +type RenameSessionParams struct { + Title string `json:"title"` + ID string `json:"id"` +} + +func (q *Queries) RenameSession(ctx context.Context, arg RenameSessionParams) error { + _, err := q.exec(ctx, q.renameSessionStmt, renameSession, arg.Title, arg.ID) + return err +} + const updateSession = `-- name: UpdateSession :one UPDATE sessions SET @@ -206,7 +223,8 @@ SET title = ?, prompt_tokens = prompt_tokens + ?, completion_tokens = completion_tokens + ?, - cost = cost + ? + cost = cost + ?, + updated_at = strftime('%s', 'now') WHERE id = ? ` diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index 54bc072a0dcd7462d805f30cf832714e1f7d7705..0e170fdeb270041c035c7f2ea24aaa4b571b4387 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -52,10 +52,17 @@ SET title = ?, prompt_tokens = prompt_tokens + ?, completion_tokens = completion_tokens + ?, - cost = cost + ? + cost = cost + ?, + updated_at = strftime('%s', 'now') WHERE id = ?; +-- name: RenameSession :exec +UPDATE sessions +SET + title = ? +WHERE id = ?; + -- name: DeleteSession :exec DELETE FROM sessions WHERE id = ?; diff --git a/internal/session/session.go b/internal/session/session.go index f9279f9f4d45f8562fd868074721d27dca0901f6..834243b62aae6266290147ca0d0270a6069e34b3 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -12,6 +12,7 @@ import ( "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" + "github.com/zeebo/xxh3" ) type TodoStatus string @@ -22,6 +23,13 @@ const ( TodoStatusCompleted TodoStatus = "completed" ) +// HashID returns the XXH3 hash of a session ID (UUID) as a hex string. +func HashID(id string) string { + h := xxh3.New() + h.WriteString(id) + return fmt.Sprintf("%x", h.Sum(nil)) +} + type Todo struct { Content string `json:"content"` Status TodoStatus `json:"status"` @@ -61,6 +69,7 @@ type Service interface { List(ctx context.Context) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error + Rename(ctx context.Context, id string, title string) error Delete(ctx context.Context, id string) error // Agent tool session management @@ -198,6 +207,15 @@ func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title stri }) } +// Rename updates only the title of a session without touching updated_at or +// usage fields. +func (s *service) Rename(ctx context.Context, id string, title string) error { + return s.q.RenameSession(ctx, db.RenameSessionParams{ + ID: id, + Title: title, + }) +} + func (s *service) List(ctx context.Context) ([]Session, error) { dbSessions, err := s.q.ListSessions(ctx) if err != nil {