session.go

  1package cmd
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"runtime"
 12	"strings"
 13	"syscall"
 14	"time"
 15
 16	"charm.land/lipgloss/v2"
 17	"github.com/charmbracelet/colorprofile"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/db"
 20	"github.com/charmbracelet/crush/internal/event"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/session"
 23	"github.com/charmbracelet/crush/internal/ui/chat"
 24	"github.com/charmbracelet/crush/internal/ui/styles"
 25	"github.com/charmbracelet/x/ansi"
 26	"github.com/charmbracelet/x/exp/charmtone"
 27	"github.com/charmbracelet/x/term"
 28	"github.com/spf13/cobra"
 29)
 30
 31var sessionCmd = &cobra.Command{
 32	Use:     "session",
 33	Aliases: []string{"sessions"},
 34	Short:   "Manage sessions",
 35	Long:    "Manage Crush sessions. Agents can use --json for machine-readable output.",
 36}
 37
 38var (
 39	sessionListJSON   bool
 40	sessionShowJSON   bool
 41	sessionLastJSON   bool
 42	sessionDeleteJSON bool
 43	sessionRenameJSON bool
 44)
 45
 46var sessionListCmd = &cobra.Command{
 47	Use:     "list",
 48	Aliases: []string{"ls"},
 49	Short:   "List all sessions",
 50	Long:    "List all sessions. Use --json for machine-readable output.",
 51	RunE:    runSessionList,
 52}
 53
 54var sessionShowCmd = &cobra.Command{
 55	Use:   "show <id>",
 56	Short: "Show session details",
 57	Long:  "Show session details. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.",
 58	Args:  cobra.ExactArgs(1),
 59	RunE:  runSessionShow,
 60}
 61
 62var sessionLastCmd = &cobra.Command{
 63	Use:   "last",
 64	Short: "Show most recent session",
 65	Long:  "Show the last updated session. Use --json for machine-readable output.",
 66	RunE:  runSessionLast,
 67}
 68
 69var sessionDeleteCmd = &cobra.Command{
 70	Use:     "delete <id>",
 71	Aliases: []string{"rm"},
 72	Short:   "Delete a session",
 73	Long:    "Delete a session by ID. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.",
 74	Args:    cobra.ExactArgs(1),
 75	RunE:    runSessionDelete,
 76}
 77
 78var sessionRenameCmd = &cobra.Command{
 79	Use:   "rename <id> <title>",
 80	Short: "Rename a session",
 81	Long:  "Rename a session by ID. Use --json for machine-readable output. ID can be a UUID, full hash, or hash prefix.",
 82	Args:  cobra.MinimumNArgs(2),
 83	RunE:  runSessionRename,
 84}
 85
 86func init() {
 87	sessionListCmd.Flags().BoolVar(&sessionListJSON, "json", false, "output in JSON format")
 88	sessionShowCmd.Flags().BoolVar(&sessionShowJSON, "json", false, "output in JSON format")
 89	sessionLastCmd.Flags().BoolVar(&sessionLastJSON, "json", false, "output in JSON format")
 90	sessionDeleteCmd.Flags().BoolVar(&sessionDeleteJSON, "json", false, "output in JSON format")
 91	sessionRenameCmd.Flags().BoolVar(&sessionRenameJSON, "json", false, "output in JSON format")
 92	sessionCmd.AddCommand(sessionListCmd)
 93	sessionCmd.AddCommand(sessionShowCmd)
 94	sessionCmd.AddCommand(sessionLastCmd)
 95	sessionCmd.AddCommand(sessionDeleteCmd)
 96	sessionCmd.AddCommand(sessionRenameCmd)
 97}
 98
 99type sessionServices struct {
100	sessions session.Service
101	messages message.Service
102}
103
104func sessionSetup(cmd *cobra.Command) (context.Context, *sessionServices, func(), error) {
105	dataDir, _ := cmd.Flags().GetString("data-dir")
106	ctx := cmd.Context()
107
108	if dataDir == "" {
109		cfg, err := config.Init("", "", false)
110		if err != nil {
111			return nil, nil, nil, fmt.Errorf("failed to initialize config: %w", err)
112		}
113		dataDir = cfg.Config().Options.DataDirectory
114	}
115
116	conn, err := db.Connect(ctx, dataDir)
117	if err != nil {
118		return nil, nil, nil, fmt.Errorf("failed to connect to database: %w", err)
119	}
120
121	queries := db.New(conn)
122	svc := &sessionServices{
123		sessions: session.NewService(queries, conn),
124		messages: message.NewService(queries),
125	}
126	return ctx, svc, func() { conn.Close() }, nil
127}
128
129func runSessionList(cmd *cobra.Command, _ []string) error {
130	event.SetNonInteractive(true)
131	event.SessionListed(sessionListJSON)
132
133	ctx, svc, cleanup, err := sessionSetup(cmd)
134	if err != nil {
135		return err
136	}
137	defer cleanup()
138
139	list, err := svc.sessions.List(ctx)
140	if err != nil {
141		return fmt.Errorf("failed to list sessions: %w", err)
142	}
143
144	if sessionListJSON {
145		out := cmd.OutOrStdout()
146		output := make([]sessionJSON, len(list))
147		for i, s := range list {
148			output[i] = sessionJSON{
149				ID:       session.HashID(s.ID),
150				UUID:     s.ID,
151				Title:    s.Title,
152				Created:  time.Unix(s.CreatedAt, 0).Format(time.RFC3339),
153				Modified: time.Unix(s.UpdatedAt, 0).Format(time.RFC3339),
154			}
155		}
156		enc := json.NewEncoder(out)
157		enc.SetEscapeHTML(false)
158		return enc.Encode(output)
159	}
160
161	w, cleanup, usingPager := sessionWriter(ctx, len(list))
162	defer cleanup()
163
164	hashStyle := lipgloss.NewStyle().Foreground(charmtone.Malibu)
165	dateStyle := lipgloss.NewStyle().Foreground(charmtone.Damson)
166
167	width := sessionOutputWidth
168	if tw, _, err := term.GetSize(os.Stdout.Fd()); err == nil && tw > 0 {
169		width = tw
170	}
171	// 7 (hash) + 1 (space) + 25 (RFC3339 date) + 1 (space) = 34 chars prefix.
172	titleWidth := width - 34
173	if titleWidth < 10 {
174		titleWidth = 10
175	}
176
177	var writeErr error
178	for _, s := range list {
179		hash := session.HashID(s.ID)[:7]
180		date := time.Unix(s.CreatedAt, 0).Format(time.RFC3339)
181		title := strings.ReplaceAll(s.Title, "\n", " ")
182		title = ansi.Truncate(title, titleWidth, "…")
183		_, writeErr = fmt.Fprintln(w, hashStyle.Render(hash), dateStyle.Render(date), title)
184		if writeErr != nil {
185			break
186		}
187	}
188	if writeErr != nil && usingPager && isBrokenPipe(writeErr) {
189		return nil
190	}
191	return writeErr
192}
193
194type sessionJSON struct {
195	ID       string `json:"id"`
196	UUID     string `json:"uuid"`
197	Title    string `json:"title"`
198	Created  string `json:"created"`
199	Modified string `json:"modified"`
200}
201
202type sessionMutationResult struct {
203	ID      string `json:"id"`
204	UUID    string `json:"uuid"`
205	Title   string `json:"title"`
206	Deleted bool   `json:"deleted,omitempty"`
207	Renamed bool   `json:"renamed,omitempty"`
208}
209
210// resolveSessionID resolves a session ID that can be a UUID, full hash, or hash prefix.
211// Returns an error if the prefix is ambiguous (matches multiple sessions).
212func resolveSessionID(ctx context.Context, svc session.Service, id string) (session.Session, error) {
213	// Try direct UUID lookup first
214	if s, err := svc.Get(ctx, id); err == nil {
215		return s, nil
216	}
217
218	// List all sessions and check for hash matches
219	sessions, err := svc.List(ctx)
220	if err != nil {
221		return session.Session{}, err
222	}
223
224	var matches []session.Session
225	for _, s := range sessions {
226		hash := session.HashID(s.ID)
227		if hash == id || strings.HasPrefix(hash, id) {
228			matches = append(matches, s)
229		}
230	}
231
232	if len(matches) == 0 {
233		return session.Session{}, fmt.Errorf("session not found: %s", id)
234	}
235
236	if len(matches) == 1 {
237		return matches[0], nil
238	}
239
240	// Ambiguous - show matches like Git does
241	var sb strings.Builder
242	fmt.Fprintf(&sb, "session ID '%s' is ambiguous. Matches:\n\n", id)
243	for _, m := range matches {
244		hash := session.HashID(m.ID)
245		created := time.Unix(m.CreatedAt, 0).Format("2006-01-02")
246		// Keep title on one line by replacing newlines with spaces, and truncate.
247		title := strings.ReplaceAll(m.Title, "\n", " ")
248		title = ansi.Truncate(title, 50, "…")
249		fmt.Fprintf(&sb, "  %s... %q (created %s)\n", hash[:12], title, created)
250	}
251	sb.WriteString("\nUse more characters or the full hash")
252	return session.Session{}, errors.New(sb.String())
253}
254
255func runSessionShow(cmd *cobra.Command, args []string) error {
256	event.SetNonInteractive(true)
257	event.SessionShown(sessionShowJSON)
258
259	ctx, svc, cleanup, err := sessionSetup(cmd)
260	if err != nil {
261		return err
262	}
263	defer cleanup()
264
265	sess, err := resolveSessionID(ctx, svc.sessions, args[0])
266	if err != nil {
267		return err
268	}
269
270	msgs, err := svc.messages.List(ctx, sess.ID)
271	if err != nil {
272		return fmt.Errorf("failed to list messages: %w", err)
273	}
274
275	msgPtrs := messagePtrs(msgs)
276	if sessionShowJSON {
277		return outputSessionJSON(cmd.OutOrStdout(), sess, msgPtrs)
278	}
279	return outputSessionHuman(ctx, sess, msgPtrs)
280}
281
282func runSessionDelete(cmd *cobra.Command, args []string) error {
283	event.SetNonInteractive(true)
284	event.SessionDeletedCommand(sessionDeleteJSON)
285
286	ctx, svc, cleanup, err := sessionSetup(cmd)
287	if err != nil {
288		return err
289	}
290	defer cleanup()
291
292	sess, err := resolveSessionID(ctx, svc.sessions, args[0])
293	if err != nil {
294		return err
295	}
296
297	if err := svc.sessions.Delete(ctx, sess.ID); err != nil {
298		return fmt.Errorf("failed to delete session: %w", err)
299	}
300
301	out := cmd.OutOrStdout()
302	if sessionDeleteJSON {
303		enc := json.NewEncoder(out)
304		enc.SetEscapeHTML(false)
305		return enc.Encode(sessionMutationResult{
306			ID:      session.HashID(sess.ID),
307			UUID:    sess.ID,
308			Title:   sess.Title,
309			Deleted: true,
310		})
311	}
312
313	fmt.Fprintf(out, "Deleted session %s\n", session.HashID(sess.ID)[:12])
314	return nil
315}
316
317func runSessionRename(cmd *cobra.Command, args []string) error {
318	event.SetNonInteractive(true)
319	event.SessionRenamed(sessionRenameJSON)
320
321	ctx, svc, cleanup, err := sessionSetup(cmd)
322	if err != nil {
323		return err
324	}
325	defer cleanup()
326
327	sess, err := resolveSessionID(ctx, svc.sessions, args[0])
328	if err != nil {
329		return err
330	}
331
332	newTitle := strings.Join(args[1:], " ")
333	if err := svc.sessions.Rename(ctx, sess.ID, newTitle); err != nil {
334		return fmt.Errorf("failed to rename session: %w", err)
335	}
336
337	out := cmd.OutOrStdout()
338	if sessionRenameJSON {
339		enc := json.NewEncoder(out)
340		enc.SetEscapeHTML(false)
341		return enc.Encode(sessionMutationResult{
342			ID:      session.HashID(sess.ID),
343			UUID:    sess.ID,
344			Title:   newTitle,
345			Renamed: true,
346		})
347	}
348
349	fmt.Fprintf(out, "Renamed session %s to %q\n", session.HashID(sess.ID)[:12], newTitle)
350	return nil
351}
352
353func runSessionLast(cmd *cobra.Command, _ []string) error {
354	event.SetNonInteractive(true)
355	event.SessionLastShown(sessionLastJSON)
356
357	ctx, svc, cleanup, err := sessionSetup(cmd)
358	if err != nil {
359		return err
360	}
361	defer cleanup()
362
363	list, err := svc.sessions.List(ctx)
364	if err != nil {
365		return fmt.Errorf("failed to list sessions: %w", err)
366	}
367
368	if len(list) == 0 {
369		return fmt.Errorf("no sessions found")
370	}
371
372	sess := list[0]
373
374	msgs, err := svc.messages.List(ctx, sess.ID)
375	if err != nil {
376		return fmt.Errorf("failed to list messages: %w", err)
377	}
378
379	msgPtrs := messagePtrs(msgs)
380	if sessionLastJSON {
381		return outputSessionJSON(cmd.OutOrStdout(), sess, msgPtrs)
382	}
383	return outputSessionHuman(ctx, sess, msgPtrs)
384}
385
386const (
387	sessionOutputWidth     = 80
388	sessionMaxContentWidth = 120
389)
390
391func messagePtrs(msgs []message.Message) []*message.Message {
392	ptrs := make([]*message.Message, len(msgs))
393	for i := range msgs {
394		ptrs[i] = &msgs[i]
395	}
396	return ptrs
397}
398
399func outputSessionJSON(w io.Writer, sess session.Session, msgs []*message.Message) error {
400	output := sessionShowOutput{
401		Meta: sessionShowMeta{
402			ID:               session.HashID(sess.ID),
403			UUID:             sess.ID,
404			Title:            sess.Title,
405			Created:          time.Unix(sess.CreatedAt, 0).Format(time.RFC3339),
406			Modified:         time.Unix(sess.UpdatedAt, 0).Format(time.RFC3339),
407			Cost:             sess.Cost,
408			PromptTokens:     sess.PromptTokens,
409			CompletionTokens: sess.CompletionTokens,
410			TotalTokens:      sess.PromptTokens + sess.CompletionTokens,
411		},
412		Messages: make([]sessionShowMessage, len(msgs)),
413	}
414
415	for i, msg := range msgs {
416		output.Messages[i] = sessionShowMessage{
417			ID:       msg.ID,
418			Role:     string(msg.Role),
419			Created:  time.Unix(msg.CreatedAt, 0).Format(time.RFC3339),
420			Model:    msg.Model,
421			Provider: msg.Provider,
422			Parts:    convertParts(msg.Parts),
423		}
424	}
425
426	enc := json.NewEncoder(w)
427	enc.SetEscapeHTML(false)
428	return enc.Encode(output)
429}
430
431func outputSessionHuman(ctx context.Context, sess session.Session, msgs []*message.Message) error {
432	sty := styles.DefaultStyles()
433	toolResults := chat.BuildToolResultMap(msgs)
434
435	width := sessionOutputWidth
436	if w, _, err := term.GetSize(os.Stdout.Fd()); err == nil && w > 0 {
437		width = w
438	}
439	contentWidth := min(width, sessionMaxContentWidth)
440
441	keyStyle := lipgloss.NewStyle().Foreground(charmtone.Damson)
442	valStyle := lipgloss.NewStyle().Foreground(charmtone.Malibu)
443
444	hash := session.HashID(sess.ID)[:12]
445	created := time.Unix(sess.CreatedAt, 0).Format("Mon Jan 2 15:04:05 2006 -0700")
446
447	// Render to buffer to determine actual height
448	var buf strings.Builder
449
450	fmt.Fprintln(&buf, keyStyle.Render("ID:    ")+valStyle.Render(hash))
451	fmt.Fprintln(&buf, keyStyle.Render("UUID:  ")+valStyle.Render(sess.ID))
452	fmt.Fprintln(&buf, keyStyle.Render("Title: ")+valStyle.Render(sess.Title))
453	fmt.Fprintln(&buf, keyStyle.Render("Date:  ")+valStyle.Render(created))
454	fmt.Fprintln(&buf)
455
456	first := true
457	for _, msg := range msgs {
458		items := chat.ExtractMessageItems(&sty, msg, toolResults)
459		for _, item := range items {
460			if !first {
461				fmt.Fprintln(&buf)
462			}
463			first = false
464			fmt.Fprintln(&buf, item.Render(contentWidth))
465		}
466	}
467	fmt.Fprintln(&buf)
468
469	contentHeight := strings.Count(buf.String(), "\n")
470	w, cleanup, usingPager := sessionWriter(ctx, contentHeight)
471	defer cleanup()
472
473	_, err := io.WriteString(w, buf.String())
474	// Ignore broken pipe errors when using a pager. This happens when the user
475	// exits the pager early (e.g., pressing 'q' in less), which closes the pipe
476	// and causes subsequent writes to fail. These errors are expected user behavior.
477	if err != nil && usingPager && isBrokenPipe(err) {
478		return nil
479	}
480	return err
481}
482
483func isBrokenPipe(err error) bool {
484	if err == nil {
485		return false
486	}
487	// Check for syscall.EPIPE (broken pipe)
488	if errors.Is(err, syscall.EPIPE) {
489		return true
490	}
491	// Also check for "broken pipe" in the error message
492	return strings.Contains(err.Error(), "broken pipe")
493}
494
495// sessionWriter returns a writer, cleanup function, and a bool indicating if a pager is used.
496// When the content fits within the terminal (or stdout is not a TTY), it returns
497// a colorprofile.Writer wrapping stdout. When content exceeds terminal height,
498// it starts a pager process (respecting $PAGER, defaulting to "less -R").
499func sessionWriter(ctx context.Context, contentHeight int) (io.Writer, func(), bool) {
500	// Use NewWriter which automatically detects TTY and strips ANSI when redirected
501	if runtime.GOOS == "windows" || !term.IsTerminal(os.Stdout.Fd()) {
502		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
503	}
504
505	_, termHeight, err := term.GetSize(os.Stdout.Fd())
506	if err != nil || contentHeight <= termHeight {
507		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
508	}
509
510	// Detect color profile from stderr since stdout is piped to the pager.
511	profile := colorprofile.Detect(os.Stderr, os.Environ())
512
513	pager := os.Getenv("PAGER")
514	if pager == "" {
515		pager = "less -R"
516	}
517
518	parts := strings.Fields(pager)
519	cmd := exec.CommandContext(ctx, parts[0], parts[1:]...) //nolint:gosec
520	cmd.Stdout = os.Stdout
521	cmd.Stderr = os.Stderr
522
523	pipe, err := cmd.StdinPipe()
524	if err != nil {
525		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
526	}
527
528	if err := cmd.Start(); err != nil {
529		return colorprofile.NewWriter(os.Stdout, os.Environ()), func() {}, false
530	}
531
532	return &colorprofile.Writer{
533			Forward: pipe,
534			Profile: profile,
535		}, func() {
536			pipe.Close()
537			_ = cmd.Wait()
538		}, true
539}
540
541type sessionShowMeta struct {
542	ID               string  `json:"id"`
543	UUID             string  `json:"uuid"`
544	Title            string  `json:"title"`
545	Created          string  `json:"created"`
546	Modified         string  `json:"modified"`
547	Cost             float64 `json:"cost"`
548	PromptTokens     int64   `json:"prompt_tokens"`
549	CompletionTokens int64   `json:"completion_tokens"`
550	TotalTokens      int64   `json:"total_tokens"`
551}
552
553type sessionShowMessage struct {
554	ID       string            `json:"id"`
555	Role     string            `json:"role"`
556	Created  string            `json:"created"`
557	Model    string            `json:"model,omitempty"`
558	Provider string            `json:"provider,omitempty"`
559	Parts    []sessionShowPart `json:"parts"`
560}
561
562type sessionShowPart struct {
563	Type string `json:"type"`
564
565	// Text content
566	Text string `json:"text,omitempty"`
567
568	// Reasoning
569	Thinking   string `json:"thinking,omitempty"`
570	StartedAt  int64  `json:"started_at,omitempty"`
571	FinishedAt int64  `json:"finished_at,omitempty"`
572
573	// Tool call
574	ToolCallID string `json:"tool_call_id,omitempty"`
575	Name       string `json:"name,omitempty"`
576	Input      string `json:"input,omitempty"`
577
578	// Tool result
579	Content  string `json:"content,omitempty"`
580	IsError  bool   `json:"is_error,omitempty"`
581	MIMEType string `json:"mime_type,omitempty"`
582
583	// Binary
584	Size int64 `json:"size,omitempty"`
585
586	// Image URL
587	URL    string `json:"url,omitempty"`
588	Detail string `json:"detail,omitempty"`
589
590	// Finish
591	Reason string `json:"reason,omitempty"`
592	Time   int64  `json:"time,omitempty"`
593}
594
595func convertParts(parts []message.ContentPart) []sessionShowPart {
596	result := make([]sessionShowPart, 0, len(parts))
597	for _, part := range parts {
598		switch p := part.(type) {
599		case message.TextContent:
600			result = append(result, sessionShowPart{
601				Type: "text",
602				Text: p.Text,
603			})
604		case message.ReasoningContent:
605			result = append(result, sessionShowPart{
606				Type:       "reasoning",
607				Thinking:   p.Thinking,
608				StartedAt:  p.StartedAt,
609				FinishedAt: p.FinishedAt,
610			})
611		case message.ToolCall:
612			result = append(result, sessionShowPart{
613				Type:       "tool_call",
614				ToolCallID: p.ID,
615				Name:       p.Name,
616				Input:      p.Input,
617			})
618		case message.ToolResult:
619			result = append(result, sessionShowPart{
620				Type:       "tool_result",
621				ToolCallID: p.ToolCallID,
622				Name:       p.Name,
623				Content:    p.Content,
624				IsError:    p.IsError,
625				MIMEType:   p.MIMEType,
626			})
627		case message.BinaryContent:
628			result = append(result, sessionShowPart{
629				Type:     "binary",
630				MIMEType: p.MIMEType,
631				Size:     int64(len(p.Data)),
632			})
633		case message.ImageURLContent:
634			result = append(result, sessionShowPart{
635				Type:   "image_url",
636				URL:    p.URL,
637				Detail: p.Detail,
638			})
639		case message.Finish:
640			result = append(result, sessionShowPart{
641				Type:   "finish",
642				Reason: string(p.Reason),
643				Time:   p.Time,
644			})
645		default:
646			result = append(result, sessionShowPart{
647				Type: "unknown",
648			})
649		}
650	}
651	return result
652}
653
654type sessionShowOutput struct {
655	Meta     sessionShowMeta      `json:"meta"`
656	Messages []sessionShowMessage `json:"messages"`
657}