session_test.go

 1package session
 2
 3import (
 4	"testing"
 5
 6	"github.com/charmbracelet/crush/internal/db"
 7	"github.com/stretchr/testify/require"
 8)
 9
10func TestEstimatedUsageStateSurvivesFetchModifySave(t *testing.T) {
11	dataDir := t.TempDir()
12	t.Cleanup(func() {
13		require.NoError(t, db.Release(dataDir))
14		db.ResetPool()
15	})
16
17	conn, err := db.Connect(t.Context(), dataDir)
18	require.NoError(t, err)
19
20	sessions := NewService(db.New(conn), conn)
21
22	created, err := sessions.Create(t.Context(), "test")
23	require.NoError(t, err)
24	created.PromptTokens = 100
25	created.CompletionTokens = 50
26	created.EstimatedUsage = true
27
28	saved, err := sessions.Save(t.Context(), created)
29	require.NoError(t, err)
30	require.True(t, saved.EstimatedUsage)
31
32	fetched, err := sessions.Get(t.Context(), created.ID)
33	require.NoError(t, err)
34	require.True(t, fetched.EstimatedUsage)
35
36	fetched.Todos = []Todo{{
37		Content:    "Check estimate state",
38		Status:     TodoStatusInProgress,
39		ActiveForm: "Checking estimate state",
40	}}
41
42	updated, err := sessions.Save(t.Context(), fetched)
43	require.NoError(t, err)
44	require.True(t, updated.EstimatedUsage)
45
46	refetched, err := sessions.Get(t.Context(), created.ID)
47	require.NoError(t, err)
48	require.True(t, refetched.EstimatedUsage)
49}
50
51func TestEstimatedUsageStateCanBeClearedByExplicitSave(t *testing.T) {
52	dataDir := t.TempDir()
53	t.Cleanup(func() {
54		require.NoError(t, db.Release(dataDir))
55		db.ResetPool()
56	})
57
58	conn, err := db.Connect(t.Context(), dataDir)
59	require.NoError(t, err)
60
61	sessions := NewService(db.New(conn), conn)
62
63	created, err := sessions.Create(t.Context(), "test")
64	require.NoError(t, err)
65	created.PromptTokens = 100
66	created.CompletionTokens = 50
67	created.EstimatedUsage = true
68
69	saved, err := sessions.Save(t.Context(), created)
70	require.NoError(t, err)
71	require.True(t, saved.EstimatedUsage)
72
73	saved.EstimatedUsage = false
74	updated, err := sessions.Save(t.Context(), saved)
75	require.NoError(t, err)
76	require.False(t, updated.EstimatedUsage)
77
78	refetched, err := sessions.Get(t.Context(), created.ID)
79	require.NoError(t, err)
80	require.False(t, refetched.EstimatedUsage)
81}