1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/tui/components/completions"
11 "github.com/charmbracelet/crush/internal/tui/components/core/list"
12 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15 "github.com/charmbracelet/lipgloss/v2"
16)
17
18type ModelListComponent struct {
19 list list.ListModel
20 modelType int
21 providers []catwalk.Provider
22}
23
24func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
25 modelList := list.New(
26 list.WithFilterable(true),
27 list.WithKeyMap(keyMap),
28 list.WithInputStyle(inputStyle),
29 list.WithFilterPlaceholder(inputPlaceholder),
30 list.WithWrapNavigation(true),
31 )
32
33 return &ModelListComponent{
34 list: modelList,
35 modelType: LargeModelType,
36 }
37}
38
39func (m *ModelListComponent) Init() tea.Cmd {
40 var cmds []tea.Cmd
41 if len(m.providers) == 0 {
42 providers, err := config.Providers()
43 m.providers = providers
44 if err != nil {
45 cmds = append(cmds, util.ReportError(err))
46 }
47 }
48 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
49 return tea.Batch(cmds...)
50}
51
52func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
53 u, cmd := m.list.Update(msg)
54 m.list = u.(list.ListModel)
55 return m, cmd
56}
57
58func (m *ModelListComponent) View() string {
59 return m.list.View()
60}
61
62func (m *ModelListComponent) Cursor() *tea.Cursor {
63 return m.list.Cursor()
64}
65
66func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
67 return m.list.SetSize(width, height)
68}
69
70func (m *ModelListComponent) Items() []util.Model {
71 return m.list.Items()
72}
73
74func (m *ModelListComponent) SelectedIndex() int {
75 return m.list.SelectedIndex()
76}
77
78func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
79 t := styles.CurrentTheme()
80 m.modelType = modelType
81
82 modelItems := []util.Model{}
83 // first none section
84 selectIndex := 1
85
86 cfg := config.Get()
87 var currentModel config.SelectedModel
88 if m.modelType == LargeModelType {
89 currentModel = cfg.Models[config.SelectedModelTypeLarge]
90 } else {
91 currentModel = cfg.Models[config.SelectedModelTypeSmall]
92 }
93
94 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
95 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
96
97 // Create a map to track which providers we've already added
98 addedProviders := make(map[string]bool)
99
100 // First, add any configured providers that are not in the known providers list
101 // These should appear at the top of the list
102 knownProviders, err := config.Providers()
103 if err != nil {
104 return util.ReportError(err)
105 }
106 for providerID, providerConfig := range cfg.Providers {
107 if providerConfig.Disable {
108 continue
109 }
110
111 // Check if this provider is not in the known providers list
112 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
113 // Convert config provider to provider.Provider format
114 configProvider := catwalk.Provider{
115 Name: providerConfig.Name,
116 ID: catwalk.InferenceProvider(providerID),
117 Models: make([]catwalk.Model, len(providerConfig.Models)),
118 }
119
120 // Convert models
121 for i, model := range providerConfig.Models {
122 configProvider.Models[i] = catwalk.Model{
123 ID: model.ID,
124 Name: model.Name,
125 CostPer1MIn: model.CostPer1MIn,
126 CostPer1MOut: model.CostPer1MOut,
127 CostPer1MInCached: model.CostPer1MInCached,
128 CostPer1MOutCached: model.CostPer1MOutCached,
129 ContextWindow: model.ContextWindow,
130 DefaultMaxTokens: model.DefaultMaxTokens,
131 CanReason: model.CanReason,
132 HasReasoningEffort: model.HasReasoningEffort,
133 DefaultReasoningEffort: model.DefaultReasoningEffort,
134 SupportsImages: model.SupportsImages,
135 }
136 }
137
138 // Add this unknown provider to the list
139 name := configProvider.Name
140 if name == "" {
141 name = string(configProvider.ID)
142 }
143 section := commands.NewItemSection(name)
144 section.SetInfo(configured)
145 modelItems = append(modelItems, section)
146 for _, model := range configProvider.Models {
147 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
148 Provider: configProvider,
149 Model: model,
150 }))
151 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
152 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
153 }
154 }
155 addedProviders[providerID] = true
156 }
157 }
158
159 // Then add the known providers from the predefined list
160 for _, provider := range m.providers {
161 // Skip if we already added this provider as an unknown provider
162 if addedProviders[string(provider.ID)] {
163 continue
164 }
165
166 // Check if this provider is configured and not disabled
167 if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
168 continue
169 }
170
171 name := provider.Name
172 if name == "" {
173 name = string(provider.ID)
174 }
175
176 section := commands.NewItemSection(name)
177 if _, ok := cfg.Providers[string(provider.ID)]; ok {
178 section.SetInfo(configured)
179 }
180 modelItems = append(modelItems, section)
181 for _, model := range provider.Models {
182 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
183 Provider: provider,
184 Model: model,
185 }))
186 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
187 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
188 }
189 }
190 }
191
192 return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
193}
194
195// GetModelType returns the current model type
196func (m *ModelListComponent) GetModelType() int {
197 return m.modelType
198}
199
200func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
201 m.list.SetFilterPlaceholder(placeholder)
202}
203
204func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
205 m.providers = providers
206}