1package session
2
3import (
4 "context"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/db"
8 "github.com/charmbracelet/crush/internal/message"
9 "github.com/stretchr/testify/require"
10)
11
12func TestFork(t *testing.T) {
13 t.Parallel()
14
15 ctx := context.Background()
16
17 conn, err := db.Connect(t.Context(), t.TempDir())
18 require.NoError(t, err)
19 defer conn.Close()
20
21 q := db.New(conn)
22 svc := NewService(q, conn)
23 msgSvc := message.NewService(q)
24
25 sourceSession, err := svc.Create(ctx, "Source Session")
26 require.NoError(t, err)
27
28 for range 5 {
29 _, err = msgSvc.Create(ctx, sourceSession.ID, message.CreateMessageParams{
30 Role: message.User,
31 Parts: []message.ContentPart{
32 message.TextContent{Text: "Test message"},
33 },
34 })
35 require.NoError(t, err)
36 }
37
38 getMessages := func(sessionID string) []message.Message {
39 msgs, err := msgSvc.List(ctx, sessionID)
40 require.NoError(t, err)
41 return msgs
42 }
43
44 sourceMessages := getMessages(sourceSession.ID)
45 require.Len(t, sourceMessages, 5)
46
47 targetMessageID := sourceMessages[2].ID
48 newSession, err := svc.Fork(ctx, sourceSession.ID, targetMessageID, msgSvc)
49 require.NoError(t, err)
50 require.NotEmpty(t, newSession.ID)
51 require.NotEqual(t, sourceSession.ID, newSession.ID)
52 require.Contains(t, newSession.Title, "Forked:")
53 require.Contains(t, newSession.Title, sourceSession.Title)
54
55 forkedMessages := getMessages(newSession.ID)
56 require.Len(t, forkedMessages, 2)
57
58 for i, msg := range forkedMessages {
59 require.Equal(t, sourceMessages[i].Role, msg.Role)
60 require.Equal(t, sourceMessages[i].Parts[0], msg.Parts[0])
61 }
62}
63
64func TestForkInvalidMessageID(t *testing.T) {
65 t.Parallel()
66
67 ctx := context.Background()
68
69 conn, err := db.Connect(t.Context(), t.TempDir())
70 require.NoError(t, err)
71 defer conn.Close()
72
73 q := db.New(conn)
74 svc := NewService(q, conn)
75 msgSvc := message.NewService(q)
76
77 sourceSession, err := svc.Create(ctx, "Source Session")
78 require.NoError(t, err)
79
80 _, err = svc.Fork(ctx, sourceSession.ID, "invalid-id", msgSvc)
81 require.Error(t, err)
82 require.Contains(t, err.Error(), "message not found")
83}