1package session
2
3import (
4 "testing"
5
6 "github.com/charmbracelet/crush/internal/db"
7 "github.com/stretchr/testify/require"
8)
9
10func TestEstimatedUsageStateSurvivesFetchModifySave(t *testing.T) {
11 dataDir := t.TempDir()
12 t.Cleanup(func() {
13 require.NoError(t, db.Release(dataDir))
14 db.ResetPool()
15 })
16
17 conn, err := db.Connect(t.Context(), dataDir)
18 require.NoError(t, err)
19
20 sessions := NewService(db.New(conn), conn)
21
22 created, err := sessions.Create(t.Context(), "test")
23 require.NoError(t, err)
24 created.PromptTokens = 100
25 created.CompletionTokens = 50
26 created.EstimatedUsage = true
27
28 saved, err := sessions.Save(t.Context(), created)
29 require.NoError(t, err)
30 require.True(t, saved.EstimatedUsage)
31
32 fetched, err := sessions.Get(t.Context(), created.ID)
33 require.NoError(t, err)
34 require.True(t, fetched.EstimatedUsage)
35
36 fetched.Todos = []Todo{{
37 Content: "Check estimate state",
38 Status: TodoStatusInProgress,
39 ActiveForm: "Checking estimate state",
40 }}
41
42 updated, err := sessions.Save(t.Context(), fetched)
43 require.NoError(t, err)
44 require.True(t, updated.EstimatedUsage)
45
46 refetched, err := sessions.Get(t.Context(), created.ID)
47 require.NoError(t, err)
48 require.True(t, refetched.EstimatedUsage)
49}
50
51func TestEstimatedUsageStateCanBeClearedByExplicitSave(t *testing.T) {
52 dataDir := t.TempDir()
53 t.Cleanup(func() {
54 require.NoError(t, db.Release(dataDir))
55 db.ResetPool()
56 })
57
58 conn, err := db.Connect(t.Context(), dataDir)
59 require.NoError(t, err)
60
61 sessions := NewService(db.New(conn), conn)
62
63 created, err := sessions.Create(t.Context(), "test")
64 require.NoError(t, err)
65 created.PromptTokens = 100
66 created.CompletionTokens = 50
67 created.EstimatedUsage = true
68
69 saved, err := sessions.Save(t.Context(), created)
70 require.NoError(t, err)
71 require.True(t, saved.EstimatedUsage)
72
73 saved.EstimatedUsage = false
74 updated, err := sessions.Save(t.Context(), saved)
75 require.NoError(t, err)
76 require.False(t, updated.EstimatedUsage)
77
78 refetched, err := sessions.Get(t.Context(), created.ID)
79 require.NoError(t, err)
80 require.False(t, refetched.EstimatedUsage)
81}