provider_test.go

  1package app
  2
  3import (
  4	"testing"
  5
  6	"github.com/charmbracelet/catwalk/pkg/catwalk"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/stretchr/testify/require"
  9)
 10
 11func TestParseModelStr(t *testing.T) {
 12	tests := []struct {
 13		name            string
 14		modelStr        string
 15		expectedFilter  string
 16		expectedModelID string
 17		setupProviders  func() map[string]config.ProviderConfig
 18	}{
 19		{
 20			name:            "simple model with no slashes",
 21			modelStr:        "gpt-4o",
 22			expectedFilter:  "",
 23			expectedModelID: "gpt-4o",
 24			setupProviders:  setupMockProviders,
 25		},
 26		{
 27			name:            "valid provider and model",
 28			modelStr:        "openai/gpt-4o",
 29			expectedFilter:  "openai",
 30			expectedModelID: "gpt-4o",
 31			setupProviders:  setupMockProviders,
 32		},
 33		{
 34			name:            "model with multiple slashes and first part is invalid provider",
 35			modelStr:        "moonshot/kimi-k2",
 36			expectedFilter:  "",
 37			expectedModelID: "moonshot/kimi-k2",
 38			setupProviders:  setupMockProviders,
 39		},
 40		{
 41			name:            "full path with valid provider and model with slashes",
 42			modelStr:        "synthetic/moonshot/kimi-k2",
 43			expectedFilter:  "synthetic",
 44			expectedModelID: "moonshot/kimi-k2",
 45			setupProviders:  setupMockProvidersWithSlashes,
 46		},
 47		{
 48			name:            "empty model string",
 49			modelStr:        "",
 50			expectedFilter:  "",
 51			expectedModelID: "",
 52			setupProviders:  setupMockProviders,
 53		},
 54		{
 55			name:            "model with trailing slash but valid provider",
 56			modelStr:        "openai/",
 57			expectedFilter:  "openai",
 58			expectedModelID: "",
 59			setupProviders:  setupMockProviders,
 60		},
 61	}
 62
 63	for _, tt := range tests {
 64		t.Run(tt.name, func(t *testing.T) {
 65			providers := tt.setupProviders()
 66			filter, modelID := parseModelStr(providers, tt.modelStr)
 67
 68			require.Equal(t, tt.expectedFilter, filter, "provider filter mismatch")
 69			require.Equal(t, tt.expectedModelID, modelID, "model ID mismatch")
 70		})
 71	}
 72}
 73
 74func setupMockProviders() map[string]config.ProviderConfig {
 75	return map[string]config.ProviderConfig{
 76		"openai": {
 77			ID:     "openai",
 78			Name:   "OpenAI",
 79			Models: []catwalk.Model{{ID: "gpt-4o"}, {ID: "gpt-4o-mini"}},
 80		},
 81		"anthropic": {
 82			ID:     "anthropic",
 83			Name:   "Anthropic",
 84			Models: []catwalk.Model{{ID: "claude-3-sonnet"}, {ID: "claude-3-opus"}},
 85		},
 86	}
 87}
 88
 89func setupMockProvidersWithSlashes() map[string]config.ProviderConfig {
 90	return map[string]config.ProviderConfig{
 91		"synthetic": {
 92			ID:   "synthetic",
 93			Name: "Synthetic",
 94			Models: []catwalk.Model{
 95				{ID: "moonshot/kimi-k2"},
 96				{ID: "deepseek/deepseek-chat"},
 97			},
 98		},
 99		"openai": {
100			ID:     "openai",
101			Name:   "OpenAI",
102			Models: []catwalk.Model{{ID: "gpt-4o"}},
103		},
104	}
105}
106
107func TestFindModels(t *testing.T) {
108	tests := []struct {
109		name             string
110		modelStr         string
111		expectedProvider string
112		expectedModelID  string
113		expectError      bool
114		errorContains    string
115		setupProviders   func() map[string]config.ProviderConfig
116	}{
117		{
118			name:             "simple model found in one provider",
119			modelStr:         "gpt-4o",
120			expectedProvider: "openai",
121			expectedModelID:  "gpt-4o",
122			expectError:      false,
123			setupProviders:   setupMockProviders,
124		},
125		{
126			name:             "model with slashes in ID",
127			modelStr:         "moonshot/kimi-k2",
128			expectedProvider: "synthetic",
129			expectedModelID:  "moonshot/kimi-k2",
130			expectError:      false,
131			setupProviders:   setupMockProvidersWithSlashes,
132		},
133		{
134			name:             "provider and model with slashes in ID",
135			modelStr:         "synthetic/moonshot/kimi-k2",
136			expectedProvider: "synthetic",
137			expectedModelID:  "moonshot/kimi-k2",
138			expectError:      false,
139			setupProviders:   setupMockProvidersWithSlashes,
140		},
141		{
142			name:           "model not found",
143			modelStr:       "nonexistent-model",
144			expectError:    true,
145			errorContains:  "not found",
146			setupProviders: setupMockProviders,
147		},
148		{
149			name:           "invalid provider specified",
150			modelStr:       "nonexistent-provider/gpt-4o",
151			expectError:    true,
152			errorContains:  "provider",
153			setupProviders: setupMockProviders,
154		},
155		{
156			name:          "model found in multiple providers without provider filter",
157			modelStr:      "shared-model",
158			expectError:   true,
159			errorContains: "multiple providers",
160			setupProviders: func() map[string]config.ProviderConfig {
161				return map[string]config.ProviderConfig{
162					"openai": {
163						ID:     "openai",
164						Models: []catwalk.Model{{ID: "shared-model"}},
165					},
166					"anthropic": {
167						ID:     "anthropic",
168						Models: []catwalk.Model{{ID: "shared-model"}},
169					},
170				}
171			},
172		},
173		{
174			name:           "empty model string",
175			modelStr:       "",
176			expectError:    true,
177			errorContains:  "not found",
178			setupProviders: setupMockProviders,
179		},
180	}
181
182	for _, tt := range tests {
183		t.Run(tt.name, func(t *testing.T) {
184			providers := tt.setupProviders()
185
186			// Use findModels with the model as "large" and empty "small".
187			matches, _, err := findModels(providers, tt.modelStr, "")
188			if err != nil {
189				if tt.expectError {
190					require.Contains(t, err.Error(), tt.errorContains)
191				} else {
192					require.NoError(t, err)
193				}
194				return
195			}
196
197			// Validate the matches.
198			match, err := validateMatches(matches, tt.modelStr, "large")
199
200			if tt.expectError {
201				require.Error(t, err)
202				require.Contains(t, err.Error(), tt.errorContains)
203			} else {
204				require.NoError(t, err)
205				require.Equal(t, tt.expectedProvider, match.provider)
206				require.Equal(t, tt.expectedModelID, match.modelID)
207			}
208		})
209	}
210}