provider.go

 1package app
 2
 3import (
 4	"fmt"
 5	"strings"
 6
 7	"github.com/charmbracelet/crush/internal/config"
 8	xstrings "github.com/charmbracelet/x/exp/strings"
 9)
10
11// parseModelStr parses a model string into provider filter and model ID.
12// Format: "model-name" or "provider/model-name" or "synthetic/moonshot/kimi-k2".
13// This function only checks if the first component is a valid provider name; if not,
14// it treats the entire string as a model ID (which may contain slashes).
15func parseModelStr(providers map[string]config.ProviderConfig, modelStr string) (providerFilter, modelID string) {
16	parts := strings.Split(modelStr, "/")
17	if len(parts) == 1 {
18		return "", parts[0]
19	}
20	// Check if the first part is a valid provider name
21	if _, ok := providers[parts[0]]; ok {
22		return parts[0], strings.Join(parts[1:], "/")
23	}
24
25	// First part is not a valid provider, treat entire string as model ID
26	return "", modelStr
27}
28
29// modelMatch represents a found model.
30type modelMatch struct {
31	provider string
32	modelID  string
33}
34
35func findModels(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch, error) {
36	largeProviderFilter, largeModelID := parseModelStr(providers, largeModel)
37	smallProviderFilter, smallModelID := parseModelStr(providers, smallModel)
38
39	// Validate provider filters exist.
40	for _, pf := range []struct {
41		filter, label string
42	}{
43		{largeProviderFilter, "large"},
44		{smallProviderFilter, "small"},
45	} {
46		if pf.filter != "" {
47			if _, ok := providers[pf.filter]; !ok {
48				return nil, nil, fmt.Errorf("%s model: provider %q not found in configuration. Use 'crush models' to list available models", pf.label, pf.filter)
49			}
50		}
51	}
52
53	// Find matching models in a single pass.
54	var largeMatches, smallMatches []modelMatch
55	for name, provider := range providers {
56		if provider.Disable {
57			continue
58		}
59		for _, m := range provider.Models {
60			if filter(largeModelID, largeProviderFilter, m.ID, name) {
61				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
62			}
63			if filter(smallModelID, smallProviderFilter, m.ID, name) {
64				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
65			}
66		}
67	}
68
69	return largeMatches, smallMatches, nil
70}
71
72func filter(modelFilter, providerFilter, model, provider string) bool {
73	return modelFilter != "" && model == modelFilter &&
74		(providerFilter == "" || provider == providerFilter)
75}
76
77// Validate and return a single match.
78func validateMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
79	switch {
80	case len(matches) == 0:
81		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
82	case len(matches) > 1:
83		names := make([]string, len(matches))
84		for i, m := range matches {
85			names[i] = m.provider
86		}
87		return modelMatch{}, fmt.Errorf(
88			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
89			label,
90			modelID,
91			xstrings.EnglishJoin(names, true),
92		)
93	}
94	return matches[0], nil
95}