session_fork_test.go

 1package session
 2
 3import (
 4	"context"
 5	"testing"
 6
 7	"github.com/charmbracelet/crush/internal/db"
 8	"github.com/charmbracelet/crush/internal/message"
 9	"github.com/stretchr/testify/require"
10)
11
12func TestFork(t *testing.T) {
13	t.Parallel()
14
15	ctx := context.Background()
16
17	conn, err := db.Connect(t.Context(), t.TempDir())
18	require.NoError(t, err)
19	defer conn.Close()
20
21	q := db.New(conn)
22	svc := NewService(q, conn)
23	msgSvc := message.NewService(q)
24
25	sourceSession, err := svc.Create(ctx, "Source Session")
26	require.NoError(t, err)
27
28	for range 5 {
29		_, err = msgSvc.Create(ctx, sourceSession.ID, message.CreateMessageParams{
30			Role: message.User,
31			Parts: []message.ContentPart{
32				message.TextContent{Text: "Test message"},
33			},
34		})
35		require.NoError(t, err)
36	}
37
38	getMessages := func(sessionID string) []message.Message {
39		msgs, err := msgSvc.List(ctx, sessionID)
40		require.NoError(t, err)
41		return msgs
42	}
43
44	sourceMessages := getMessages(sourceSession.ID)
45	require.Len(t, sourceMessages, 5)
46
47	targetMessageID := sourceMessages[2].ID
48	newSession, err := svc.Fork(ctx, sourceSession.ID, targetMessageID, msgSvc)
49	require.NoError(t, err)
50	require.NotEmpty(t, newSession.ID)
51	require.NotEqual(t, sourceSession.ID, newSession.ID)
52	require.Contains(t, newSession.Title, "Forked:")
53	require.Contains(t, newSession.Title, sourceSession.Title)
54
55	forkedMessages := getMessages(newSession.ID)
56	require.Len(t, forkedMessages, 2)
57
58	for i, msg := range forkedMessages {
59		require.Equal(t, sourceMessages[i].Role, msg.Role)
60		require.Equal(t, sourceMessages[i].Parts[0], msg.Parts[0])
61	}
62}
63
64func TestForkInvalidMessageID(t *testing.T) {
65	t.Parallel()
66
67	ctx := context.Background()
68
69	conn, err := db.Connect(t.Context(), t.TempDir())
70	require.NoError(t, err)
71	defer conn.Close()
72
73	q := db.New(conn)
74	svc := NewService(q, conn)
75	msgSvc := message.NewService(q)
76
77	sourceSession, err := svc.Create(ctx, "Source Session")
78	require.NoError(t, err)
79
80	_, err = svc.Fork(ctx, sourceSession.ID, "invalid-id", msgSvc)
81	require.Error(t, err)
82	require.Contains(t, err.Error(), "message not found")
83}