feat: fork conversation

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/message/message.go           | 12 ++++
internal/session/session.go           | 45 +++++++++++++++
internal/session/session_fork_test.go | 83 +++++++++++++++++++++++++++++
internal/ui/dialog/actions.go         |  1 
internal/ui/dialog/commands.go        | 15 ++++
internal/ui/model/chat.go             |  5 +
internal/ui/model/ui.go               | 67 ++++++++++++++++++++++-
7 files changed, 222 insertions(+), 6 deletions(-)

Detailed changes

internal/message/message.go 🔗

@@ -28,6 +28,7 @@ type Service interface {
 	List(ctx context.Context, sessionID string) ([]Message, error)
 	ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
 	ListAllUserMessages(ctx context.Context) ([]Message, error)
+	Copy(ctx context.Context, sessionID string, message Message) (Message, error)
 	Delete(ctx context.Context, id string) error
 	DeleteSessionMessages(ctx context.Context, sessionID string) error
 }
@@ -95,6 +96,17 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
 	return message, nil
 }
 
+func (s *service) Copy(ctx context.Context, sessionID string, message Message) (Message, error) {
+	params := CreateMessageParams{
+		Role:             message.Role,
+		Parts:            message.Parts,
+		Model:            message.Model,
+		Provider:         message.Provider,
+		IsSummaryMessage: message.IsSummaryMessage,
+	}
+	return s.Create(ctx, sessionID, params)
+}
+
 func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
 	messages, err := s.List(ctx, sessionID)
 	if err != nil {

internal/session/session.go 🔗

@@ -10,6 +10,7 @@ import (
 
 	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/event"
+	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/pubsub"
 	"github.com/google/uuid"
 )
@@ -42,6 +43,11 @@ type Session struct {
 	UpdatedAt        int64
 }
 
+type ForkMessageService interface {
+	List(ctx context.Context, sessionID string) ([]message.Message, error)
+	Copy(ctx context.Context, sessionID string, message message.Message) (message.Message, error)
+}
+
 type Service interface {
 	pubsub.Subscriber[Session]
 	Create(ctx context.Context, title string) (Session, error)
@@ -52,6 +58,7 @@ type Service interface {
 	Save(ctx context.Context, session Session) (Session, error)
 	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 	Delete(ctx context.Context, id string) error
+	Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error)
 
 	// Agent tool session management
 	CreateAgentToolSessionID(messageID, toolCallID string) string
@@ -139,6 +146,44 @@ func (s *service) Delete(ctx context.Context, id string) error {
 	return nil
 }
 
+func (s *service) Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error) {
+	messages, err := messageSvc.List(ctx, sourceSessionID)
+	if err != nil {
+		return Session{}, fmt.Errorf("listing messages: %w", err)
+	}
+
+	targetIndex := -1
+	for i, msg := range messages {
+		if msg.ID == upToMessageID {
+			targetIndex = i
+			break
+		}
+	}
+	if targetIndex == -1 {
+		return Session{}, fmt.Errorf("message not found: %s", upToMessageID)
+	}
+
+	sourceSession, err := s.Get(ctx, sourceSessionID)
+	if err != nil {
+		return Session{}, fmt.Errorf("getting source session: %w", err)
+	}
+
+	newSession, err := s.Create(ctx, "Forked: "+sourceSession.Title)
+	if err != nil {
+		return Session{}, fmt.Errorf("creating session: %w", err)
+	}
+
+	for i := 0; i <= targetIndex; i++ {
+		_, err = messageSvc.Copy(ctx, newSession.ID, messages[i])
+		if err != nil {
+			_ = s.Delete(ctx, newSession.ID)
+			return Session{}, fmt.Errorf("copying message: %w", err)
+		}
+	}
+
+	return newSession, nil
+}
+
 func (s *service) Get(ctx context.Context, id string) (Session, error) {
 	dbSession, err := s.q.GetSessionByID(ctx, id)
 	if err != nil {

internal/session/session_fork_test.go 🔗

@@ -0,0 +1,83 @@
+package session
+
+import (
+	"context"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/db"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/stretchr/testify/require"
+)
+
+func TestFork(t *testing.T) {
+	t.Parallel()
+
+	ctx := context.Background()
+
+	conn, err := db.Connect(t.Context(), t.TempDir())
+	require.NoError(t, err)
+	defer conn.Close()
+
+	q := db.New(conn)
+	svc := NewService(q, conn)
+	msgSvc := message.NewService(q)
+
+	sourceSession, err := svc.Create(ctx, "Source Session")
+	require.NoError(t, err)
+
+	for i := 0; i < 5; i++ {
+		_, err = msgSvc.Create(ctx, sourceSession.ID, message.CreateMessageParams{
+			Role: message.User,
+			Parts: []message.ContentPart{
+				message.TextContent{Text: "Test message"},
+			},
+		})
+		require.NoError(t, err)
+	}
+
+	getMessages := func(sessionID string) []message.Message {
+		msgs, err := msgSvc.List(ctx, sessionID)
+		require.NoError(t, err)
+		return msgs
+	}
+
+	sourceMessages := getMessages(sourceSession.ID)
+	require.Len(t, sourceMessages, 5)
+
+	targetMessageID := sourceMessages[2].ID
+	newSession, err := svc.Fork(ctx, sourceSession.ID, targetMessageID, msgSvc)
+	require.NoError(t, err)
+	require.NotEmpty(t, newSession.ID)
+	require.NotEqual(t, sourceSession.ID, newSession.ID)
+	require.Contains(t, newSession.Title, "Forked:")
+	require.Contains(t, newSession.Title, sourceSession.Title)
+
+	forkedMessages := getMessages(newSession.ID)
+	require.Len(t, forkedMessages, 3)
+
+	for i, msg := range forkedMessages {
+		require.Equal(t, sourceMessages[i].Role, msg.Role)
+		require.Equal(t, sourceMessages[i].Parts[0], msg.Parts[0])
+	}
+}
+
+func TestForkInvalidMessageID(t *testing.T) {
+	t.Parallel()
+
+	ctx := context.Background()
+
+	conn, err := db.Connect(t.Context(), t.TempDir())
+	require.NoError(t, err)
+	defer conn.Close()
+
+	q := db.New(conn)
+	svc := NewService(q, conn)
+	msgSvc := message.NewService(q)
+
+	sourceSession, err := svc.Create(ctx, "Source Session")
+	require.NoError(t, err)
+
+	_, err = svc.Fork(ctx, sourceSession.ID, "invalid-id", msgSvc)
+	require.Error(t, err)
+	require.Contains(t, err.Error(), "message not found")
+}

internal/ui/dialog/actions.go 🔗

@@ -44,6 +44,7 @@ type ActionSelectModel struct {
 // Messages for commands
 type (
 	ActionNewSession        struct{}
+	ActionForkConversation  struct{}
 	ActionToggleHelp        struct{}
 	ActionToggleCompactMode struct{}
 	ActionToggleThinking    struct{}

internal/ui/dialog/commands.go 🔗

@@ -65,18 +65,20 @@ type Commands struct {
 
 	customCommands []commands.CustomCommand
 	mcpPrompts     []commands.MCPPrompt
+	canFork        bool // whether a user message is selected for forking
 }
 
 var _ Dialog = (*Commands)(nil)
 
 // NewCommands creates a new commands dialog.
-func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt) (*Commands, error) {
+func NewCommands(com *common.Common, sessionID string, customCommands []commands.CustomCommand, mcpPrompts []commands.MCPPrompt, canFork bool) (*Commands, error) {
 	c := &Commands{
 		com:            com,
 		selected:       SystemCommands,
 		sessionID:      sessionID,
 		customCommands: customCommands,
 		mcpPrompts:     mcpPrompts,
+		canFork:        canFork,
 	}
 
 	help := help.New()
@@ -386,9 +388,18 @@ func (c *Commands) setCommandItems(commandType CommandType) {
 func (c *Commands) defaultCommands() []*CommandItem {
 	commands := []*CommandItem{
 		NewCommandItem(c.com.Styles, "new_session", "New Session", "ctrl+n", ActionNewSession{}),
+	}
+
+	if c.canFork {
+		commands = append(commands,
+			NewCommandItem(c.com.Styles, "fork_conversation", "Fork Conversation", "", ActionForkConversation{}),
+		)
+	}
+
+	commands = append(commands,
 		NewCommandItem(c.com.Styles, "switch_session", "Sessions", "ctrl+s", ActionOpenDialog{SessionsID}),
 		NewCommandItem(c.com.Styles, "switch_model", "Switch Model", "ctrl+l", ActionOpenDialog{ModelsID}),
-	}
+	)
 
 	// Only show compact command if there's an active session
 	if c.sessionID != "" {

internal/ui/model/chat.go 🔗

@@ -246,6 +246,11 @@ func (m *Chat) SelectedItemInView() bool {
 	return m.list.SelectedItemInView()
 }
 
+// SelectedItem returns the currently selected item in the chat list.
+func (m *Chat) SelectedItem() list.Item {
+	return m.list.SelectedItem()
+}
+
 func (m *Chat) isSelectable(index int) bool {
 	item := m.list.ItemAt(index)
 	if item == nil {

internal/ui/model/ui.go 🔗

@@ -41,6 +41,7 @@ import (
 	"github.com/charmbracelet/crush/internal/ui/common"
 	"github.com/charmbracelet/crush/internal/ui/completions"
 	"github.com/charmbracelet/crush/internal/ui/dialog"
+	"github.com/charmbracelet/crush/internal/ui/list"
 	"github.com/charmbracelet/crush/internal/ui/logo"
 	"github.com/charmbracelet/crush/internal/ui/styles"
 	"github.com/charmbracelet/crush/internal/uiutil"
@@ -1158,6 +1159,41 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 		cmds = append(cmds, m.initializeProject())
 		m.dialog.CloseDialog(dialog.CommandsID)
 
+	case dialog.ActionForkConversation:
+		if m.isAgentBusy() {
+			cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait..."))
+			break
+		}
+		if m.session == nil {
+			cmds = append(cmds, uiutil.ReportWarn("No session to fork..."))
+			break
+		}
+		selectedItem := m.chat.SelectedItem()
+		if selectedItem == nil {
+			cmds = append(cmds, uiutil.ReportWarn("No message selected..."))
+			break
+		}
+		if _, ok := selectedItem.(*chat.UserMessageItem); !ok {
+			cmds = append(cmds, uiutil.ReportWarn("Can only fork from user messages..."))
+			break
+		}
+		messageID, ok := m.getMessageIDFromItem(selectedItem)
+		if !ok {
+			cmds = append(cmds, uiutil.ReportWarn("Cannot fork from selected item..."))
+			break
+		}
+		cmds = append(cmds, func() tea.Msg {
+			newSession, err := m.com.App.Sessions.Fork(context.Background(), m.session.ID, messageID, m.com.App.Messages)
+			if err != nil {
+				return uiutil.ReportError(err)()
+			}
+			return loadSessionMsg{
+				session: &newSession,
+				files:   []SessionFile{},
+			}
+		})
+		m.dialog.CloseDialog(dialog.CommandsID)
+
 	case dialog.ActionSelectModel:
 		if m.isAgentBusy() {
 			cmds = append(cmds, uiutil.ReportWarn("Agent is busy, please wait..."))
@@ -2489,6 +2525,22 @@ func (m *UI) hasSession() bool {
 	return m.session != nil && m.session.ID != ""
 }
 
+// getMessageIDFromItem gets the message ID from a selected chat item.
+func (m *UI) getMessageIDFromItem(item list.Item) (string, bool) {
+	if toolMsg, ok := item.(chat.ToolMessageItem); ok {
+		return toolMsg.MessageID(), true
+	}
+	if msgItem, ok := item.(chat.MessageItem); ok {
+		itemID := msgItem.ID()
+		if strings.HasSuffix(itemID, ":assistant-info") {
+			baseID := strings.TrimSuffix(itemID, ":assistant-info")
+			return baseID, true
+		}
+		return itemID, true
+	}
+	return "", false
+}
+
 // mimeOf detects the MIME type of the given content.
 func mimeOf(content []byte) string {
 	mimeBufferSize := min(512, len(content))
@@ -2699,10 +2751,17 @@ func (m *UI) openModelsDialog() tea.Cmd {
 
 // openCommandsDialog opens the commands dialog.
 func (m *UI) openCommandsDialog() tea.Cmd {
+	canFork := false
+	selectedItem := m.chat.SelectedItem()
+	if selectedItem != nil {
+		if _, ok := selectedItem.(*chat.UserMessageItem); ok {
+			canFork = true
+		}
+	}
+
 	if m.dialog.ContainsDialog(dialog.CommandsID) {
-		// Bring to front
-		m.dialog.BringToFront(dialog.CommandsID)
-		return nil
+		// Close and reopen to refresh state
+		m.dialog.CloseDialog(dialog.CommandsID)
 	}
 
 	sessionID := ""
@@ -2710,7 +2769,7 @@ func (m *UI) openCommandsDialog() tea.Cmd {
 		sessionID = m.session.ID
 	}
 
-	commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts)
+	commands, err := dialog.NewCommands(m.com, sessionID, m.customCommands, m.mcpPrompts, canFork)
 	if err != nil {
 		return uiutil.ReportError(err)
 	}