session_test.go

  1package session
  2
  3import (
  4	"database/sql"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/db"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12func TestMarshalModels(t *testing.T) {
 13	t.Parallel()
 14
 15	t.Run("empty", func(t *testing.T) {
 16		t.Parallel()
 17		result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{})
 18		require.NoError(t, err)
 19		require.Equal(t, "", result)
 20	})
 21
 22	t.Run("nil", func(t *testing.T) {
 23		t.Parallel()
 24		result, err := marshalModels(nil)
 25		require.NoError(t, err)
 26		require.Equal(t, "", result)
 27	})
 28
 29	t.Run("single entry", func(t *testing.T) {
 30		t.Parallel()
 31		models := map[config.SelectedModelType]config.SelectedModel{
 32			config.SelectedModelTypeLarge: {
 33				Model:    "claude-sonnet-4-20250514",
 34				Provider: "anthropic",
 35			},
 36		}
 37		result, err := marshalModels(models)
 38		require.NoError(t, err)
 39		require.Contains(t, result, "claude-sonnet-4-20250514")
 40		require.Contains(t, result, "anthropic")
 41	})
 42
 43	t.Run("round-trip", func(t *testing.T) {
 44		t.Parallel()
 45		temp := 0.7
 46		topP := 0.9
 47		topK := int64(50)
 48		freqPen := 0.1
 49		presPen := 0.2
 50		models := map[config.SelectedModelType]config.SelectedModel{
 51			config.SelectedModelTypeLarge: {
 52				Model:            "gpt-4o",
 53				Provider:         "openai",
 54				ReasoningEffort:  "high",
 55				Think:            true,
 56				MaxTokens:        4096,
 57				Temperature:      &temp,
 58				TopP:             &topP,
 59				TopK:             &topK,
 60				FrequencyPenalty: &freqPen,
 61				PresencePenalty:  &presPen,
 62				ProviderOptions:  map[string]any{"key": "value"},
 63			},
 64			config.SelectedModelTypeSmall: {
 65				Model:    "gpt-4o-mini",
 66				Provider: "openai",
 67			},
 68		}
 69		data, err := marshalModels(models)
 70		require.NoError(t, err)
 71		result, err := unmarshalModels(data)
 72		require.NoError(t, err)
 73		require.Equal(t, models, result)
 74	})
 75}
 76
 77func TestUnmarshalModels(t *testing.T) {
 78	t.Parallel()
 79
 80	t.Run("empty string", func(t *testing.T) {
 81		t.Parallel()
 82		result, err := unmarshalModels("")
 83		require.NoError(t, err)
 84		require.Nil(t, result)
 85	})
 86
 87	t.Run("valid JSON", func(t *testing.T) {
 88		t.Parallel()
 89		data := `{"large":{"model":"gpt-4o","provider":"openai"}}`
 90		result, err := unmarshalModels(data)
 91		require.NoError(t, err)
 92		require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
 93		require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
 94	})
 95
 96	t.Run("invalid JSON", func(t *testing.T) {
 97		t.Parallel()
 98		_, err := unmarshalModels("{invalid}")
 99		require.Error(t, err)
100	})
101}
102
103func TestFromDBItemWithModels(t *testing.T) {
104	t.Parallel()
105
106	t.Run("null models", func(t *testing.T) {
107		t.Parallel()
108		item := testDBSession()
109		item.Models = sql.NullString{Valid: false}
110		result := service{}.fromDBItem(item)
111		require.Nil(t, result.Models)
112	})
113
114	t.Run("empty models", func(t *testing.T) {
115		t.Parallel()
116		item := testDBSession()
117		item.Models = sql.NullString{String: "", Valid: true}
118		result := service{}.fromDBItem(item)
119		require.Nil(t, result.Models)
120	})
121
122	t.Run("valid models", func(t *testing.T) {
123		t.Parallel()
124		item := testDBSession()
125		item.Models = sql.NullString{
126			String: `{"large":{"model":"gpt-4o","provider":"openai"}}`,
127			Valid:  true,
128		}
129		result := service{}.fromDBItem(item)
130		require.NotNil(t, result.Models)
131		require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model)
132	})
133
134	t.Run("invalid JSON models", func(t *testing.T) {
135		t.Parallel()
136		item := testDBSession()
137		item.Models = sql.NullString{
138			String: "{invalid}",
139			Valid:  true,
140		}
141		result := service{}.fromDBItem(item)
142		require.Nil(t, result.Models)
143	})
144}
145
146func testDBSession() db.Session {
147	return db.Session{
148		ID:    "test-id",
149		Title: "Test",
150	}
151}