diff --git a/internal/message/message.go b/internal/message/message.go index 6da8827b72227602dc36c39b6a2254aba18d2b0d..02309fae70ed7f7a59d151797cc5198a0ee92550 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -28,6 +28,7 @@ type Service interface { List(ctx context.Context, sessionID string) ([]Message, error) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) ListAllUserMessages(ctx context.Context) ([]Message, error) + Copy(ctx context.Context, sessionID string, message Message) (Message, error) Delete(ctx context.Context, id string) error DeleteSessionMessages(ctx context.Context, sessionID string) error } @@ -95,6 +96,17 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes return message, nil } +func (s *service) Copy(ctx context.Context, sessionID string, message Message) (Message, error) { + params := CreateMessageParams{ + Role: message.Role, + Parts: message.Parts, + Model: message.Model, + Provider: message.Provider, + IsSummaryMessage: message.IsSummaryMessage, + } + return s.Create(ctx, sessionID, params) +} + func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error { messages, err := s.List(ctx, sessionID) if err != nil { diff --git a/internal/session/session.go b/internal/session/session.go index 905ee1cf1417b148019d9688985c1f5200209d69..c7574aee4ce1c5dccf7f88dd7a09763e34af75f7 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" + "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) @@ -42,6 +43,11 @@ type Session struct { UpdatedAt int64 } +type ForkMessageService interface { + List(ctx context.Context, sessionID string) ([]message.Message, error) + Copy(ctx context.Context, sessionID string, message message.Message) (message.Message, error) +} + type Service interface { pubsub.Subscriber[Session] Create(ctx context.Context, title string) (Session, error) @@ -52,6 +58,7 @@ type Service interface { Save(ctx context.Context, session Session) (Session, error) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error Delete(ctx context.Context, id string) error + Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error) // Agent tool session management CreateAgentToolSessionID(messageID, toolCallID string) string @@ -139,6 +146,44 @@ func (s *service) Delete(ctx context.Context, id string) error { return nil } +func (s *service) Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error) { + messages, err := messageSvc.List(ctx, sourceSessionID) + if err != nil { + return Session{}, fmt.Errorf("listing messages: %w", err) + } + + targetIndex := -1 + for i, msg := range messages { + if msg.ID == upToMessageID { + targetIndex = i + break + } + } + if targetIndex == -1 { + return Session{}, fmt.Errorf("message not found: %s", upToMessageID) + } + + sourceSession, err := s.Get(ctx, sourceSessionID) + if err != nil { + return Session{}, fmt.Errorf("getting source session: %w", err) + } + + newSession, err := s.Create(ctx, "Forked: "+sourceSession.Title) + if err != nil { + return Session{}, fmt.Errorf("creating session: %w", err) + } + + for i := 0; i <= targetIndex; i++ { + _, err = messageSvc.Copy(ctx, newSession.ID, messages[i]) + if err != nil { + _ = s.Delete(ctx, newSession.ID) + return Session{}, fmt.Errorf("copying message: %w", err) + } + } + + return newSession, nil +} + func (s *service) Get(ctx context.Context, id string) (Session, error) { dbSession, err := s.q.GetSessionByID(ctx, id) if err != nil { diff --git a/internal/session/session_fork_test.go b/internal/session/session_fork_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c0710aa47a97020dd6f5f6d13b2d0080dc8abbef --- /dev/null +++ b/internal/session/session_fork_test.go @@ -0,0 +1,83 @@ +package session + +import ( + "context" + "testing" + + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/stretchr/testify/require" +) + +func TestFork(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + defer conn.Close() + + q := db.New(conn) + svc := NewService(q, conn) + msgSvc := message.NewService(q) + + sourceSession, err := svc.Create(ctx, "Source Session") + require.NoError(t, err) + + for i := 0; i < 5; i++ { + _, err = msgSvc.Create(ctx, sourceSession.ID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Test message"}, + }, + }) + require.NoError(t, err) + } + + getMessages := func(sessionID string) []message.Message { + msgs, err := msgSvc.List(ctx, sessionID) + require.NoError(t, err) + return msgs + } + + sourceMessages := getMessages(sourceSession.ID) + require.Len(t, sourceMessages, 5) + + targetMessageID := sourceMessages[2].ID + newSession, err := svc.Fork(ctx, sourceSession.ID, targetMessageID, msgSvc) + require.NoError(t, err) + require.NotEmpty(t, newSession.ID) + require.NotEqual(t, sourceSession.ID, newSession.ID) + require.Contains(t, newSession.Title, "Forked:") + require.Contains(t, newSession.Title, sourceSession.Title) + + forkedMessages := getMessages(newSession.ID) + require.Len(t, forkedMessages, 3) + + for i, msg := range forkedMessages { + require.Equal(t, sourceMessages[i].Role, msg.Role) + require.Equal(t, sourceMessages[i].Parts[0], msg.Parts[0]) + } +} + +func TestForkInvalidMessageID(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + defer conn.Close() + + q := db.New(conn) + svc := NewService(q, conn) + msgSvc := message.NewService(q) + + sourceSession, err := svc.Create(ctx, "Source Session") + require.NoError(t, err) + + _, err = svc.Fork(ctx, sourceSession.ID, "invalid-id", msgSvc) + require.Error(t, err) + require.Contains(t, err.Error(), "message not found") +} diff --git a/internal/ui/dialog/actions.go b/internal/ui/dialog/actions.go index b5db01692437dbee4b11b77da47b68f258b090e9..e22641a780bb4c8ab8ca69ed6e75fe11d38e0576 100644 --- a/internal/ui/dialog/actions.go +++ b/internal/ui/dialog/actions.go @@ -44,6 +44,7 @@ type ActionSelectModel struct { // Messages for commands type ( ActionNewSession struct{} + ActionForkConversation struct{} ActionToggleHelp struct{} ActionToggleCompactMode struct{} ActionToggleThinking struct{} diff --git a/internal/ui/dialog/commands.go b/internal/ui/dialog/commands.go index 416f5a0131e2dc7cf36561f118daed248ceebd08..48ed977077f3f98a66017085066ece7d1d7786cc 100644 --- a/internal/ui/dialog/commands.go +++ b/internal/ui/dialog/commands.go @@ -65,18 +65,20 @@ type Commands struct { customCommands []commands.CustomCommand mcpPrompts []commands.MCPPrompt + canFork bool // whether a user message is selected for forking } var _ Dialog = (*Commands)(nil) // NewCommands creates a new commands dialog. -func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt) (*Commands, error) { +func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt, canFork bool) (*Commands, error) { c := &Commands{ com: com, selected: SystemCommands, sessionID: sessionID, customCommands: customCommands, mcpPrompts: mcpPrompts, + canFork: canFork, } help := help.New() @@ -386,9 +388,18 @@ func (c *Commands) setCommandItems(commandType CommandType) { func (c *Commands) defaultCommands() []*CommandItem { commands := []*CommandItem{ NewCommandItem(c.com.Styles, "new_session", "New Session", "ctrl+n", ActionNewSession{}), + } + + if c.canFork { + commands = append(commands, + NewCommandItem(c.com.Styles, "fork_conversation", "Fork Conversation", "", ActionForkConversation{}), + ) + } + + commands = append(commands, NewCommandItem(c.com.Styles, "switch_session", "Sessions", "ctrl+s", ActionOpenDialog{SessionsID}), NewCommandItem(c.com.Styles, "switch_model", "Switch Model", "ctrl+l", ActionOpenDialog{ModelsID}), - } + ) // Only show compact command if there's an active session if c.sessionID != "" { diff --git a/internal/ui/model/chat.go b/internal/ui/model/chat.go index 3a743edd9d1e87b643076f114b065b2eaa2b2ca5..b0972576d37bcea0e2fbff8992f884df7f3bd69a 100644 --- a/internal/ui/model/chat.go +++ b/internal/ui/model/chat.go @@ -246,6 +246,11 @@ func (m *Chat) SelectedItemInView() bool { return m.list.SelectedItemInView() } +// SelectedItem returns the currently selected item in the chat list. +func (m *Chat) SelectedItem() list.Item { + return m.list.SelectedItem() +} + func (m *Chat) isSelectable(index int) bool { item := m.list.ItemAt(index) if item == nil { diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 9eb7f01f881e70ad82597820dac8e3161f4cd684..0c8276bc03afebd5446049cb6c6b4c67d846b803 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -41,6 +41,7 @@ import ( "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/ui/completions" "github.com/charmbracelet/crush/internal/ui/dialog" + "github.com/charmbracelet/crush/internal/ui/list" "github.com/charmbracelet/crush/internal/ui/logo" "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/crush/internal/uiutil" @@ -1158,6 +1159,41 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { cmds = append(cmds, m.initializeProject()) m.dialog.CloseDialog(dialog.CommandsID) + case dialog.ActionForkConversation: + if m.isAgentBusy() { + cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait...")) + break + } + if m.session == nil { + cmds = append(cmds, uiutil.ReportWarn("No session to fork...")) + break + } + selectedItem := m.chat.SelectedItem() + if selectedItem == nil { + cmds = append(cmds, uiutil.ReportWarn("No message selected...")) + break + } + if _, ok := selectedItem.(*chat.UserMessageItem); !ok { + cmds = append(cmds, uiutil.ReportWarn("Can only fork from user messages...")) + break + } + messageID, ok := m.getMessageIDFromItem(selectedItem) + if !ok { + cmds = append(cmds, uiutil.ReportWarn("Cannot fork from selected item...")) + break + } + cmds = append(cmds, func() tea.Msg { + newSession, err := m.com.App.Sessions.Fork(context.Background(), m.session.ID, messageID, m.com.App.Messages) + if err != nil { + return uiutil.ReportError(err)() + } + return loadSessionMsg{ + session: &newSession, + files: []SessionFile{}, + } + }) + m.dialog.CloseDialog(dialog.CommandsID) + case dialog.ActionSelectModel: if m.isAgentBusy() { cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait...")) @@ -2489,6 +2525,22 @@ func (m *UI) hasSession() bool { return m.session != nil && m.session.ID != "" } +// getMessageIDFromItem gets the message ID from a selected chat item. +func (m *UI) getMessageIDFromItem(item list.Item) (string, bool) { + if toolMsg, ok := item.(chat.ToolMessageItem); ok { + return toolMsg.MessageID(), true + } + if msgItem, ok := item.(chat.MessageItem); ok { + itemID := msgItem.ID() + if strings.HasSuffix(itemID, ":assistant-info") { + baseID := strings.TrimSuffix(itemID, ":assistant-info") + return baseID, true + } + return itemID, true + } + return "", false +} + // mimeOf detects the MIME type of the given content. func mimeOf(content []byte) string { mimeBufferSize := min(512, len(content)) @@ -2699,10 +2751,17 @@ func (m *UI) openModelsDialog() tea.Cmd { // openCommandsDialog opens the commands dialog. func (m *UI) openCommandsDialog() tea.Cmd { + canFork := false + selectedItem := m.chat.SelectedItem() + if selectedItem != nil { + if _, ok := selectedItem.(*chat.UserMessageItem); ok { + canFork = true + } + } + if m.dialog.ContainsDialog(dialog.CommandsID) { - // Bring to front - m.dialog.BringToFront(dialog.CommandsID) - return nil + // Close and reopen to refresh state + m.dialog.CloseDialog(dialog.CommandsID) } sessionID := "" @@ -2710,7 +2769,7 @@ func (m *UI) openCommandsDialog() tea.Cmd { sessionID = m.session.ID } - commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts) + commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts, canFork) if err != nil { return uiutil.ReportError(err) }