1package app
2
3import (
4 "fmt"
5 "strings"
6
7 "github.com/charmbracelet/crush/internal/config"
8 xstrings "github.com/charmbracelet/x/exp/strings"
9)
10
11// parseModelStr parses a model string into provider filter and model ID.
12// Format: "model-name" or "provider/model-name" or "synthetic/moonshot/kimi-k2".
13// This function only checks if the first component is a valid provider name; if not,
14// it treats the entire string as a model ID (which may contain slashes).
15func parseModelStr(providers map[string]config.ProviderConfig, modelStr string) (providerFilter, modelID string) {
16 parts := strings.Split(modelStr, "/")
17 if len(parts) == 1 {
18 return "", parts[0]
19 }
20 // Check if the first part is a valid provider name
21 if _, ok := providers[parts[0]]; ok {
22 return parts[0], strings.Join(parts[1:], "/")
23 }
24
25 // First part is not a valid provider, treat entire string as model ID
26 return "", modelStr
27}
28
29// modelMatch represents a found model.
30type modelMatch struct {
31 provider string
32 modelID string
33}
34
35func findModels(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch, error) {
36 largeProviderFilter, largeModelID := parseModelStr(providers, largeModel)
37 smallProviderFilter, smallModelID := parseModelStr(providers, smallModel)
38
39 // Validate provider filters exist.
40 for _, pf := range []struct {
41 filter, label string
42 }{
43 {largeProviderFilter, "large"},
44 {smallProviderFilter, "small"},
45 } {
46 if pf.filter != "" {
47 if _, ok := providers[pf.filter]; !ok {
48 return nil, nil, fmt.Errorf("%s model: provider %q not found in configuration. Use 'crush models' to list available models", pf.label, pf.filter)
49 }
50 }
51 }
52
53 // Find matching models in a single pass.
54 var largeMatches, smallMatches []modelMatch
55 for name, provider := range providers {
56 if provider.Disable {
57 continue
58 }
59 for _, m := range provider.Models {
60 if filter(largeModelID, largeProviderFilter, m.ID, name) {
61 largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
62 }
63 if filter(smallModelID, smallProviderFilter, m.ID, name) {
64 smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
65 }
66 }
67 }
68
69 return largeMatches, smallMatches, nil
70}
71
72func filter(modelFilter, providerFilter, model, provider string) bool {
73 return modelFilter != "" && model == modelFilter &&
74 (providerFilter == "" || provider == providerFilter)
75}
76
77// Validate and return a single match.
78func validateMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
79 switch {
80 case len(matches) == 0:
81 return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
82 case len(matches) > 1:
83 names := make([]string, len(matches))
84 for i, m := range matches {
85 names[i] = m.provider
86 }
87 return modelMatch{}, fmt.Errorf(
88 "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
89 label,
90 modelID,
91 xstrings.EnglishJoin(names, true),
92 )
93 }
94 return matches[0], nil
95}