1package proto
2
3import (
4 "encoding/json"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/stretchr/testify/require"
9)
10
11func TestModelsRoundTrip(t *testing.T) {
12 t.Parallel()
13
14 t.Run("nil models", func(t *testing.T) {
15 t.Parallel()
16 s := Session{
17 ID: "test-id",
18 Title: "Test",
19 }
20 data, err := json.Marshal(s)
21 require.NoError(t, err)
22 var decoded Session
23 require.NoError(t, json.Unmarshal(data, &decoded))
24 require.Nil(t, decoded.Models)
25 })
26
27 t.Run("populated models", func(t *testing.T) {
28 t.Parallel()
29 temp := 0.7
30 s := Session{
31 ID: "test-id",
32 Title: "Test",
33 Models: map[SelectedModelType]SelectedModel{
34 SelectedModelTypeLarge: {
35 Model: "gpt-4o",
36 Provider: "openai",
37 ReasoningEffort: "high",
38 Think: true,
39 MaxTokens: 4096,
40 Temperature: &temp,
41 ProviderOptions: map[string]any{"key": "value"},
42 },
43 SelectedModelTypeSmall: {
44 Model: "gpt-4o-mini",
45 Provider: "openai",
46 },
47 },
48 }
49 data, err := json.Marshal(s)
50 require.NoError(t, err)
51 var decoded Session
52 require.NoError(t, json.Unmarshal(data, &decoded))
53 require.Equal(t, "gpt-4o", decoded.Models[SelectedModelTypeLarge].Model)
54 require.Equal(t, "openai", decoded.Models[SelectedModelTypeLarge].Provider)
55 require.Equal(t, "high", decoded.Models[SelectedModelTypeLarge].ReasoningEffort)
56 require.True(t, decoded.Models[SelectedModelTypeLarge].Think)
57 require.Equal(t, int64(4096), decoded.Models[SelectedModelTypeLarge].MaxTokens)
58 require.NotNil(t, decoded.Models[SelectedModelTypeLarge].Temperature)
59 require.Equal(t, 0.7, *decoded.Models[SelectedModelTypeLarge].Temperature)
60 require.Equal(t, "gpt-4o-mini", decoded.Models[SelectedModelTypeSmall].Model)
61 })
62
63 t.Run("empty map models", func(t *testing.T) {
64 t.Parallel()
65 s := Session{
66 ID: "test-id",
67 Title: "Test",
68 Models: map[SelectedModelType]SelectedModel{},
69 }
70 data, err := json.Marshal(s)
71 require.NoError(t, err)
72 var decoded Session
73 require.NoError(t, json.Unmarshal(data, &decoded))
74 // Empty map with omitempty is dropped during marshaling.
75 require.Nil(t, decoded.Models)
76 })
77}
78
79func TestProtoToDomainRoundTrip(t *testing.T) {
80 t.Parallel()
81
82 t.Run("models through proto", func(t *testing.T) {
83 t.Parallel()
84 temp := 0.7
85 domainModels := map[config.SelectedModelType]config.SelectedModel{
86 config.SelectedModelTypeLarge: {
87 Model: "gpt-4o",
88 Provider: "openai",
89 ReasoningEffort: "high",
90 Think: true,
91 MaxTokens: 4096,
92 Temperature: &temp,
93 ProviderOptions: map[string]any{"key": "value"},
94 },
95 }
96
97 // Domain → Proto
98 protoModels := convertModelsToProtoLocal(domainModels)
99 require.Equal(t, SelectedModelTypeLarge, SelectedModelType(config.SelectedModelTypeLarge))
100 require.Equal(t, "gpt-4o", protoModels[SelectedModelTypeLarge].Model)
101 require.Equal(t, "openai", protoModels[SelectedModelTypeLarge].Provider)
102
103 // Proto → Domain
104 result := convertModelsFromProtoLocal(protoModels)
105 require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
106 require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
107 require.Equal(t, "high", result[config.SelectedModelTypeLarge].ReasoningEffort)
108 require.True(t, result[config.SelectedModelTypeLarge].Think)
109 require.Equal(t, int64(4096), result[config.SelectedModelTypeLarge].MaxTokens)
110 require.NotNil(t, result[config.SelectedModelTypeLarge].Temperature)
111 require.Equal(t, 0.7, *result[config.SelectedModelTypeLarge].Temperature)
112 })
113
114 t.Run("nil models round-trip", func(t *testing.T) {
115 t.Parallel()
116 protoModels := convertModelsToProtoLocal(nil)
117 require.Nil(t, protoModels)
118
119 domainModels := convertModelsFromProtoLocal(nil)
120 require.Nil(t, domainModels)
121 })
122}
123
124func convertModelsToProtoLocal(models map[config.SelectedModelType]config.SelectedModel) map[SelectedModelType]SelectedModel {
125 if models == nil {
126 return nil
127 }
128 result := make(map[SelectedModelType]SelectedModel, len(models))
129 for k, v := range models {
130 result[SelectedModelType(k)] = SelectedModel{
131 Model: v.Model,
132 Provider: v.Provider,
133 ReasoningEffort: v.ReasoningEffort,
134 Think: v.Think,
135 MaxTokens: v.MaxTokens,
136 Temperature: v.Temperature,
137 TopP: v.TopP,
138 TopK: v.TopK,
139 FrequencyPenalty: v.FrequencyPenalty,
140 PresencePenalty: v.PresencePenalty,
141 ProviderOptions: v.ProviderOptions,
142 }
143 }
144 return result
145}
146
147func convertModelsFromProtoLocal(models map[SelectedModelType]SelectedModel) map[config.SelectedModelType]config.SelectedModel {
148 if models == nil {
149 return nil
150 }
151 result := make(map[config.SelectedModelType]config.SelectedModel, len(models))
152 for k, v := range models {
153 result[config.SelectedModelType(k)] = config.SelectedModel{
154 Model: v.Model,
155 Provider: v.Provider,
156 ReasoningEffort: v.ReasoningEffort,
157 Think: v.Think,
158 MaxTokens: v.MaxTokens,
159 Temperature: v.Temperature,
160 TopP: v.TopP,
161 TopK: v.TopK,
162 FrequencyPenalty: v.FrequencyPenalty,
163 PresencePenalty: v.PresencePenalty,
164 ProviderOptions: v.ProviderOptions,
165 }
166 }
167 return result
168}