1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/fur/provider"
10 "github.com/charmbracelet/crush/internal/tui/components/completions"
11 "github.com/charmbracelet/crush/internal/tui/components/core/list"
12 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15 "github.com/charmbracelet/lipgloss/v2"
16)
17
18type ModelListComponent struct {
19 list list.ListModel
20 modelType int
21 providers []provider.Provider
22}
23
24func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
25 modelList := list.New(
26 list.WithFilterable(true),
27 list.WithKeyMap(keyMap),
28 list.WithInputStyle(inputStyle),
29 list.WithFilterPlaceholder(inputPlaceholder),
30 list.WithWrapNavigation(true),
31 )
32
33 return &ModelListComponent{
34 list: modelList,
35 modelType: LargeModelType,
36 }
37}
38
39func (m *ModelListComponent) Init() tea.Cmd {
40 var cmds []tea.Cmd
41 if len(m.providers) == 0 {
42 providers, err := config.Providers()
43 m.providers = providers
44 if err != nil {
45 cmds = append(cmds, util.ReportError(err))
46 }
47 }
48 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
49 return tea.Batch(cmds...)
50}
51
52func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
53 u, cmd := m.list.Update(msg)
54 m.list = u.(list.ListModel)
55 return m, cmd
56}
57
58func (m *ModelListComponent) View() tea.View {
59 return m.list.View()
60}
61
62func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
63 return m.list.SetSize(width, height)
64}
65
66func (m *ModelListComponent) Items() []util.Model {
67 return m.list.Items()
68}
69
70func (m *ModelListComponent) SelectedIndex() int {
71 return m.list.SelectedIndex()
72}
73
74func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
75 t := styles.CurrentTheme()
76 m.modelType = modelType
77
78 modelItems := []util.Model{}
79 selectIndex := 0
80
81 cfg := config.Get()
82 var currentModel config.SelectedModel
83 if m.modelType == LargeModelType {
84 currentModel = cfg.Models[config.SelectedModelTypeLarge]
85 } else {
86 currentModel = cfg.Models[config.SelectedModelTypeSmall]
87 }
88
89 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
90 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
91
92 // Create a map to track which providers we've already added
93 addedProviders := make(map[string]bool)
94
95 // First, add any configured providers that are not in the known providers list
96 // These should appear at the top of the list
97 knownProviders := provider.KnownProviders()
98 for providerID, providerConfig := range cfg.Providers {
99 if providerConfig.Disable {
100 continue
101 }
102
103 // Check if this provider is not in the known providers list
104 if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
105 // Convert config provider to provider.Provider format
106 configProvider := provider.Provider{
107 Name: string(providerID), // Use provider ID as name for unknown providers
108 ID: provider.InferenceProvider(providerID),
109 Models: make([]provider.Model, len(providerConfig.Models)),
110 }
111
112 // Convert models
113 for i, model := range providerConfig.Models {
114 configProvider.Models[i] = provider.Model{
115 ID: model.ID,
116 Name: model.Name,
117 CostPer1MIn: model.CostPer1MIn,
118 CostPer1MOut: model.CostPer1MOut,
119 CostPer1MInCached: model.CostPer1MInCached,
120 CostPer1MOutCached: model.CostPer1MOutCached,
121 ContextWindow: model.ContextWindow,
122 DefaultMaxTokens: model.DefaultMaxTokens,
123 CanReason: model.CanReason,
124 HasReasoningEffort: model.HasReasoningEffort,
125 DefaultReasoningEffort: model.DefaultReasoningEffort,
126 SupportsImages: model.SupportsImages,
127 }
128 }
129
130 // Add this unknown provider to the list
131 name := configProvider.Name
132 if name == "" {
133 name = string(configProvider.ID)
134 }
135 section := commands.NewItemSection(name)
136 section.SetInfo(configured)
137 modelItems = append(modelItems, section)
138 for _, model := range configProvider.Models {
139 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
140 Provider: configProvider,
141 Model: model,
142 }))
143 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
144 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
145 }
146 }
147 addedProviders[providerID] = true
148 }
149 }
150
151 // Then add the known providers from the predefined list
152 for _, provider := range m.providers {
153 // Skip if we already added this provider as an unknown provider
154 if addedProviders[string(provider.ID)] {
155 continue
156 }
157
158 // Check if this provider is configured and not disabled
159 if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
160 continue
161 }
162
163 name := provider.Name
164 if name == "" {
165 name = string(provider.ID)
166 }
167
168 section := commands.NewItemSection(name)
169 if _, ok := cfg.Providers[string(provider.ID)]; ok {
170 section.SetInfo(configured)
171 }
172 modelItems = append(modelItems, section)
173 for _, model := range provider.Models {
174 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
175 Provider: provider,
176 Model: model,
177 }))
178 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
179 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
180 }
181 }
182 }
183
184 return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
185}
186
187// GetModelType returns the current model type
188func (m *ModelListComponent) GetModelType() int {
189 return m.modelType
190}
191
192func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
193 m.list.SetFilterPlaceholder(placeholder)
194}
195
196func (m *ModelListComponent) SetProviders(providers []provider.Provider) {
197 m.providers = providers
198}