diff --git a/Taskfile.yaml b/Taskfile.yaml index 476626fde4f0ed33d26fa20c2dc8b00ecd557af6..38e8a16313d17b9b1826ce4b6f055d39537916ec 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -184,3 +184,8 @@ tasks: - go get charm.land/fantasy - go get charm.land/catwalk - go mod tidy + + sqlc: + desc: Generate code using SQLC + cmds: + - sqlc generate diff --git a/internal/app/app.go b/internal/app/app.go index 8ed3e2e41cb2b235771eba24c3b59945f73cdfda..a3828891978c1b83429036799ab588d20f672852 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -157,9 +157,40 @@ func (app *App) AgentNotifications() *pubsub.Broker[notify.Notification] { return app.agentNotifications } +// resolveSession resolves which session to use for a non-interactive run +// If continueSessionID is set, it looks up that session by ID +// If useLast is set, it returns the most recently updated top-level session +// Otherwise, it creates a new session +func (app *App) resolveSession(ctx context.Context, continueSessionID string, useLast bool) (session.Session, error) { + switch { + case continueSessionID != "": + if app.Sessions.IsAgentToolSession(continueSessionID) { + return session.Session{}, fmt.Errorf("cannot continue an agent tool session: %s", continueSessionID) + } + sess, err := app.Sessions.Get(ctx, continueSessionID) + if err != nil { + return session.Session{}, fmt.Errorf("session not found: %s", continueSessionID) + } + if sess.ParentSessionID != "" { + return session.Session{}, fmt.Errorf("cannot continue a child session: %s", continueSessionID) + } + return sess, nil + + case useLast: + sess, err := app.Sessions.GetLast(ctx) + if err != nil { + return session.Session{}, fmt.Errorf("no sessions found to continue") + } + return sess, nil + + default: + return app.Sessions.Create(ctx, agent.DefaultSessionName) + } +} + // RunNonInteractive runs the application in non-interactive mode with the // given prompt, printing to stdout. -func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool) error { +func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool, continueSessionID string, useLast bool) error { slog.Info("Running in non-interactive mode") ctx, cancel := context.WithCancel(ctx) @@ -227,11 +258,16 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, defer stopSpinner() - sess, err := app.Sessions.Create(ctx, agent.DefaultSessionName) + sess, err := app.resolveSession(ctx, continueSessionID, useLast) if err != nil { return fmt.Errorf("failed to create session for non-interactive mode: %w", err) } - slog.Info("Created session for non-interactive run", "session_id", sess.ID) + + if continueSessionID != "" || useLast { + slog.Info("Continuing session for non-interactive run", "session_id", sess.ID) + } else { + slog.Info("Created session for non-interactive run", "session_id", sess.ID) + } // Automatically approve all permission requests for this non-interactive // session. diff --git a/internal/app/resolve_session_test.go b/internal/app/resolve_session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9b0c7af736fa9637c095c7851da3460bacf737a2 --- /dev/null +++ b/internal/app/resolve_session_test.go @@ -0,0 +1,174 @@ +package app + +import ( + "context" + "database/sql" + "fmt" + "strings" + "testing" + + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +// mockSessionService is a minimal mock of session.Service for testing resolveSession. +type mockSessionService struct { + sessions []session.Session + created []session.Session +} + +func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] { + return make(chan pubsub.Event[session.Session]) +} + +func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) { + s := session.Session{ID: "new-session-id", Title: title} + m.created = append(m.created, s) + return s, nil +} + +func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) { + return session.Session{}, nil +} + +func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) { + return session.Session{}, nil +} + +func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) { + for _, s := range m.sessions { + if s.ID == id { + return s, nil + } + } + return session.Session{}, sql.ErrNoRows +} + +func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) { + if len(m.sessions) > 0 { + return m.sessions[0], nil + } + return session.Session{}, sql.ErrNoRows +} + +func (m *mockSessionService) List(context.Context) ([]session.Session, error) { + return m.sessions, nil +} + +func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) { + return s, nil +} + +func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error { + return nil +} + +func (m *mockSessionService) Rename(context.Context, string, string) error { + return nil +} + +func (m *mockSessionService) Delete(context.Context, string) error { + return nil +} + +func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string { + return fmt.Sprintf("%s$$%s", messageID, toolCallID) +} + +func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) { + parts := strings.Split(sessionID, "$$") + if len(parts) != 2 { + return "", "", false + } + return parts[0], parts[1], true +} + +func (m *mockSessionService) IsAgentToolSession(sessionID string) bool { + _, _, ok := m.ParseAgentToolSessionID(sessionID) + return ok +} + +func newTestApp(sessions session.Service) *App { + return &App{Sessions: sessions} +} + +func TestResolveSession_NewSession(t *testing.T) { + mock := &mockSessionService{} + app := newTestApp(mock) + + sess, err := app.resolveSession(t.Context(), "", false) + require.NoError(t, err) + require.Equal(t, "new-session-id", sess.ID) + require.Len(t, mock.created, 1) +} + +func TestResolveSession_ContinueByID(t *testing.T) { + mock := &mockSessionService{ + sessions: []session.Session{ + {ID: "existing-id", Title: "Old session"}, + }, + } + app := newTestApp(mock) + + sess, err := app.resolveSession(t.Context(), "existing-id", false) + require.NoError(t, err) + require.Equal(t, "existing-id", sess.ID) + require.Equal(t, "Old session", sess.Title) + require.Empty(t, mock.created) +} + +func TestResolveSession_ContinueByID_NotFound(t *testing.T) { + mock := &mockSessionService{} + app := newTestApp(mock) + + _, err := app.resolveSession(t.Context(), "nonexistent", false) + require.Error(t, err) + require.Contains(t, err.Error(), "session not found") +} + +func TestResolveSession_ContinueByID_ChildSession(t *testing.T) { + mock := &mockSessionService{ + sessions: []session.Session{ + {ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"}, + }, + } + app := newTestApp(mock) + + _, err := app.resolveSession(t.Context(), "child-id", false) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot continue a child session") +} + +func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) { + mock := &mockSessionService{} + app := newTestApp(mock) + + _, err := app.resolveSession(t.Context(), "msg123$$tool456", false) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot continue an agent tool session") +} + +func TestResolveSession_Last(t *testing.T) { + mock := &mockSessionService{ + sessions: []session.Session{ + {ID: "most-recent", Title: "Latest session"}, + {ID: "older", Title: "Older session"}, + }, + } + app := newTestApp(mock) + + sess, err := app.resolveSession(t.Context(), "", true) + require.NoError(t, err) + require.Equal(t, "most-recent", sess.ID) + require.Empty(t, mock.created) +} + +func TestResolveSession_Last_NoSessions(t *testing.T) { + mock := &mockSessionService{} + app := newTestApp(mock) + + _, err := app.resolveSession(t.Context(), "", true) + require.Error(t, err) + require.Contains(t, err.Error(), "no sessions found") +} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 1119a83f993d08b3a2206104912ec993ec37a9e3..8215438ddef6d06d1a7a3bbb863fc24935835297 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -36,12 +36,23 @@ crush run --quiet "Generate a README for this project" # Run in verbose mode (show logs) crush run --verbose "Generate a README for this project" + +# Continue a previous session +crush run --session {session-id} "Follow up on your last response" + +# Continue the most recent session +crush run --continue "Follow up on your last response" + `, RunE: func(cmd *cobra.Command, args []string) error { - quiet, _ := cmd.Flags().GetBool("quiet") - verbose, _ := cmd.Flags().GetBool("verbose") - largeModel, _ := cmd.Flags().GetString("model") - smallModel, _ := cmd.Flags().GetString("small-model") + var ( + quiet, _ = cmd.Flags().GetBool("quiet") + verbose, _ = cmd.Flags().GetBool("verbose") + largeModel, _ = cmd.Flags().GetString("model") + smallModel, _ = cmd.Flags().GetString("small-model") + sessionID, _ = cmd.Flags().GetString("session") + useLast, _ = cmd.Flags().GetBool("continue") + ) // Cancel on SIGINT or SIGTERM. ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) @@ -53,6 +64,14 @@ crush run --verbose "Generate a README for this project" } defer app.Shutdown() + if sessionID != "" { + sess, err := resolveSessionID(ctx, app.Sessions, sessionID) + if err != nil { + return err + } + sessionID = sess.ID + } + if !app.Config().IsConfigured() { return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") } @@ -76,7 +95,14 @@ crush run --verbose "Generate a README for this project" event.SetNonInteractive(true) event.AppInitialized() - return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose) + switch { + case sessionID != "": + event.SetContinueBySessionID(true) + case useLast: + event.SetContinueLastSession(true) + } + + return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast) }, } @@ -85,4 +111,7 @@ func init() { runCmd.Flags().BoolP("verbose", "v", false, "Show logs") runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers") runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider") + runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID") + runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session") + runCmd.MarkFlagsMutuallyExclusive("session", "continue") } diff --git a/internal/db/db.go b/internal/db/db.go index dbde2e493eea4c262aef55ef7dcadd904a1b9d65..fa3c5ac5aad27ab1929e306cd50fdb7dba493ea0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -63,6 +63,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getHourDayHeatmapStmt, err = db.PrepareContext(ctx, getHourDayHeatmap); err != nil { return nil, fmt.Errorf("error preparing query GetHourDayHeatmap: %w", err) } + if q.getLastSessionStmt, err = db.PrepareContext(ctx, getLastSession); err != nil { + return nil, fmt.Errorf("error preparing query GetLastSession: %w", err) + } if q.getMessageStmt, err = db.PrepareContext(ctx, getMessage); err != nil { return nil, fmt.Errorf("error preparing query GetMessage: %w", err) } @@ -202,6 +205,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getHourDayHeatmapStmt: %w", cerr) } } + if q.getLastSessionStmt != nil { + if cerr := q.getLastSessionStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getLastSessionStmt: %w", cerr) + } + } if q.getMessageStmt != nil { if cerr := q.getMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getMessageStmt: %w", cerr) @@ -369,6 +377,7 @@ type Queries struct { getFileByPathAndSessionStmt *sql.Stmt getFileReadStmt *sql.Stmt getHourDayHeatmapStmt *sql.Stmt + getLastSessionStmt *sql.Stmt getMessageStmt *sql.Stmt getRecentActivityStmt *sql.Stmt getSessionByIDStmt *sql.Stmt @@ -411,6 +420,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, getFileReadStmt: q.getFileReadStmt, getHourDayHeatmapStmt: q.getHourDayHeatmapStmt, + getLastSessionStmt: q.getLastSessionStmt, getMessageStmt: q.getMessageStmt, getRecentActivityStmt: q.getRecentActivityStmt, getSessionByIDStmt: q.getSessionByIDStmt, diff --git a/internal/db/querier.go b/internal/db/querier.go index ae91927aedf797f84f347e7e14a93327120a847e..9031505a3db825f2c21d83e005046323bde3a6c2 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -22,6 +22,7 @@ type Querier interface { GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error) GetHourDayHeatmap(ctx context.Context) ([]GetHourDayHeatmapRow, error) + GetLastSession(ctx context.Context) (Session, error) GetMessage(ctx context.Context, id string) (Message, error) GetRecentActivity(ctx context.Context) ([]GetRecentActivityRow, error) GetSessionByID(ctx context.Context, id string) (Session, error) diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index bdcddd01d9bdb95034a9a669e6881eed661dee10..685948e60e84ec4df66e4d5d1c9645a9ff1fb43f 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -83,6 +83,32 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error { return err } +const getLastSession = `-- name: GetLastSession :one +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +FROM sessions +ORDER BY updated_at DESC +LIMIT 1 +` + +func (q *Queries) GetLastSession(ctx context.Context) (Session, error) { + row := q.queryRow(ctx, q.getLastSessionStmt, getLastSession) + var i Session + err := row.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + &i.Todos, + ) + return i, err +} + const getSessionByID = `-- name: GetSessionByID :one SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos FROM sessions diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index 0e170fdeb270041c035c7f2ea24aaa4b571b4387..44c1609ecfbc3867bea827088fcbcff6e718427b 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -28,6 +28,12 @@ SELECT * FROM sessions WHERE id = ? LIMIT 1; +-- name: GetLastSession :one +SELECT * +FROM sessions +ORDER BY updated_at DESC +LIMIT 1; + -- name: ListSessions :many SELECT * FROM sessions diff --git a/internal/event/event.go b/internal/event/event.go index 516df804fbbca1bc03212b2c3cf26a38efab6979..a10b5d82b7fdf13ccf02fd0967cab2ce65e20661 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -17,7 +17,9 @@ const ( endpoint = "https://data.charm.land" key = "phc_4zt4VgDWLqbYnJYEwLRxFoaTL2noNrQij0C6E8k3I0V" - nonInteractiveEventName = "NonInteractive" + nonInteractiveAttrName = "NonInteractive" + continueSessionByIDAttrName = "ContinueSessionByID" + continueLastSessionAttrName = "ContinueLastSession" ) var ( @@ -30,11 +32,19 @@ var ( Set("SHELL", filepath.Base(os.Getenv("SHELL"))). Set("Version", version.Version). Set("GoVersion", runtime.Version()). - Set(nonInteractiveEventName, false) + Set(nonInteractiveAttrName, false) ) func SetNonInteractive(nonInteractive bool) { - baseProps = baseProps.Set(nonInteractiveEventName, nonInteractive) + baseProps = baseProps.Set(nonInteractiveAttrName, nonInteractive) +} + +func SetContinueBySessionID(continueBySessionID bool) { + baseProps = baseProps.Set(continueSessionByIDAttrName, continueBySessionID) +} + +func SetContinueLastSession(continueLastSession bool) { + baseProps = baseProps.Set(continueLastSessionAttrName, continueLastSession) } func Init() { diff --git a/internal/session/session.go b/internal/session/session.go index 834243b62aae6266290147ca0d0270a6069e34b3..66bd9f4c9a12916d02c6d22ed7d51f81d74efdfd 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -66,6 +66,7 @@ type Service interface { CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) Get(ctx context.Context, id string) (Session, error) + GetLast(ctx context.Context) (Session, error) List(ctx context.Context) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error @@ -166,6 +167,14 @@ func (s *service) Get(ctx context.Context, id string) (Session, error) { return s.fromDBItem(dbSession), nil } +func (s *service) GetLast(ctx context.Context) (Session, error) { + dbSession, err := s.q.GetLastSession(ctx) + if err != nil { + return Session{}, err + } + return s.fromDBItem(dbSession), nil +} + func (s *service) Save(ctx context.Context, session Session) (Session, error) { todosJSON, err := marshalTodos(session.Todos) if err != nil {