Taskfile.yaml 🔗
@@ -184,3 +184,8 @@ tasks:
- go get charm.land/fantasy
- go get charm.land/catwalk
- go mod tidy
+
+ sqlc:
+ desc: Generate code using SQLC
+ cmds:
+ - sqlc generate
Daniil Sivak and Andrey Nering created
Co-authored-by: Andrey Nering <andreynering@users.noreply.github.com>
Taskfile.yaml | 5
internal/app/app.go | 42 ++++++
internal/app/resolve_session_test.go | 174 ++++++++++++++++++++++++++++++
internal/cmd/run.go | 39 +++++
internal/db/db.go | 10 +
internal/db/querier.go | 1
internal/db/sessions.sql.go | 26 ++++
internal/db/sql/sessions.sql | 6 +
internal/event/event.go | 16 ++
internal/session/session.go | 9 +
10 files changed, 317 insertions(+), 11 deletions(-)
@@ -184,3 +184,8 @@ tasks:
- go get charm.land/fantasy
- go get charm.land/catwalk
- go mod tidy
+
+ sqlc:
+ desc: Generate code using SQLC
+ cmds:
+ - sqlc generate
@@ -157,9 +157,40 @@ func (app *App) AgentNotifications() *pubsub.Broker[notify.Notification] {
return app.agentNotifications
}
+// resolveSession resolves which session to use for a non-interactive run
+// If continueSessionID is set, it looks up that session by ID
+// If useLast is set, it returns the most recently updated top-level session
+// Otherwise, it creates a new session
+func (app *App) resolveSession(ctx context.Context, continueSessionID string, useLast bool) (session.Session, error) {
+ switch {
+ case continueSessionID != "":
+ if app.Sessions.IsAgentToolSession(continueSessionID) {
+ return session.Session{}, fmt.Errorf("cannot continue an agent tool session: %s", continueSessionID)
+ }
+ sess, err := app.Sessions.Get(ctx, continueSessionID)
+ if err != nil {
+ return session.Session{}, fmt.Errorf("session not found: %s", continueSessionID)
+ }
+ if sess.ParentSessionID != "" {
+ return session.Session{}, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
+ }
+ return sess, nil
+
+ case useLast:
+ sess, err := app.Sessions.GetLast(ctx)
+ if err != nil {
+ return session.Session{}, fmt.Errorf("no sessions found to continue")
+ }
+ return sess, nil
+
+ default:
+ return app.Sessions.Create(ctx, agent.DefaultSessionName)
+ }
+}
+
// RunNonInteractive runs the application in non-interactive mode with the
// given prompt, printing to stdout.
-func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool) error {
+func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool, continueSessionID string, useLast bool) error {
slog.Info("Running in non-interactive mode")
ctx, cancel := context.WithCancel(ctx)
@@ -227,11 +258,16 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt,
defer stopSpinner()
- sess, err := app.Sessions.Create(ctx, agent.DefaultSessionName)
+ sess, err := app.resolveSession(ctx, continueSessionID, useLast)
if err != nil {
return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
}
- slog.Info("Created session for non-interactive run", "session_id", sess.ID)
+
+ if continueSessionID != "" || useLast {
+ slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
+ } else {
+ slog.Info("Created session for non-interactive run", "session_id", sess.ID)
+ }
// Automatically approve all permission requests for this non-interactive
// session.
@@ -0,0 +1,174 @@
+package app
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/session"
+ "github.com/stretchr/testify/require"
+)
+
+// mockSessionService is a minimal mock of session.Service for testing resolveSession.
+type mockSessionService struct {
+ sessions []session.Session
+ created []session.Session
+}
+
+func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
+ return make(chan pubsub.Event[session.Session])
+}
+
+func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
+ s := session.Session{ID: "new-session-id", Title: title}
+ m.created = append(m.created, s)
+ return s, nil
+}
+
+func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
+ return session.Session{}, nil
+}
+
+func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
+ return session.Session{}, nil
+}
+
+func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
+ for _, s := range m.sessions {
+ if s.ID == id {
+ return s, nil
+ }
+ }
+ return session.Session{}, sql.ErrNoRows
+}
+
+func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
+ if len(m.sessions) > 0 {
+ return m.sessions[0], nil
+ }
+ return session.Session{}, sql.ErrNoRows
+}
+
+func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
+ return m.sessions, nil
+}
+
+func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
+ return s, nil
+}
+
+func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
+ return nil
+}
+
+func (m *mockSessionService) Rename(context.Context, string, string) error {
+ return nil
+}
+
+func (m *mockSessionService) Delete(context.Context, string) error {
+ return nil
+}
+
+func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
+ return fmt.Sprintf("%s$$%s", messageID, toolCallID)
+}
+
+func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
+ parts := strings.Split(sessionID, "$$")
+ if len(parts) != 2 {
+ return "", "", false
+ }
+ return parts[0], parts[1], true
+}
+
+func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
+ _, _, ok := m.ParseAgentToolSessionID(sessionID)
+ return ok
+}
+
+func newTestApp(sessions session.Service) *App {
+ return &App{Sessions: sessions}
+}
+
+func TestResolveSession_NewSession(t *testing.T) {
+ mock := &mockSessionService{}
+ app := newTestApp(mock)
+
+ sess, err := app.resolveSession(t.Context(), "", false)
+ require.NoError(t, err)
+ require.Equal(t, "new-session-id", sess.ID)
+ require.Len(t, mock.created, 1)
+}
+
+func TestResolveSession_ContinueByID(t *testing.T) {
+ mock := &mockSessionService{
+ sessions: []session.Session{
+ {ID: "existing-id", Title: "Old session"},
+ },
+ }
+ app := newTestApp(mock)
+
+ sess, err := app.resolveSession(t.Context(), "existing-id", false)
+ require.NoError(t, err)
+ require.Equal(t, "existing-id", sess.ID)
+ require.Equal(t, "Old session", sess.Title)
+ require.Empty(t, mock.created)
+}
+
+func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
+ mock := &mockSessionService{}
+ app := newTestApp(mock)
+
+ _, err := app.resolveSession(t.Context(), "nonexistent", false)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "session not found")
+}
+
+func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
+ mock := &mockSessionService{
+ sessions: []session.Session{
+ {ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
+ },
+ }
+ app := newTestApp(mock)
+
+ _, err := app.resolveSession(t.Context(), "child-id", false)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "cannot continue a child session")
+}
+
+func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
+ mock := &mockSessionService{}
+ app := newTestApp(mock)
+
+ _, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "cannot continue an agent tool session")
+}
+
+func TestResolveSession_Last(t *testing.T) {
+ mock := &mockSessionService{
+ sessions: []session.Session{
+ {ID: "most-recent", Title: "Latest session"},
+ {ID: "older", Title: "Older session"},
+ },
+ }
+ app := newTestApp(mock)
+
+ sess, err := app.resolveSession(t.Context(), "", true)
+ require.NoError(t, err)
+ require.Equal(t, "most-recent", sess.ID)
+ require.Empty(t, mock.created)
+}
+
+func TestResolveSession_Last_NoSessions(t *testing.T) {
+ mock := &mockSessionService{}
+ app := newTestApp(mock)
+
+ _, err := app.resolveSession(t.Context(), "", true)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "no sessions found")
+}
@@ -36,12 +36,23 @@ crush run --quiet "Generate a README for this project"
# Run in verbose mode (show logs)
crush run --verbose "Generate a README for this project"
+
+# Continue a previous session
+crush run --session {session-id} "Follow up on your last response"
+
+# Continue the most recent session
+crush run --continue "Follow up on your last response"
+
`,
RunE: func(cmd *cobra.Command, args []string) error {
- quiet, _ := cmd.Flags().GetBool("quiet")
- verbose, _ := cmd.Flags().GetBool("verbose")
- largeModel, _ := cmd.Flags().GetString("model")
- smallModel, _ := cmd.Flags().GetString("small-model")
+ var (
+ quiet, _ = cmd.Flags().GetBool("quiet")
+ verbose, _ = cmd.Flags().GetBool("verbose")
+ largeModel, _ = cmd.Flags().GetString("model")
+ smallModel, _ = cmd.Flags().GetString("small-model")
+ sessionID, _ = cmd.Flags().GetString("session")
+ useLast, _ = cmd.Flags().GetBool("continue")
+ )
// Cancel on SIGINT or SIGTERM.
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
@@ -53,6 +64,14 @@ crush run --verbose "Generate a README for this project"
}
defer app.Shutdown()
+ if sessionID != "" {
+ sess, err := resolveSessionID(ctx, app.Sessions, sessionID)
+ if err != nil {
+ return err
+ }
+ sessionID = sess.ID
+ }
+
if !app.Config().IsConfigured() {
return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
}
@@ -76,7 +95,14 @@ crush run --verbose "Generate a README for this project"
event.SetNonInteractive(true)
event.AppInitialized()
- return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose)
+ switch {
+ case sessionID != "":
+ event.SetContinueBySessionID(true)
+ case useLast:
+ event.SetContinueLastSession(true)
+ }
+
+ return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
},
}
@@ -85,4 +111,7 @@ func init() {
runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
+ runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
+ runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
+ runCmd.MarkFlagsMutuallyExclusive("session", "continue")
}
@@ -63,6 +63,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.getHourDayHeatmapStmt, err = db.PrepareContext(ctx, getHourDayHeatmap); err != nil {
return nil, fmt.Errorf("error preparing query GetHourDayHeatmap: %w", err)
}
+ if q.getLastSessionStmt, err = db.PrepareContext(ctx, getLastSession); err != nil {
+ return nil, fmt.Errorf("error preparing query GetLastSession: %w", err)
+ }
if q.getMessageStmt, err = db.PrepareContext(ctx, getMessage); err != nil {
return nil, fmt.Errorf("error preparing query GetMessage: %w", err)
}
@@ -202,6 +205,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing getHourDayHeatmapStmt: %w", cerr)
}
}
+ if q.getLastSessionStmt != nil {
+ if cerr := q.getLastSessionStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing getLastSessionStmt: %w", cerr)
+ }
+ }
if q.getMessageStmt != nil {
if cerr := q.getMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing getMessageStmt: %w", cerr)
@@ -369,6 +377,7 @@ type Queries struct {
getFileByPathAndSessionStmt *sql.Stmt
getFileReadStmt *sql.Stmt
getHourDayHeatmapStmt *sql.Stmt
+ getLastSessionStmt *sql.Stmt
getMessageStmt *sql.Stmt
getRecentActivityStmt *sql.Stmt
getSessionByIDStmt *sql.Stmt
@@ -411,6 +420,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
getFileReadStmt: q.getFileReadStmt,
getHourDayHeatmapStmt: q.getHourDayHeatmapStmt,
+ getLastSessionStmt: q.getLastSessionStmt,
getMessageStmt: q.getMessageStmt,
getRecentActivityStmt: q.getRecentActivityStmt,
getSessionByIDStmt: q.getSessionByIDStmt,
@@ -22,6 +22,7 @@ type Querier interface {
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error)
GetHourDayHeatmap(ctx context.Context) ([]GetHourDayHeatmapRow, error)
+ GetLastSession(ctx context.Context) (Session, error)
GetMessage(ctx context.Context, id string) (Message, error)
GetRecentActivity(ctx context.Context) ([]GetRecentActivityRow, error)
GetSessionByID(ctx context.Context, id string) (Session, error)
@@ -83,6 +83,32 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
return err
}
+const getLastSession = `-- name: GetLastSession :one
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+FROM sessions
+ORDER BY updated_at DESC
+LIMIT 1
+`
+
+func (q *Queries) GetLastSession(ctx context.Context) (Session, error) {
+ row := q.queryRow(ctx, q.getLastSessionStmt, getLastSession)
+ var i Session
+ err := row.Scan(
+ &i.ID,
+ &i.ParentSessionID,
+ &i.Title,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.Cost,
+ &i.UpdatedAt,
+ &i.CreatedAt,
+ &i.SummaryMessageID,
+ &i.Todos,
+ )
+ return i, err
+}
+
const getSessionByID = `-- name: GetSessionByID :one
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
FROM sessions
@@ -28,6 +28,12 @@ SELECT *
FROM sessions
WHERE id = ? LIMIT 1;
+-- name: GetLastSession :one
+SELECT *
+FROM sessions
+ORDER BY updated_at DESC
+LIMIT 1;
+
-- name: ListSessions :many
SELECT *
FROM sessions
@@ -17,7 +17,9 @@ const (
endpoint = "https://data.charm.land"
key = "phc_4zt4VgDWLqbYnJYEwLRxFoaTL2noNrQij0C6E8k3I0V"
- nonInteractiveEventName = "NonInteractive"
+ nonInteractiveAttrName = "NonInteractive"
+ continueSessionByIDAttrName = "ContinueSessionByID"
+ continueLastSessionAttrName = "ContinueLastSession"
)
var (
@@ -30,11 +32,19 @@ var (
Set("SHELL", filepath.Base(os.Getenv("SHELL"))).
Set("Version", version.Version).
Set("GoVersion", runtime.Version()).
- Set(nonInteractiveEventName, false)
+ Set(nonInteractiveAttrName, false)
)
func SetNonInteractive(nonInteractive bool) {
- baseProps = baseProps.Set(nonInteractiveEventName, nonInteractive)
+ baseProps = baseProps.Set(nonInteractiveAttrName, nonInteractive)
+}
+
+func SetContinueBySessionID(continueBySessionID bool) {
+ baseProps = baseProps.Set(continueSessionByIDAttrName, continueBySessionID)
+}
+
+func SetContinueLastSession(continueLastSession bool) {
+ baseProps = baseProps.Set(continueLastSessionAttrName, continueLastSession)
}
func Init() {
@@ -66,6 +66,7 @@ type Service interface {
CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
Get(ctx context.Context, id string) (Session, error)
+ GetLast(ctx context.Context) (Session, error)
List(ctx context.Context) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
@@ -166,6 +167,14 @@ func (s *service) Get(ctx context.Context, id string) (Session, error) {
return s.fromDBItem(dbSession), nil
}
+func (s *service) GetLast(ctx context.Context) (Session, error) {
+ dbSession, err := s.q.GetLastSession(ctx)
+ if err != nil {
+ return Session{}, err
+ }
+ return s.fromDBItem(dbSession), nil
+}
+
func (s *service) Save(ctx context.Context, session Session) (Session, error) {
todosJSON, err := marshalTodos(session.Todos)
if err != nil {