session_test.go

  1package proto
  2
  3import (
  4	"encoding/json"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/stretchr/testify/require"
  9)
 10
 11func TestModelsRoundTrip(t *testing.T) {
 12	t.Parallel()
 13
 14	t.Run("nil models", func(t *testing.T) {
 15		t.Parallel()
 16		s := Session{
 17			ID:    "test-id",
 18			Title: "Test",
 19		}
 20		data, err := json.Marshal(s)
 21		require.NoError(t, err)
 22		var decoded Session
 23		require.NoError(t, json.Unmarshal(data, &decoded))
 24		require.Nil(t, decoded.Models)
 25	})
 26
 27	t.Run("populated models", func(t *testing.T) {
 28		t.Parallel()
 29		temp := 0.7
 30		s := Session{
 31			ID:    "test-id",
 32			Title: "Test",
 33			Models: map[SelectedModelType]SelectedModel{
 34				SelectedModelTypeLarge: {
 35					Model:           "gpt-4o",
 36					Provider:        "openai",
 37					ReasoningEffort: "high",
 38					Think:           true,
 39					MaxTokens:       4096,
 40					Temperature:     &temp,
 41					ProviderOptions: map[string]any{"key": "value"},
 42				},
 43				SelectedModelTypeSmall: {
 44					Model:    "gpt-4o-mini",
 45					Provider: "openai",
 46				},
 47			},
 48		}
 49		data, err := json.Marshal(s)
 50		require.NoError(t, err)
 51		var decoded Session
 52		require.NoError(t, json.Unmarshal(data, &decoded))
 53		require.Equal(t, "gpt-4o", decoded.Models[SelectedModelTypeLarge].Model)
 54		require.Equal(t, "openai", decoded.Models[SelectedModelTypeLarge].Provider)
 55		require.Equal(t, "high", decoded.Models[SelectedModelTypeLarge].ReasoningEffort)
 56		require.True(t, decoded.Models[SelectedModelTypeLarge].Think)
 57		require.Equal(t, int64(4096), decoded.Models[SelectedModelTypeLarge].MaxTokens)
 58		require.NotNil(t, decoded.Models[SelectedModelTypeLarge].Temperature)
 59		require.Equal(t, 0.7, *decoded.Models[SelectedModelTypeLarge].Temperature)
 60		require.Equal(t, "gpt-4o-mini", decoded.Models[SelectedModelTypeSmall].Model)
 61	})
 62
 63	t.Run("empty map models", func(t *testing.T) {
 64		t.Parallel()
 65		s := Session{
 66			ID:     "test-id",
 67			Title:  "Test",
 68			Models: map[SelectedModelType]SelectedModel{},
 69		}
 70		data, err := json.Marshal(s)
 71		require.NoError(t, err)
 72		var decoded Session
 73		require.NoError(t, json.Unmarshal(data, &decoded))
 74		// Empty map with omitempty is dropped during marshaling.
 75		require.Nil(t, decoded.Models)
 76	})
 77}
 78
 79func TestProtoToDomainRoundTrip(t *testing.T) {
 80	t.Parallel()
 81
 82	t.Run("models through proto", func(t *testing.T) {
 83		t.Parallel()
 84		temp := 0.7
 85		domainModels := map[config.SelectedModelType]config.SelectedModel{
 86			config.SelectedModelTypeLarge: {
 87				Model:           "gpt-4o",
 88				Provider:        "openai",
 89				ReasoningEffort: "high",
 90				Think:           true,
 91				MaxTokens:       4096,
 92				Temperature:     &temp,
 93				ProviderOptions: map[string]any{"key": "value"},
 94			},
 95		}
 96
 97		// Domain → Proto
 98		protoModels := convertModelsToProtoLocal(domainModels)
 99		require.Equal(t, SelectedModelTypeLarge, SelectedModelType(config.SelectedModelTypeLarge))
100		require.Equal(t, "gpt-4o", protoModels[SelectedModelTypeLarge].Model)
101		require.Equal(t, "openai", protoModels[SelectedModelTypeLarge].Provider)
102
103		// Proto → Domain
104		result := convertModelsFromProtoLocal(protoModels)
105		require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
106		require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
107		require.Equal(t, "high", result[config.SelectedModelTypeLarge].ReasoningEffort)
108		require.True(t, result[config.SelectedModelTypeLarge].Think)
109		require.Equal(t, int64(4096), result[config.SelectedModelTypeLarge].MaxTokens)
110		require.NotNil(t, result[config.SelectedModelTypeLarge].Temperature)
111		require.Equal(t, 0.7, *result[config.SelectedModelTypeLarge].Temperature)
112	})
113
114	t.Run("nil models round-trip", func(t *testing.T) {
115		t.Parallel()
116		protoModels := convertModelsToProtoLocal(nil)
117		require.Nil(t, protoModels)
118
119		domainModels := convertModelsFromProtoLocal(nil)
120		require.Nil(t, domainModels)
121	})
122}
123
124func convertModelsToProtoLocal(models map[config.SelectedModelType]config.SelectedModel) map[SelectedModelType]SelectedModel {
125	if models == nil {
126		return nil
127	}
128	result := make(map[SelectedModelType]SelectedModel, len(models))
129	for k, v := range models {
130		result[SelectedModelType(k)] = SelectedModel{
131			Model:            v.Model,
132			Provider:         v.Provider,
133			ReasoningEffort:  v.ReasoningEffort,
134			Think:            v.Think,
135			MaxTokens:        v.MaxTokens,
136			Temperature:      v.Temperature,
137			TopP:             v.TopP,
138			TopK:             v.TopK,
139			FrequencyPenalty: v.FrequencyPenalty,
140			PresencePenalty:  v.PresencePenalty,
141			ProviderOptions:  v.ProviderOptions,
142		}
143	}
144	return result
145}
146
147func convertModelsFromProtoLocal(models map[SelectedModelType]SelectedModel) map[config.SelectedModelType]config.SelectedModel {
148	if models == nil {
149		return nil
150	}
151	result := make(map[config.SelectedModelType]config.SelectedModel, len(models))
152	for k, v := range models {
153		result[config.SelectedModelType(k)] = config.SelectedModel{
154			Model:            v.Model,
155			Provider:         v.Provider,
156			ReasoningEffort:  v.ReasoningEffort,
157			Think:            v.Think,
158			MaxTokens:        v.MaxTokens,
159			Temperature:      v.Temperature,
160			TopP:             v.TopP,
161			TopK:             v.TopK,
162			FrequencyPenalty: v.FrequencyPenalty,
163			PresencePenalty:  v.PresencePenalty,
164			ProviderOptions:  v.ProviderOptions,
165		}
166	}
167	return result
168}