session.go

  1package cmd
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"runtime"
 12	"strings"
 13	"syscall"
 14	"time"
 15
 16	"charm.land/lipgloss/v2"
 17	"github.com/charmbracelet/colorprofile"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/db"
 20	"github.com/charmbracelet/crush/internal/message"
 21	"github.com/charmbracelet/crush/internal/session"
 22	"github.com/charmbracelet/crush/internal/ui/chat"
 23	"github.com/charmbracelet/crush/internal/ui/styles"
 24	"github.com/charmbracelet/x/ansi"
 25	"github.com/charmbracelet/x/exp/charmtone"
 26	"github.com/charmbracelet/x/term"
 27	"github.com/spf13/cobra"
 28)
 29
 30var sessionCmd = &cobra.Command{
 31	Use:     "session",
 32	Aliases: []string{"sessions"},
 33	Short:   "Manage sessions",
 34	Long:    "Manage Crush sessions. Agents can use --json for machine-readable output.",
 35}
 36
 37var (
 38	sessionListJSON   bool
 39	sessionShowJSON   bool
 40	sessionLastJSON   bool
 41	sessionDeleteJSON bool
 42	sessionRenameJSON bool
 43)
 44
 45var sessionListCmd = &cobra.Command{
 46	Use:     "list",
 47	Aliases: []string{"ls"},
 48	Short:   "List all sessions",
 49	Long:    "List all sessions. Use --json for machine-readable output.",
 50	RunE:    runSessionList,
 51}
 52
 53var sessionShowCmd = &cobra.Command{
 54	Use:   "show <id>",
 55	Short: "Show session details",
 56	Long:  "Show session details. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.",
 57	Args:  cobra.ExactArgs(1),
 58	RunE:  runSessionShow,
 59}
 60
 61var sessionLastCmd = &cobra.Command{
 62	Use:   "last",
 63	Short: "Show most recent session",
 64	Long:  "Show the last updated session. Use --json for machine-readable output.",
 65	RunE:  runSessionLast,
 66}
 67
 68var sessionDeleteCmd = &cobra.Command{
 69	Use:     "delete <id>",
 70	Aliases: []string{"rm"},
 71	Short:   "Delete a session",
 72	Long:    "Delete a session by ID. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.",
 73	Args:    cobra.ExactArgs(1),
 74	RunE:    runSessionDelete,
 75}
 76
 77var sessionRenameCmd = &cobra.Command{
 78	Use:   "rename <id> <title>",
 79	Short: "Rename a session",
 80	Long:  "Rename a session by ID. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.",
 81	Args:  cobra.MinimumNArgs(2),
 82	RunE:  runSessionRename,
 83}
 84
 85func init() {
 86	sessionListCmd.Flags().BoolVar(&sessionListJSON, "json", false, "output in JSON format")
 87	sessionShowCmd.Flags().BoolVar(&sessionShowJSON, "json", false, "output in JSON format")
 88	sessionLastCmd.Flags().BoolVar(&sessionLastJSON, "json", false, "output in JSON format")
 89	sessionDeleteCmd.Flags().BoolVar(&sessionDeleteJSON, "json", false, "output in JSON format")
 90	sessionRenameCmd.Flags().BoolVar(&sessionRenameJSON, "json", false, "output in JSON format")
 91	sessionCmd.AddCommand(sessionListCmd)
 92	sessionCmd.AddCommand(sessionShowCmd)
 93	sessionCmd.AddCommand(sessionLastCmd)
 94	sessionCmd.AddCommand(sessionDeleteCmd)
 95	sessionCmd.AddCommand(sessionRenameCmd)
 96}
 97
 98type sessionServices struct {
 99	sessions session.Service
100	messages message.Service
101}
102
103func sessionSetup(cmd *cobra.Command) (context.Context, *sessionServices, func(), error) {
104	dataDir, _ := cmd.Flags().GetString("data-dir")
105	ctx := cmd.Context()
106
107	if dataDir == "" {
108		cfg, err := config.Init("", "", false)
109		if err != nil {
110			return nil, nil, nil, fmt.Errorf("failed to initialize config: %w", err)
111		}
112		dataDir = cfg.Config().Options.DataDirectory
113	}
114
115	conn, err := db.Connect(ctx, dataDir)
116	if err != nil {
117		return nil, nil, nil, fmt.Errorf("failed to connect to database: %w", err)
118	}
119
120	queries := db.New(conn)
121	svc := &sessionServices{
122		sessions: session.NewService(queries, conn),
123		messages: message.NewService(queries),
124	}
125	return ctx, svc, func() { conn.Close() }, nil
126}
127
128func runSessionList(cmd *cobra.Command, _ []string) error {
129	ctx, svc, cleanup, err := sessionSetup(cmd)
130	if err != nil {
131		return err
132	}
133	defer cleanup()
134
135	list, err := svc.sessions.List(ctx)
136	if err != nil {
137		return fmt.Errorf("failed to list sessions: %w", err)
138	}
139
140	if sessionListJSON {
141		out := cmd.OutOrStdout()
142		output := make([]sessionJSON, len(list))
143		for i, s := range list {
144			output[i] = sessionJSON{
145				ID:       session.HashID(s.ID),
146				UUID:     s.ID,
147				Title:    s.Title,
148				Created:  time.Unix(s.CreatedAt, 0).Format(time.RFC3339),
149				Modified: time.Unix(s.UpdatedAt, 0).Format(time.RFC3339),
150			}
151		}
152		enc := json.NewEncoder(out)
153		enc.SetEscapeHTML(false)
154		return enc.Encode(output)
155	}
156
157	w, cleanup, usingPager := sessionWriter(ctx, len(list))
158	defer cleanup()
159
160	hashStyle := lipgloss.NewStyle().Foreground(charmtone.Malibu)
161	dateStyle := lipgloss.NewStyle().Foreground(charmtone.Damson)
162
163	width := sessionOutputWidth
164	if tw, _, err := term.GetSize(os.Stdout.Fd()); err == nil && tw > 0 {
165		width = tw
166	}
167	// 7 (hash) + 1 (space) + 25 (RFC3339 date) + 1 (space) = 34 chars prefix.
168	titleWidth := width - 34
169	if titleWidth < 10 {
170		titleWidth = 10
171	}
172
173	var writeErr error
174	for _, s := range list {
175		hash := session.HashID(s.ID)[:7]
176		date := time.Unix(s.CreatedAt, 0).Format(time.RFC3339)
177		title := strings.ReplaceAll(s.Title, "\n", " ")
178		title = ansi.Truncate(title, titleWidth, "…")
179		_, writeErr = fmt.Fprintln(w, hashStyle.Render(hash), dateStyle.Render(date), title)
180		if writeErr != nil {
181			break
182		}
183	}
184	if writeErr != nil && usingPager && isBrokenPipe(writeErr) {
185		return nil
186	}
187	return writeErr
188}
189
190type sessionJSON struct {
191	ID       string `json:"id"`
192	UUID     string `json:"uuid"`
193	Title    string `json:"title"`
194	Created  string `json:"created"`
195	Modified string `json:"modified"`
196}
197
198type sessionMutationResult struct {
199	ID      string `json:"id"`
200	UUID    string `json:"uuid"`
201	Title   string `json:"title"`
202	Deleted bool   `json:"deleted,omitempty"`
203	Renamed bool   `json:"renamed,omitempty"`
204}
205
206// resolveSessionID resolves a session ID that can be a UUID, full hash, or hash prefix.
207// Returns an error if the prefix is ambiguous (matches multiple sessions).
208func resolveSessionID(ctx context.Context, svc session.Service, id string) (session.Session, error) {
209	// Try direct UUID lookup first
210	if s, err := svc.Get(ctx, id); err == nil {
211		return s, nil
212	}
213
214	// List all sessions and check for hash matches
215	sessions, err := svc.List(ctx)
216	if err != nil {
217		return session.Session{}, err
218	}
219
220	var matches []session.Session
221	for _, s := range sessions {
222		hash := session.HashID(s.ID)
223		if hash == id || strings.HasPrefix(hash, id) {
224			matches = append(matches, s)
225		}
226	}
227
228	if len(matches) == 0 {
229		return session.Session{}, fmt.Errorf("session not found: %s", id)
230	}
231
232	if len(matches) == 1 {
233		return matches[0], nil
234	}
235
236	// Ambiguous - show matches like Git does
237	var sb strings.Builder
238	fmt.Fprintf(&sb, "session ID '%s' is ambiguous. Matches:\n\n", id)
239	for _, m := range matches {
240		hash := session.HashID(m.ID)
241		created := time.Unix(m.CreatedAt, 0).Format("2006-01-02")
242		// Keep title on one line by replacing newlines with spaces, and truncate.
243		title := strings.ReplaceAll(m.Title, "\n", " ")
244		title = ansi.Truncate(title, 50, "…")
245		fmt.Fprintf(&sb, "  %s... %q (created %s)\n", hash[:12], title, created)
246	}
247	sb.WriteString("\nUse more characters or the full hash")
248	return session.Session{}, errors.New(sb.String())
249}
250
251func runSessionShow(cmd *cobra.Command, args []string) error {
252	ctx, svc, cleanup, err := sessionSetup(cmd)
253	if err != nil {
254		return err
255	}
256	defer cleanup()
257
258	sess, err := resolveSessionID(ctx, svc.sessions, args[0])
259	if err != nil {
260		return err
261	}
262
263	msgs, err := svc.messages.List(ctx, sess.ID)
264	if err != nil {
265		return fmt.Errorf("failed to list messages: %w", err)
266	}
267
268	msgPtrs := messagePtrs(msgs)
269	if sessionShowJSON {
270		return outputSessionJSON(cmd.OutOrStdout(), sess, msgPtrs)
271	}
272	return outputSessionHuman(ctx, sess, msgPtrs)
273}
274
275func runSessionDelete(cmd *cobra.Command, args []string) error {
276	ctx, svc, cleanup, err := sessionSetup(cmd)
277	if err != nil {
278		return err
279	}
280	defer cleanup()
281
282	sess, err := resolveSessionID(ctx, svc.sessions, args[0])
283	if err != nil {
284		return err
285	}
286
287	if err := svc.sessions.Delete(ctx, sess.ID); err != nil {
288		return fmt.Errorf("failed to delete session: %w", err)
289	}
290
291	out := cmd.OutOrStdout()
292	if sessionDeleteJSON {
293		enc := json.NewEncoder(out)
294		enc.SetEscapeHTML(false)
295		return enc.Encode(sessionMutationResult{
296			ID:      session.HashID(sess.ID),
297			UUID:    sess.ID,
298			Title:   sess.Title,
299			Deleted: true,
300		})
301	}
302
303	fmt.Fprintf(out, "Deleted session %s\n", session.HashID(sess.ID)[:12])
304	return nil
305}
306
307func runSessionRename(cmd *cobra.Command, args []string) error {
308	ctx, svc, cleanup, err := sessionSetup(cmd)
309	if err != nil {
310		return err
311	}
312	defer cleanup()
313
314	sess, err := resolveSessionID(ctx, svc.sessions, args[0])
315	if err != nil {
316		return err
317	}
318
319	newTitle := strings.Join(args[1:], " ")
320	if err := svc.sessions.Rename(ctx, sess.ID, newTitle); err != nil {
321		return fmt.Errorf("failed to rename session: %w", err)
322	}
323
324	out := cmd.OutOrStdout()
325	if sessionRenameJSON {
326		enc := json.NewEncoder(out)
327		enc.SetEscapeHTML(false)
328		return enc.Encode(sessionMutationResult{
329			ID:      session.HashID(sess.ID),
330			UUID:    sess.ID,
331			Title:   newTitle,
332			Renamed: true,
333		})
334	}
335
336	fmt.Fprintf(out, "Renamed session %s to %q\n", session.HashID(sess.ID)[:12], newTitle)
337	return nil
338}
339
340func runSessionLast(cmd *cobra.Command, _ []string) error {
341	ctx, svc, cleanup, err := sessionSetup(cmd)
342	if err != nil {
343		return err
344	}
345	defer cleanup()
346
347	list, err := svc.sessions.List(ctx)
348	if err != nil {
349		return fmt.Errorf("failed to list sessions: %w", err)
350	}
351
352	if len(list) == 0 {
353		return fmt.Errorf("no sessions found")
354	}
355
356	sess := list[0]
357
358	msgs, err := svc.messages.List(ctx, sess.ID)
359	if err != nil {
360		return fmt.Errorf("failed to list messages: %w", err)
361	}
362
363	msgPtrs := messagePtrs(msgs)
364	if sessionLastJSON {
365		return outputSessionJSON(cmd.OutOrStdout(), sess, msgPtrs)
366	}
367	return outputSessionHuman(ctx, sess, msgPtrs)
368}
369
370const (
371	sessionOutputWidth     = 80
372	sessionMaxContentWidth = 120
373)
374
375func messagePtrs(msgs []message.Message) []*message.Message {
376	ptrs := make([]*message.Message, len(msgs))
377	for i := range msgs {
378		ptrs[i] = &msgs[i]
379	}
380	return ptrs
381}
382
383func outputSessionJSON(w io.Writer, sess session.Session, msgs []*message.Message) error {
384	output := sessionShowOutput{
385		Meta: sessionShowMeta{
386			ID:               session.HashID(sess.ID),
387			UUID:             sess.ID,
388			Title:            sess.Title,
389			Created:          time.Unix(sess.CreatedAt, 0).Format(time.RFC3339),
390			Modified:         time.Unix(sess.UpdatedAt, 0).Format(time.RFC3339),
391			Cost:             sess.Cost,
392			PromptTokens:     sess.PromptTokens,
393			CompletionTokens: sess.CompletionTokens,
394			TotalTokens:      sess.PromptTokens + sess.CompletionTokens,
395		},
396		Messages: make([]sessionShowMessage, len(msgs)),
397	}
398
399	for i, msg := range msgs {
400		output.Messages[i] = sessionShowMessage{
401			ID:       msg.ID,
402			Role:     string(msg.Role),
403			Created:  time.Unix(msg.CreatedAt, 0).Format(time.RFC3339),
404			Model:    msg.Model,
405			Provider: msg.Provider,
406			Parts:    convertParts(msg.Parts),
407		}
408	}
409
410	enc := json.NewEncoder(w)
411	enc.SetEscapeHTML(false)
412	return enc.Encode(output)
413}
414
415func outputSessionHuman(ctx context.Context, sess session.Session, msgs []*message.Message) error {
416	sty := styles.DefaultStyles()
417	toolResults := chat.BuildToolResultMap(msgs)
418
419	width := sessionOutputWidth
420	if w, _, err := term.GetSize(os.Stdout.Fd()); err == nil && w > 0 {
421		width = w
422	}
423	contentWidth := min(width, sessionMaxContentWidth)
424
425	keyStyle := lipgloss.NewStyle().Foreground(charmtone.Damson)
426	valStyle := lipgloss.NewStyle().Foreground(charmtone.Malibu)
427
428	hash := session.HashID(sess.ID)[:12]
429	created := time.Unix(sess.CreatedAt, 0).Format("Mon Jan 2 15:04:05 2006 -0700")
430
431	// Render to buffer to determine actual height
432	var buf strings.Builder
433
434	fmt.Fprintln(&buf, keyStyle.Render("ID:    ")+valStyle.Render(hash))
435	fmt.Fprintln(&buf, keyStyle.Render("UUID:  ")+valStyle.Render(sess.ID))
436	fmt.Fprintln(&buf, keyStyle.Render("Title: ")+valStyle.Render(sess.Title))
437	fmt.Fprintln(&buf, keyStyle.Render("Date:  ")+valStyle.Render(created))
438	fmt.Fprintln(&buf)
439
440	first := true
441	for _, msg := range msgs {
442		items := chat.ExtractMessageItems(&sty, msg, toolResults)
443		for _, item := range items {
444			if !first {
445				fmt.Fprintln(&buf)
446			}
447			first = false
448			fmt.Fprintln(&buf, item.Render(contentWidth))
449		}
450	}
451	fmt.Fprintln(&buf)
452
453	contentHeight := strings.Count(buf.String(), "\n")
454	w, cleanup, usingPager := sessionWriter(ctx, contentHeight)
455	defer cleanup()
456
457	_, err := io.WriteString(w, buf.String())
458	// Ignore broken pipe errors when using a pager. This happens when the user
459	// exits the pager early (e.g., pressing 'q' in less), which closes the pipe
460	// and causes subsequent writes to fail. These errors are expected user behavior.
461	if err != nil && usingPager && isBrokenPipe(err) {
462		return nil
463	}
464	return err
465}
466
467func isBrokenPipe(err error) bool {
468	if err == nil {
469		return false
470	}
471	// Check for syscall.EPIPE (broken pipe)
472	if errors.Is(err, syscall.EPIPE) {
473		return true
474	}
475	// Also check for "broken pipe" in the error message
476	return strings.Contains(err.Error(), "broken pipe")
477}
478
479// sessionWriter returns a writer, cleanup function, and a bool indicating if a pager is used.
480// When the content fits within the terminal (or stdout is not a TTY), it returns
481// a colorprofile.Writer wrapping stdout. When content exceeds terminal height,
482// it starts a pager process (respecting $PAGER, defaulting to "less -R").
483func sessionWriter(ctx context.Context, contentHeight int) (io.Writer, func(), bool) {
484	// Use NewWriter which automatically detects TTY and strips ANSI when redirected
485	if runtime.GOOS == "windows" || !term.IsTerminal(os.Stdout.Fd()) {
486		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
487	}
488
489	_, termHeight, err := term.GetSize(os.Stdout.Fd())
490	if err != nil || contentHeight <= termHeight {
491		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
492	}
493
494	// Detect color profile from stderr since stdout is piped to the pager.
495	profile := colorprofile.Detect(os.Stderr, os.Environ())
496
497	pager := os.Getenv("PAGER")
498	if pager == "" {
499		pager = "less -R"
500	}
501
502	parts := strings.Fields(pager)
503	cmd := exec.CommandContext(ctx, parts[0], parts[1:]...) //nolint:gosec
504	cmd.Stdout = os.Stdout
505	cmd.Stderr = os.Stderr
506
507	pipe, err := cmd.StdinPipe()
508	if err != nil {
509		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
510	}
511
512	if err := cmd.Start(); err != nil {
513		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
514	}
515
516	return &colorprofile.Writer{
517			Forward: pipe,
518			Profile: profile,
519		}, func() {
520			pipe.Close()
521			_ = cmd.Wait()
522		}, true
523}
524
525type sessionShowMeta struct {
526	ID               string  `json:"id"`
527	UUID             string  `json:"uuid"`
528	Title            string  `json:"title"`
529	Created          string  `json:"created"`
530	Modified         string  `json:"modified"`
531	Cost             float64 `json:"cost"`
532	PromptTokens     int64   `json:"prompt_tokens"`
533	CompletionTokens int64   `json:"completion_tokens"`
534	TotalTokens      int64   `json:"total_tokens"`
535}
536
537type sessionShowMessage struct {
538	ID       string            `json:"id"`
539	Role     string            `json:"role"`
540	Created  string            `json:"created"`
541	Model    string            `json:"model,omitempty"`
542	Provider string            `json:"provider,omitempty"`
543	Parts    []sessionShowPart `json:"parts"`
544}
545
546type sessionShowPart struct {
547	Type string `json:"type"`
548
549	// Text content
550	Text string `json:"text,omitempty"`
551
552	// Reasoning
553	Thinking   string `json:"thinking,omitempty"`
554	StartedAt  int64  `json:"started_at,omitempty"`
555	FinishedAt int64  `json:"finished_at,omitempty"`
556
557	// Tool call
558	ToolCallID string `json:"tool_call_id,omitempty"`
559	Name       string `json:"name,omitempty"`
560	Input      string `json:"input,omitempty"`
561
562	// Tool result
563	Content  string `json:"content,omitempty"`
564	IsError  bool   `json:"is_error,omitempty"`
565	MIMEType string `json:"mime_type,omitempty"`
566
567	// Binary
568	Size int64 `json:"size,omitempty"`
569
570	// Image URL
571	URL    string `json:"url,omitempty"`
572	Detail string `json:"detail,omitempty"`
573
574	// Finish
575	Reason string `json:"reason,omitempty"`
576	Time   int64  `json:"time,omitempty"`
577}
578
579func convertParts(parts []message.ContentPart) []sessionShowPart {
580	result := make([]sessionShowPart, 0, len(parts))
581	for _, part := range parts {
582		switch p := part.(type) {
583		case message.TextContent:
584			result = append(result, sessionShowPart{
585				Type: "text",
586				Text: p.Text,
587			})
588		case message.ReasoningContent:
589			result = append(result, sessionShowPart{
590				Type:       "reasoning",
591				Thinking:   p.Thinking,
592				StartedAt:  p.StartedAt,
593				FinishedAt: p.FinishedAt,
594			})
595		case message.ToolCall:
596			result = append(result, sessionShowPart{
597				Type:       "tool_call",
598				ToolCallID: p.ID,
599				Name:       p.Name,
600				Input:      p.Input,
601			})
602		case message.ToolResult:
603			result = append(result, sessionShowPart{
604				Type:       "tool_result",
605				ToolCallID: p.ToolCallID,
606				Name:       p.Name,
607				Content:    p.Content,
608				IsError:    p.IsError,
609				MIMEType:   p.MIMEType,
610			})
611		case message.BinaryContent:
612			result = append(result, sessionShowPart{
613				Type:     "binary",
614				MIMEType: p.MIMEType,
615				Size:     int64(len(p.Data)),
616			})
617		case message.ImageURLContent:
618			result = append(result, sessionShowPart{
619				Type:   "image_url",
620				URL:    p.URL,
621				Detail: p.Detail,
622			})
623		case message.Finish:
624			result = append(result, sessionShowPart{
625				Type:   "finish",
626				Reason: string(p.Reason),
627				Time:   p.Time,
628			})
629		default:
630			result = append(result, sessionShowPart{
631				Type: "unknown",
632			})
633		}
634	}
635	return result
636}
637
638type sessionShowOutput struct {
639	Meta     sessionShowMeta      `json:"meta"`
640	Messages []sessionShowMessage `json:"messages"`
641}