1package models
  2
  3import (
  4	"github.com/charmbracelet/bubbles/v2/help"
  5	"github.com/charmbracelet/bubbles/v2/key"
  6	tea "github.com/charmbracelet/bubbletea/v2"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/tui/components/completions"
 10	"github.com/charmbracelet/crush/internal/tui/components/core"
 11	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 12	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 13	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 14	"github.com/charmbracelet/crush/internal/tui/styles"
 15	"github.com/charmbracelet/crush/internal/tui/util"
 16	"github.com/charmbracelet/lipgloss/v2"
 17)
 18
 19const (
 20	ModelsDialogID dialogs.DialogID = "models"
 21
 22	defaultWidth = 60
 23)
 24
 25const (
 26	LargeModelType int = iota
 27	SmallModelType
 28)
 29
 30// ModelSelectedMsg is sent when a model is selected
 31type ModelSelectedMsg struct {
 32	Model     config.PreferredModel
 33	ModelType config.ModelType
 34}
 35
 36// CloseModelDialogMsg is sent when a model is selected
 37type CloseModelDialogMsg struct{}
 38
 39// ModelDialog interface for the model selection dialog
 40type ModelDialog interface {
 41	dialogs.DialogModel
 42}
 43
 44type ModelOption struct {
 45	Provider provider.Provider
 46	Model    provider.Model
 47}
 48
 49type modelDialogCmp struct {
 50	width   int
 51	wWidth  int
 52	wHeight int
 53
 54	modelList list.ListModel
 55	keyMap    KeyMap
 56	help      help.Model
 57	modelType int
 58}
 59
 60func NewModelDialogCmp() ModelDialog {
 61	listKeyMap := list.DefaultKeyMap()
 62	keyMap := DefaultKeyMap()
 63
 64	listKeyMap.Down.SetEnabled(false)
 65	listKeyMap.Up.SetEnabled(false)
 66	listKeyMap.HalfPageDown.SetEnabled(false)
 67	listKeyMap.HalfPageUp.SetEnabled(false)
 68	listKeyMap.Home.SetEnabled(false)
 69	listKeyMap.End.SetEnabled(false)
 70
 71	listKeyMap.DownOneItem = keyMap.Next
 72	listKeyMap.UpOneItem = keyMap.Previous
 73
 74	t := styles.CurrentTheme()
 75	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 76	modelList := list.New(
 77		list.WithFilterable(true),
 78		list.WithKeyMap(listKeyMap),
 79		list.WithInputStyle(inputStyle),
 80		list.WithWrapNavigation(true),
 81	)
 82	help := help.New()
 83	help.Styles = t.S().Help
 84
 85	return &modelDialogCmp{
 86		modelList: modelList,
 87		width:     defaultWidth,
 88		keyMap:    DefaultKeyMap(),
 89		help:      help,
 90		modelType: LargeModelType,
 91	}
 92}
 93
 94func (m *modelDialogCmp) Init() tea.Cmd {
 95	m.SetModelType(m.modelType)
 96	return m.modelList.Init()
 97}
 98
 99func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
100	switch msg := msg.(type) {
101	case tea.WindowSizeMsg:
102		m.wWidth = msg.Width
103		m.wHeight = msg.Height
104		m.SetModelType(m.modelType)
105		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
106	case tea.KeyPressMsg:
107		switch {
108		case key.Matches(msg, m.keyMap.Select):
109			selectedItemInx := m.modelList.SelectedIndex()
110			if selectedItemInx == list.NoSelection {
111				return m, nil
112			}
113			items := m.modelList.Items()
114			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
115
116			var modelType config.ModelType
117			if m.modelType == LargeModelType {
118				modelType = config.LargeModel
119			} else {
120				modelType = config.SmallModel
121			}
122
123			return m, tea.Sequence(
124				util.CmdHandler(dialogs.CloseDialogMsg{}),
125				util.CmdHandler(ModelSelectedMsg{
126					Model: config.PreferredModel{
127						ModelID:  selectedItem.Model.ID,
128						Provider: selectedItem.Provider.ID,
129					},
130					ModelType: modelType,
131				}),
132			)
133		case key.Matches(msg, m.keyMap.Tab):
134			if m.modelType == LargeModelType {
135				return m, m.SetModelType(SmallModelType)
136			} else {
137				return m, m.SetModelType(LargeModelType)
138			}
139		case key.Matches(msg, m.keyMap.Close):
140			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
141		default:
142			u, cmd := m.modelList.Update(msg)
143			m.modelList = u.(list.ListModel)
144			return m, cmd
145		}
146	}
147	return m, nil
148}
149
150func (m *modelDialogCmp) View() tea.View {
151	t := styles.CurrentTheme()
152	listView := m.modelList.View()
153	radio := m.modelTypeRadio()
154	content := lipgloss.JoinVertical(
155		lipgloss.Left,
156		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
157		listView.String(),
158		"",
159		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
160	)
161	v := tea.NewView(m.style().Render(content))
162	if listView.Cursor() != nil {
163		c := m.moveCursor(listView.Cursor())
164		v.SetCursor(c)
165	}
166	return v
167}
168
169func (m *modelDialogCmp) style() lipgloss.Style {
170	t := styles.CurrentTheme()
171	return t.S().Base.
172		Width(m.width).
173		Border(lipgloss.RoundedBorder()).
174		BorderForeground(t.BorderFocus)
175}
176
177func (m *modelDialogCmp) listWidth() int {
178	return defaultWidth - 2 // 4 for padding
179}
180
181func (m *modelDialogCmp) listHeight() int {
182	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
183	return min(listHeigh, m.wHeight/2)
184}
185
186func (m *modelDialogCmp) Position() (int, int) {
187	row := m.wHeight/4 - 2 // just a bit above the center
188	col := m.wWidth / 2
189	col -= m.width / 2
190	return row, col
191}
192
193func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
194	row, col := m.Position()
195	offset := row + 3 // Border + title
196	cursor.Y += offset
197	cursor.X = cursor.X + col + 2
198	return cursor
199}
200
201func (m *modelDialogCmp) ID() dialogs.DialogID {
202	return ModelsDialogID
203}
204
205func (m *modelDialogCmp) modelTypeRadio() string {
206	t := styles.CurrentTheme()
207	choices := []string{"Large", "Small"}
208	iconSelected := "◉"
209	iconUnselected := "○"
210	if m.modelType == LargeModelType {
211		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
212	}
213	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
214}
215
216func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
217	m.modelType = modelType
218
219	providers := config.Providers()
220	modelItems := []util.Model{}
221	selectIndex := 0
222
223	cfg := config.Get()
224	var currentModel config.PreferredModel
225	if m.modelType == LargeModelType {
226		currentModel = cfg.Models.Large
227	} else {
228		currentModel = cfg.Models.Small
229	}
230
231	for _, provider := range providers {
232		name := provider.Name
233		if name == "" {
234			name = string(provider.ID)
235		}
236		modelItems = append(modelItems, commands.NewItemSection(name))
237		for _, model := range provider.Models {
238			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
239				Provider: provider,
240				Model:    model,
241			}))
242			if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
243				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
244			}
245		}
246	}
247
248	return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
249}