1package session
2
3import (
4 "database/sql"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/db"
9 "github.com/stretchr/testify/require"
10)
11
12func TestMarshalModels(t *testing.T) {
13 t.Parallel()
14
15 t.Run("empty", func(t *testing.T) {
16 t.Parallel()
17 result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{})
18 require.NoError(t, err)
19 require.Equal(t, "", result)
20 })
21
22 t.Run("nil", func(t *testing.T) {
23 t.Parallel()
24 result, err := marshalModels(nil)
25 require.NoError(t, err)
26 require.Equal(t, "", result)
27 })
28
29 t.Run("single entry", func(t *testing.T) {
30 t.Parallel()
31 models := map[config.SelectedModelType]config.SelectedModel{
32 config.SelectedModelTypeLarge: {
33 Model: "claude-sonnet-4-20250514",
34 Provider: "anthropic",
35 },
36 }
37 result, err := marshalModels(models)
38 require.NoError(t, err)
39 require.Contains(t, result, "claude-sonnet-4-20250514")
40 require.Contains(t, result, "anthropic")
41 })
42
43 t.Run("round-trip", func(t *testing.T) {
44 t.Parallel()
45 temp := 0.7
46 topP := 0.9
47 topK := int64(50)
48 freqPen := 0.1
49 presPen := 0.2
50 models := map[config.SelectedModelType]config.SelectedModel{
51 config.SelectedModelTypeLarge: {
52 Model: "gpt-4o",
53 Provider: "openai",
54 ReasoningEffort: "high",
55 Think: true,
56 MaxTokens: 4096,
57 Temperature: &temp,
58 TopP: &topP,
59 TopK: &topK,
60 FrequencyPenalty: &freqPen,
61 PresencePenalty: &presPen,
62 ProviderOptions: map[string]any{"key": "value"},
63 },
64 config.SelectedModelTypeSmall: {
65 Model: "gpt-4o-mini",
66 Provider: "openai",
67 },
68 }
69 data, err := marshalModels(models)
70 require.NoError(t, err)
71 result, err := unmarshalModels(data)
72 require.NoError(t, err)
73 require.Equal(t, models, result)
74 })
75}
76
77func TestUnmarshalModels(t *testing.T) {
78 t.Parallel()
79
80 t.Run("empty string", func(t *testing.T) {
81 t.Parallel()
82 result, err := unmarshalModels("")
83 require.NoError(t, err)
84 require.Nil(t, result)
85 })
86
87 t.Run("valid JSON", func(t *testing.T) {
88 t.Parallel()
89 data := `{"large":{"model":"gpt-4o","provider":"openai"}}`
90 result, err := unmarshalModels(data)
91 require.NoError(t, err)
92 require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
93 require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
94 })
95
96 t.Run("invalid JSON", func(t *testing.T) {
97 t.Parallel()
98 _, err := unmarshalModels("{invalid}")
99 require.Error(t, err)
100 })
101}
102
103func TestFromDBItemWithModels(t *testing.T) {
104 t.Parallel()
105
106 t.Run("null models", func(t *testing.T) {
107 t.Parallel()
108 item := testDBSession()
109 item.Models = sql.NullString{Valid: false}
110 result := service{}.fromDBItem(item)
111 require.Nil(t, result.Models)
112 })
113
114 t.Run("empty models", func(t *testing.T) {
115 t.Parallel()
116 item := testDBSession()
117 item.Models = sql.NullString{String: "", Valid: true}
118 result := service{}.fromDBItem(item)
119 require.Nil(t, result.Models)
120 })
121
122 t.Run("valid models", func(t *testing.T) {
123 t.Parallel()
124 item := testDBSession()
125 item.Models = sql.NullString{
126 String: `{"large":{"model":"gpt-4o","provider":"openai"}}`,
127 Valid: true,
128 }
129 result := service{}.fromDBItem(item)
130 require.NotNil(t, result.Models)
131 require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model)
132 })
133
134 t.Run("invalid JSON models", func(t *testing.T) {
135 t.Parallel()
136 item := testDBSession()
137 item.Models = sql.NullString{
138 String: "{invalid}",
139 Valid: true,
140 }
141 result := service{}.fromDBItem(item)
142 require.Nil(t, result.Models)
143 })
144}
145
146func testDBSession() db.Session {
147 return db.Session{
148 ID: "test-id",
149 Title: "Test",
150 }
151}