feat: be able to continue non-interactive sessions (#2401)

Daniil Sivak and Andrey Nering created

Co-authored-by: Andrey Nering <andreynering@users.noreply.github.com>

Change summary

Taskfile.yaml                        |   5 
internal/app/app.go                  |  42 ++++++
internal/app/resolve_session_test.go | 174 ++++++++++++++++++++++++++++++
internal/cmd/run.go                  |  39 +++++
internal/db/db.go                    |  10 +
internal/db/querier.go               |   1 
internal/db/sessions.sql.go          |  26 ++++
internal/db/sql/sessions.sql         |   6 +
internal/event/event.go              |  16 ++
internal/session/session.go          |   9 +
10 files changed, 317 insertions(+), 11 deletions(-)

Detailed changes

Taskfile.yaml 🔗

@@ -184,3 +184,8 @@ tasks:
       - go get charm.land/fantasy
       - go get charm.land/catwalk
       - go mod tidy
+
+  sqlc:
+    desc: Generate code using SQLC
+    cmds:
+      - sqlc generate

internal/app/app.go 🔗

@@ -157,9 +157,40 @@ func (app *App) AgentNotifications() *pubsub.Broker[notify.Notification] {
 	return app.agentNotifications
 }
 
+// resolveSession resolves which session to use for a non-interactive run
+// If continueSessionID is set, it looks up that session by ID
+// If useLast is set, it returns the most recently updated top-level session
+// Otherwise, it creates a new session
+func (app *App) resolveSession(ctx context.Context, continueSessionID string, useLast bool) (session.Session, error) {
+	switch {
+	case continueSessionID != "":
+		if app.Sessions.IsAgentToolSession(continueSessionID) {
+			return session.Session{}, fmt.Errorf("cannot continue an agent tool session: %s", continueSessionID)
+		}
+		sess, err := app.Sessions.Get(ctx, continueSessionID)
+		if err != nil {
+			return session.Session{}, fmt.Errorf("session not found: %s", continueSessionID)
+		}
+		if sess.ParentSessionID != "" {
+			return session.Session{}, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
+		}
+		return sess, nil
+
+	case useLast:
+		sess, err := app.Sessions.GetLast(ctx)
+		if err != nil {
+			return session.Session{}, fmt.Errorf("no sessions found to continue")
+		}
+		return sess, nil
+
+	default:
+		return app.Sessions.Create(ctx, agent.DefaultSessionName)
+	}
+}
+
 // RunNonInteractive runs the application in non-interactive mode with the
 // given prompt, printing to stdout.
-func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool) error {
+func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool, continueSessionID string, useLast bool) error {
 	slog.Info("Running in non-interactive mode")
 
 	ctx, cancel := context.WithCancel(ctx)
@@ -227,11 +258,16 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt,
 
 	defer stopSpinner()
 
-	sess, err := app.Sessions.Create(ctx, agent.DefaultSessionName)
+	sess, err := app.resolveSession(ctx, continueSessionID, useLast)
 	if err != nil {
 		return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
 	}
-	slog.Info("Created session for non-interactive run", "session_id", sess.ID)
+
+	if continueSessionID != "" || useLast {
+		slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
+	} else {
+		slog.Info("Created session for non-interactive run", "session_id", sess.ID)
+	}
 
 	// Automatically approve all permission requests for this non-interactive
 	// session.

internal/app/resolve_session_test.go 🔗

@@ -0,0 +1,174 @@
+package app
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+	"strings"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/pubsub"
+	"github.com/charmbracelet/crush/internal/session"
+	"github.com/stretchr/testify/require"
+)
+
+// mockSessionService is a minimal mock of session.Service for testing resolveSession.
+type mockSessionService struct {
+	sessions []session.Session
+	created  []session.Session
+}
+
+func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
+	return make(chan pubsub.Event[session.Session])
+}
+
+func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
+	s := session.Session{ID: "new-session-id", Title: title}
+	m.created = append(m.created, s)
+	return s, nil
+}
+
+func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
+	return session.Session{}, nil
+}
+
+func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
+	return session.Session{}, nil
+}
+
+func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
+	for _, s := range m.sessions {
+		if s.ID == id {
+			return s, nil
+		}
+	}
+	return session.Session{}, sql.ErrNoRows
+}
+
+func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
+	if len(m.sessions) > 0 {
+		return m.sessions[0], nil
+	}
+	return session.Session{}, sql.ErrNoRows
+}
+
+func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
+	return m.sessions, nil
+}
+
+func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
+	return s, nil
+}
+
+func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
+	return nil
+}
+
+func (m *mockSessionService) Rename(context.Context, string, string) error {
+	return nil
+}
+
+func (m *mockSessionService) Delete(context.Context, string) error {
+	return nil
+}
+
+func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
+	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
+}
+
+func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
+	parts := strings.Split(sessionID, "$$")
+	if len(parts) != 2 {
+		return "", "", false
+	}
+	return parts[0], parts[1], true
+}
+
+func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
+	_, _, ok := m.ParseAgentToolSessionID(sessionID)
+	return ok
+}
+
+func newTestApp(sessions session.Service) *App {
+	return &App{Sessions: sessions}
+}
+
+func TestResolveSession_NewSession(t *testing.T) {
+	mock := &mockSessionService{}
+	app := newTestApp(mock)
+
+	sess, err := app.resolveSession(t.Context(), "", false)
+	require.NoError(t, err)
+	require.Equal(t, "new-session-id", sess.ID)
+	require.Len(t, mock.created, 1)
+}
+
+func TestResolveSession_ContinueByID(t *testing.T) {
+	mock := &mockSessionService{
+		sessions: []session.Session{
+			{ID: "existing-id", Title: "Old session"},
+		},
+	}
+	app := newTestApp(mock)
+
+	sess, err := app.resolveSession(t.Context(), "existing-id", false)
+	require.NoError(t, err)
+	require.Equal(t, "existing-id", sess.ID)
+	require.Equal(t, "Old session", sess.Title)
+	require.Empty(t, mock.created)
+}
+
+func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
+	mock := &mockSessionService{}
+	app := newTestApp(mock)
+
+	_, err := app.resolveSession(t.Context(), "nonexistent", false)
+	require.Error(t, err)
+	require.Contains(t, err.Error(), "session not found")
+}
+
+func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
+	mock := &mockSessionService{
+		sessions: []session.Session{
+			{ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
+		},
+	}
+	app := newTestApp(mock)
+
+	_, err := app.resolveSession(t.Context(), "child-id", false)
+	require.Error(t, err)
+	require.Contains(t, err.Error(), "cannot continue a child session")
+}
+
+func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
+	mock := &mockSessionService{}
+	app := newTestApp(mock)
+
+	_, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
+	require.Error(t, err)
+	require.Contains(t, err.Error(), "cannot continue an agent tool session")
+}
+
+func TestResolveSession_Last(t *testing.T) {
+	mock := &mockSessionService{
+		sessions: []session.Session{
+			{ID: "most-recent", Title: "Latest session"},
+			{ID: "older", Title: "Older session"},
+		},
+	}
+	app := newTestApp(mock)
+
+	sess, err := app.resolveSession(t.Context(), "", true)
+	require.NoError(t, err)
+	require.Equal(t, "most-recent", sess.ID)
+	require.Empty(t, mock.created)
+}
+
+func TestResolveSession_Last_NoSessions(t *testing.T) {
+	mock := &mockSessionService{}
+	app := newTestApp(mock)
+
+	_, err := app.resolveSession(t.Context(), "", true)
+	require.Error(t, err)
+	require.Contains(t, err.Error(), "no sessions found")
+}

internal/cmd/run.go 🔗

@@ -36,12 +36,23 @@ crush run --quiet "Generate a README for this project"
 
 # Run in verbose mode (show logs)
 crush run --verbose "Generate a README for this project"
+
+# Continue a previous session
+crush run --session {session-id} "Follow up on your last response"
+
+# Continue the most recent session
+crush run --continue "Follow up on your last response"
+
   `,
 	RunE: func(cmd *cobra.Command, args []string) error {
-		quiet, _ := cmd.Flags().GetBool("quiet")
-		verbose, _ := cmd.Flags().GetBool("verbose")
-		largeModel, _ := cmd.Flags().GetString("model")
-		smallModel, _ := cmd.Flags().GetString("small-model")
+		var (
+			quiet, _      = cmd.Flags().GetBool("quiet")
+			verbose, _    = cmd.Flags().GetBool("verbose")
+			largeModel, _ = cmd.Flags().GetString("model")
+			smallModel, _ = cmd.Flags().GetString("small-model")
+			sessionID, _  = cmd.Flags().GetString("session")
+			useLast, _    = cmd.Flags().GetBool("continue")
+		)
 
 		// Cancel on SIGINT or SIGTERM.
 		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
@@ -53,6 +64,14 @@ crush run --verbose "Generate a README for this project"
 		}
 		defer app.Shutdown()
 
+		if sessionID != "" {
+			sess, err := resolveSessionID(ctx, app.Sessions, sessionID)
+			if err != nil {
+				return err
+			}
+			sessionID = sess.ID
+		}
+
 		if !app.Config().IsConfigured() {
 			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
 		}
@@ -76,7 +95,14 @@ crush run --verbose "Generate a README for this project"
 		event.SetNonInteractive(true)
 		event.AppInitialized()
 
-		return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose)
+		switch {
+		case sessionID != "":
+			event.SetContinueBySessionID(true)
+		case useLast:
+			event.SetContinueLastSession(true)
+		}
+
+		return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
 	},
 }
 
@@ -85,4 +111,7 @@ func init() {
 	runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
 	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
 	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
+	runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
+	runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
+	runCmd.MarkFlagsMutuallyExclusive("session", "continue")
 }

internal/db/db.go 🔗

@@ -63,6 +63,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
 	if q.getHourDayHeatmapStmt, err = db.PrepareContext(ctx, getHourDayHeatmap); err != nil {
 		return nil, fmt.Errorf("error preparing query GetHourDayHeatmap: %w", err)
 	}
+	if q.getLastSessionStmt, err = db.PrepareContext(ctx, getLastSession); err != nil {
+		return nil, fmt.Errorf("error preparing query GetLastSession: %w", err)
+	}
 	if q.getMessageStmt, err = db.PrepareContext(ctx, getMessage); err != nil {
 		return nil, fmt.Errorf("error preparing query GetMessage: %w", err)
 	}
@@ -202,6 +205,11 @@ func (q *Queries) Close() error {
 			err = fmt.Errorf("error closing getHourDayHeatmapStmt: %w", cerr)
 		}
 	}
+	if q.getLastSessionStmt != nil {
+		if cerr := q.getLastSessionStmt.Close(); cerr != nil {
+			err = fmt.Errorf("error closing getLastSessionStmt: %w", cerr)
+		}
+	}
 	if q.getMessageStmt != nil {
 		if cerr := q.getMessageStmt.Close(); cerr != nil {
 			err = fmt.Errorf("error closing getMessageStmt: %w", cerr)
@@ -369,6 +377,7 @@ type Queries struct {
 	getFileByPathAndSessionStmt    *sql.Stmt
 	getFileReadStmt                *sql.Stmt
 	getHourDayHeatmapStmt          *sql.Stmt
+	getLastSessionStmt             *sql.Stmt
 	getMessageStmt                 *sql.Stmt
 	getRecentActivityStmt          *sql.Stmt
 	getSessionByIDStmt             *sql.Stmt
@@ -411,6 +420,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
 		getFileByPathAndSessionStmt:    q.getFileByPathAndSessionStmt,
 		getFileReadStmt:                q.getFileReadStmt,
 		getHourDayHeatmapStmt:          q.getHourDayHeatmapStmt,
+		getLastSessionStmt:             q.getLastSessionStmt,
 		getMessageStmt:                 q.getMessageStmt,
 		getRecentActivityStmt:          q.getRecentActivityStmt,
 		getSessionByIDStmt:             q.getSessionByIDStmt,

internal/db/querier.go 🔗

@@ -22,6 +22,7 @@ type Querier interface {
 	GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
 	GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error)
 	GetHourDayHeatmap(ctx context.Context) ([]GetHourDayHeatmapRow, error)
+	GetLastSession(ctx context.Context) (Session, error)
 	GetMessage(ctx context.Context, id string) (Message, error)
 	GetRecentActivity(ctx context.Context) ([]GetRecentActivityRow, error)
 	GetSessionByID(ctx context.Context, id string) (Session, error)

internal/db/sessions.sql.go 🔗

@@ -83,6 +83,32 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
 	return err
 }
 
+const getLastSession = `-- name: GetLastSession :one
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+FROM sessions
+ORDER BY updated_at DESC
+LIMIT 1
+`
+
+func (q *Queries) GetLastSession(ctx context.Context) (Session, error) {
+	row := q.queryRow(ctx, q.getLastSessionStmt, getLastSession)
+	var i Session
+	err := row.Scan(
+		&i.ID,
+		&i.ParentSessionID,
+		&i.Title,
+		&i.MessageCount,
+		&i.PromptTokens,
+		&i.CompletionTokens,
+		&i.Cost,
+		&i.UpdatedAt,
+		&i.CreatedAt,
+		&i.SummaryMessageID,
+		&i.Todos,
+	)
+	return i, err
+}
+
 const getSessionByID = `-- name: GetSessionByID :one
 SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
 FROM sessions

internal/db/sql/sessions.sql 🔗

@@ -28,6 +28,12 @@ SELECT *
 FROM sessions
 WHERE id = ? LIMIT 1;
 
+-- name: GetLastSession :one
+SELECT *
+FROM sessions
+ORDER BY updated_at DESC
+LIMIT 1;
+
 -- name: ListSessions :many
 SELECT *
 FROM sessions

internal/event/event.go 🔗

@@ -17,7 +17,9 @@ const (
 	endpoint = "https://data.charm.land"
 	key      = "phc_4zt4VgDWLqbYnJYEwLRxFoaTL2noNrQij0C6E8k3I0V"
 
-	nonInteractiveEventName = "NonInteractive"
+	nonInteractiveAttrName      = "NonInteractive"
+	continueSessionByIDAttrName = "ContinueSessionByID"
+	continueLastSessionAttrName = "ContinueLastSession"
 )
 
 var (
@@ -30,11 +32,19 @@ var (
 			Set("SHELL", filepath.Base(os.Getenv("SHELL"))).
 			Set("Version", version.Version).
 			Set("GoVersion", runtime.Version()).
-			Set(nonInteractiveEventName, false)
+			Set(nonInteractiveAttrName, false)
 )
 
 func SetNonInteractive(nonInteractive bool) {
-	baseProps = baseProps.Set(nonInteractiveEventName, nonInteractive)
+	baseProps = baseProps.Set(nonInteractiveAttrName, nonInteractive)
+}
+
+func SetContinueBySessionID(continueBySessionID bool) {
+	baseProps = baseProps.Set(continueSessionByIDAttrName, continueBySessionID)
+}
+
+func SetContinueLastSession(continueLastSession bool) {
+	baseProps = baseProps.Set(continueLastSessionAttrName, continueLastSession)
 }
 
 func Init() {

internal/session/session.go 🔗

@@ -66,6 +66,7 @@ type Service interface {
 	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 	Get(ctx context.Context, id string) (Session, error)
+	GetLast(ctx context.Context) (Session, error)
 	List(ctx context.Context) ([]Session, error)
 	Save(ctx context.Context, session Session) (Session, error)
 	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
@@ -166,6 +167,14 @@ func (s *service) Get(ctx context.Context, id string) (Session, error) {
 	return s.fromDBItem(dbSession), nil
 }
 
+func (s *service) GetLast(ctx context.Context) (Session, error) {
+	dbSession, err := s.q.GetLastSession(ctx)
+	if err != nil {
+		return Session{}, err
+	}
+	return s.fromDBItem(dbSession), nil
+}
+
 func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 	todosJSON, err := marshalTodos(session.Todos)
 	if err != nil {