Detailed changes
@@ -28,6 +28,7 @@ type Service interface {
List(ctx context.Context, sessionID string) ([]Message, error)
ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
ListAllUserMessages(ctx context.Context) ([]Message, error)
+ Copy(ctx context.Context, sessionID string, message Message) (Message, error)
Delete(ctx context.Context, id string) error
DeleteSessionMessages(ctx context.Context, sessionID string) error
}
@@ -95,6 +96,17 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
return message, nil
}
+func (s *service) Copy(ctx context.Context, sessionID string, message Message) (Message, error) {
+ params := CreateMessageParams{
+ Role: message.Role,
+ Parts: message.Parts,
+ Model: message.Model,
+ Provider: message.Provider,
+ IsSummaryMessage: message.IsSummaryMessage,
+ }
+ return s.Create(ctx, sessionID, params)
+}
+
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
messages, err := s.List(ctx, sessionID)
if err != nil {
@@ -10,6 +10,7 @@ import (
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/event"
+ "github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
@@ -42,6 +43,11 @@ type Session struct {
UpdatedAt int64
}
+type ForkMessageService interface {
+ List(ctx context.Context, sessionID string) ([]message.Message, error)
+ Copy(ctx context.Context, sessionID string, message message.Message) (message.Message, error)
+}
+
type Service interface {
pubsub.Subscriber[Session]
Create(ctx context.Context, title string) (Session, error)
@@ -52,6 +58,7 @@ type Service interface {
Save(ctx context.Context, session Session) (Session, error)
UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
Delete(ctx context.Context, id string) error
+ Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error)
// Agent tool session management
CreateAgentToolSessionID(messageID, toolCallID string) string
@@ -139,6 +146,44 @@ func (s *service) Delete(ctx context.Context, id string) error {
return nil
}
+func (s *service) Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error) {
+ messages, err := messageSvc.List(ctx, sourceSessionID)
+ if err != nil {
+ return Session{}, fmt.Errorf("listing messages: %w", err)
+ }
+
+ targetIndex := -1
+ for i, msg := range messages {
+ if msg.ID == upToMessageID {
+ targetIndex = i
+ break
+ }
+ }
+ if targetIndex == -1 {
+ return Session{}, fmt.Errorf("message not found: %s", upToMessageID)
+ }
+
+ sourceSession, err := s.Get(ctx, sourceSessionID)
+ if err != nil {
+ return Session{}, fmt.Errorf("getting source session: %w", err)
+ }
+
+ newSession, err := s.Create(ctx, "Forked: "+sourceSession.Title)
+ if err != nil {
+ return Session{}, fmt.Errorf("creating session: %w", err)
+ }
+
+ for i := 0; i <= targetIndex; i++ {
+ _, err = messageSvc.Copy(ctx, newSession.ID, messages[i])
+ if err != nil {
+ _ = s.Delete(ctx, newSession.ID)
+ return Session{}, fmt.Errorf("copying message: %w", err)
+ }
+ }
+
+ return newSession, nil
+}
+
func (s *service) Get(ctx context.Context, id string) (Session, error) {
dbSession, err := s.q.GetSessionByID(ctx, id)
if err != nil {
@@ -0,0 +1,83 @@
+package session
+
+import (
+ "context"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFork(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ conn, err := db.Connect(t.Context(), t.TempDir())
+ require.NoError(t, err)
+ defer conn.Close()
+
+ q := db.New(conn)
+ svc := NewService(q, conn)
+ msgSvc := message.NewService(q)
+
+ sourceSession, err := svc.Create(ctx, "Source Session")
+ require.NoError(t, err)
+
+ for i := 0; i < 5; i++ {
+ _, err = msgSvc.Create(ctx, sourceSession.ID, message.CreateMessageParams{
+ Role: message.User,
+ Parts: []message.ContentPart{
+ message.TextContent{Text: "Test message"},
+ },
+ })
+ require.NoError(t, err)
+ }
+
+ getMessages := func(sessionID string) []message.Message {
+ msgs, err := msgSvc.List(ctx, sessionID)
+ require.NoError(t, err)
+ return msgs
+ }
+
+ sourceMessages := getMessages(sourceSession.ID)
+ require.Len(t, sourceMessages, 5)
+
+ targetMessageID := sourceMessages[2].ID
+ newSession, err := svc.Fork(ctx, sourceSession.ID, targetMessageID, msgSvc)
+ require.NoError(t, err)
+ require.NotEmpty(t, newSession.ID)
+ require.NotEqual(t, sourceSession.ID, newSession.ID)
+ require.Contains(t, newSession.Title, "Forked:")
+ require.Contains(t, newSession.Title, sourceSession.Title)
+
+ forkedMessages := getMessages(newSession.ID)
+ require.Len(t, forkedMessages, 3)
+
+ for i, msg := range forkedMessages {
+ require.Equal(t, sourceMessages[i].Role, msg.Role)
+ require.Equal(t, sourceMessages[i].Parts[0], msg.Parts[0])
+ }
+}
+
+func TestForkInvalidMessageID(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ conn, err := db.Connect(t.Context(), t.TempDir())
+ require.NoError(t, err)
+ defer conn.Close()
+
+ q := db.New(conn)
+ svc := NewService(q, conn)
+ msgSvc := message.NewService(q)
+
+ sourceSession, err := svc.Create(ctx, "Source Session")
+ require.NoError(t, err)
+
+ _, err = svc.Fork(ctx, sourceSession.ID, "invalid-id", msgSvc)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "message not found")
+}
@@ -44,6 +44,7 @@ type ActionSelectModel struct {
// Messages for commands
type (
ActionNewSession struct{}
+ ActionForkConversation struct{}
ActionToggleHelp struct{}
ActionToggleCompactMode struct{}
ActionToggleThinking struct{}
@@ -65,18 +65,20 @@ type Commands struct {
customCommands []commands.CustomCommand
mcpPrompts []commands.MCPPrompt
+ canFork bool // whether a user message is selected for forking
}
var _ Dialog = (*Commands)(nil)
// NewCommands creates a new commands dialog.
-func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt) (*Commands, error) {
+func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt, canFork bool) (*Commands, error) {
c := &Commands{
com: com,
selected: SystemCommands,
sessionID: sessionID,
customCommands: customCommands,
mcpPrompts: mcpPrompts,
+ canFork: canFork,
}
help := help.New()
@@ -386,9 +388,18 @@ func (c *Commands) setCommandItems(commandType CommandType) {
func (c *Commands) defaultCommands() []*CommandItem {
commands := []*CommandItem{
NewCommandItem(c.com.Styles, "new_session", "New Session", "ctrl+n", ActionNewSession{}),
+ }
+
+ if c.canFork {
+ commands = append(commands,
+ NewCommandItem(c.com.Styles, "fork_conversation", "Fork Conversation", "", ActionForkConversation{}),
+ )
+ }
+
+ commands = append(commands,
NewCommandItem(c.com.Styles, "switch_session", "Sessions", "ctrl+s", ActionOpenDialog{SessionsID}),
NewCommandItem(c.com.Styles, "switch_model", "Switch Model", "ctrl+l", ActionOpenDialog{ModelsID}),
- }
+ )
// Only show compact command if there's an active session
if c.sessionID != "" {
@@ -246,6 +246,11 @@ func (m *Chat) SelectedItemInView() bool {
return m.list.SelectedItemInView()
}
+// SelectedItem returns the currently selected item in the chat list.
+func (m *Chat) SelectedItem() list.Item {
+ return m.list.SelectedItem()
+}
+
func (m *Chat) isSelectable(index int) bool {
item := m.list.ItemAt(index)
if item == nil {
@@ -41,6 +41,7 @@ import (
"github.com/charmbracelet/crush/internal/ui/common"
"github.com/charmbracelet/crush/internal/ui/completions"
"github.com/charmbracelet/crush/internal/ui/dialog"
+ "github.com/charmbracelet/crush/internal/ui/list"
"github.com/charmbracelet/crush/internal/ui/logo"
"github.com/charmbracelet/crush/internal/ui/styles"
"github.com/charmbracelet/crush/internal/uiutil"
@@ -1158,6 +1159,41 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
cmds = append(cmds, m.initializeProject())
m.dialog.CloseDialog(dialog.CommandsID)
+ case dialog.ActionForkConversation:
+ if m.isAgentBusy() {
+ cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait..."))
+ break
+ }
+ if m.session == nil {
+ cmds = append(cmds, uiutil.ReportWarn("No session to fork..."))
+ break
+ }
+ selectedItem := m.chat.SelectedItem()
+ if selectedItem == nil {
+ cmds = append(cmds, uiutil.ReportWarn("No message selected..."))
+ break
+ }
+ if _, ok := selectedItem.(*chat.UserMessageItem); !ok {
+ cmds = append(cmds, uiutil.ReportWarn("Can only fork from user messages..."))
+ break
+ }
+ messageID, ok := m.getMessageIDFromItem(selectedItem)
+ if !ok {
+ cmds = append(cmds, uiutil.ReportWarn("Cannot fork from selected item..."))
+ break
+ }
+ cmds = append(cmds, func() tea.Msg {
+ newSession, err := m.com.App.Sessions.Fork(context.Background(), m.session.ID, messageID, m.com.App.Messages)
+ if err != nil {
+ return uiutil.ReportError(err)()
+ }
+ return loadSessionMsg{
+ session: &newSession,
+ files: []SessionFile{},
+ }
+ })
+ m.dialog.CloseDialog(dialog.CommandsID)
+
case dialog.ActionSelectModel:
if m.isAgentBusy() {
cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait..."))
@@ -2489,6 +2525,22 @@ func (m *UI) hasSession() bool {
return m.session != nil && m.session.ID != ""
}
+// getMessageIDFromItem gets the message ID from a selected chat item.
+func (m *UI) getMessageIDFromItem(item list.Item) (string, bool) {
+ if toolMsg, ok := item.(chat.ToolMessageItem); ok {
+ return toolMsg.MessageID(), true
+ }
+ if msgItem, ok := item.(chat.MessageItem); ok {
+ itemID := msgItem.ID()
+ if strings.HasSuffix(itemID, ":assistant-info") {
+ baseID := strings.TrimSuffix(itemID, ":assistant-info")
+ return baseID, true
+ }
+ return itemID, true
+ }
+ return "", false
+}
+
// mimeOf detects the MIME type of the given content.
func mimeOf(content []byte) string {
mimeBufferSize := min(512, len(content))
@@ -2699,10 +2751,17 @@ func (m *UI) openModelsDialog() tea.Cmd {
// openCommandsDialog opens the commands dialog.
func (m *UI) openCommandsDialog() tea.Cmd {
+ canFork := false
+ selectedItem := m.chat.SelectedItem()
+ if selectedItem != nil {
+ if _, ok := selectedItem.(*chat.UserMessageItem); ok {
+ canFork = true
+ }
+ }
+
if m.dialog.ContainsDialog(dialog.CommandsID) {
- // Bring to front
- m.dialog.BringToFront(dialog.CommandsID)
- return nil
+ // Close and reopen to refresh state
+ m.dialog.CloseDialog(dialog.CommandsID)
}
sessionID := ""
@@ -2710,7 +2769,7 @@ func (m *UI) openCommandsDialog() tea.Cmd {
sessionID = m.session.ID
}
- commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts)
+ commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts, canFork)
if err != nil {
return uiutil.ReportError(err)
}