1package splash
2
3import (
4 "fmt"
5 "log/slog"
6 "slices"
7
8 "github.com/charmbracelet/bubbles/v2/key"
9 tea "github.com/charmbracelet/bubbletea/v2"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/fur/provider"
12 "github.com/charmbracelet/crush/internal/tui/components/chat"
13 "github.com/charmbracelet/crush/internal/tui/components/completions"
14 "github.com/charmbracelet/crush/internal/tui/components/core"
15 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
16 "github.com/charmbracelet/crush/internal/tui/components/core/list"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
18 "github.com/charmbracelet/crush/internal/tui/components/logo"
19 "github.com/charmbracelet/crush/internal/tui/styles"
20 "github.com/charmbracelet/crush/internal/tui/util"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/charmbracelet/lipgloss/v2"
23)
24
25type Splash interface {
26 util.Model
27 layout.Sizeable
28 layout.Help
29 Cursor() *tea.Cursor
30 // SetOnboarding controls whether the splash shows model selection UI
31 SetOnboarding(bool)
32 // SetProjectInit controls whether the splash shows project initialization prompt
33 SetProjectInit(bool)
34}
35
36const (
37 SplashScreenPaddingX = 2 // Padding X for the splash screen
38 SplashScreenPaddingY = 1 // Padding Y for the splash screen
39)
40
41// OnboardingCompleteMsg is sent when onboarding is complete
42type OnboardingCompleteMsg struct{}
43
44type splashCmp struct {
45 width, height int
46 keyMap KeyMap
47 logoRendered string
48
49 // State
50 isOnboarding bool
51 needsProjectInit bool
52 needsAPIKey bool
53 selectedNo bool
54
55 modelList *models.ModelListComponent
56 apiKeyInput *models.APIKeyInput
57 selectedModel *models.ModelOption
58}
59
60func New() Splash {
61 keyMap := DefaultKeyMap()
62 listKeyMap := list.DefaultKeyMap()
63 listKeyMap.Down.SetEnabled(false)
64 listKeyMap.Up.SetEnabled(false)
65 listKeyMap.HalfPageDown.SetEnabled(false)
66 listKeyMap.HalfPageUp.SetEnabled(false)
67 listKeyMap.Home.SetEnabled(false)
68 listKeyMap.End.SetEnabled(false)
69 listKeyMap.DownOneItem = keyMap.Next
70 listKeyMap.UpOneItem = keyMap.Previous
71
72 t := styles.CurrentTheme()
73 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
74 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
75 apiKeyInput := models.NewAPIKeyInput()
76
77 return &splashCmp{
78 width: 0,
79 height: 0,
80 keyMap: keyMap,
81 logoRendered: "",
82 modelList: modelList,
83 apiKeyInput: apiKeyInput,
84 selectedNo: false,
85 }
86}
87
88func (s *splashCmp) SetOnboarding(onboarding bool) {
89 s.isOnboarding = onboarding
90 if onboarding {
91 providers, err := config.Providers()
92 if err != nil {
93 return
94 }
95 filteredProviders := []provider.Provider{}
96 simpleProviders := []string{
97 "anthropic",
98 "openai",
99 "gemini",
100 "xai",
101 "openrouter",
102 }
103 for _, p := range providers {
104 if slices.Contains(simpleProviders, string(p.ID)) {
105 filteredProviders = append(filteredProviders, p)
106 }
107 }
108 s.modelList.SetProviders(filteredProviders)
109 }
110}
111
112func (s *splashCmp) SetProjectInit(needsInit bool) {
113 s.needsProjectInit = needsInit
114}
115
116// GetSize implements SplashPage.
117func (s *splashCmp) GetSize() (int, int) {
118 return s.width, s.height
119}
120
121// Init implements SplashPage.
122func (s *splashCmp) Init() tea.Cmd {
123 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
124}
125
126// SetSize implements SplashPage.
127func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
128 s.width = width
129 s.height = height
130 s.logoRendered = s.logoBlock()
131 listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
132 listWidth := min(60, width-(SplashScreenPaddingX*2))
133
134 return s.modelList.SetSize(listWidth, listHeigh)
135}
136
137// Update implements SplashPage.
138func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
139 switch msg := msg.(type) {
140 case tea.WindowSizeMsg:
141 return s, s.SetSize(msg.Width, msg.Height)
142 case tea.KeyPressMsg:
143 switch {
144 case key.Matches(msg, s.keyMap.Back):
145 slog.Info("Back key pressed in splash screen")
146 if s.needsAPIKey {
147 // Go back to model selection
148 s.needsAPIKey = false
149 s.selectedModel = nil
150 return s, nil
151 }
152 case key.Matches(msg, s.keyMap.Select):
153 if s.isOnboarding && !s.needsAPIKey {
154 modelInx := s.modelList.SelectedIndex()
155 items := s.modelList.Items()
156 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
157 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
158 cmd := s.setPreferredModel(selectedItem)
159 s.isOnboarding = false
160 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
161 } else {
162 // Provider not configured, show API key input
163 s.needsAPIKey = true
164 s.selectedModel = &selectedItem
165 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
166 return s, nil
167 }
168 } else if s.needsAPIKey {
169 // Handle API key submission
170 apiKey := s.apiKeyInput.Value()
171 if apiKey != "" {
172 return s, s.saveAPIKeyAndContinue(apiKey)
173 }
174 } else if s.needsProjectInit {
175 return s, s.initializeProject()
176 }
177 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
178 if s.needsProjectInit {
179 s.selectedNo = !s.selectedNo
180 return s, nil
181 }
182 case key.Matches(msg, s.keyMap.Yes):
183 if s.needsProjectInit {
184 return s, s.initializeProject()
185 }
186 case key.Matches(msg, s.keyMap.No):
187 if s.needsProjectInit {
188 s.needsProjectInit = false
189 return s, util.CmdHandler(OnboardingCompleteMsg{})
190 }
191 default:
192 if s.needsAPIKey {
193 u, cmd := s.apiKeyInput.Update(msg)
194 s.apiKeyInput = u.(*models.APIKeyInput)
195 return s, cmd
196 } else if s.isOnboarding {
197 u, cmd := s.modelList.Update(msg)
198 s.modelList = u
199 return s, cmd
200 }
201 }
202 case tea.PasteMsg:
203 if s.needsAPIKey {
204 u, cmd := s.apiKeyInput.Update(msg)
205 s.apiKeyInput = u.(*models.APIKeyInput)
206 return s, cmd
207 } else if s.isOnboarding {
208 var cmd tea.Cmd
209 s.modelList, cmd = s.modelList.Update(msg)
210 return s, cmd
211 }
212 }
213 return s, nil
214}
215
216func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
217 if s.selectedModel == nil {
218 return util.ReportError(fmt.Errorf("no model selected"))
219 }
220
221 cfg := config.Get()
222 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
223 if err != nil {
224 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
225 }
226
227 // Reset API key state and continue with model selection
228 s.needsAPIKey = false
229 cmd := s.setPreferredModel(*s.selectedModel)
230 s.isOnboarding = false
231 s.selectedModel = nil
232
233 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
234}
235
236func (s *splashCmp) initializeProject() tea.Cmd {
237 s.needsProjectInit = false
238 prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2391. Build/lint/test commands - especially for running a single test
2402. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
241
242The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
243If there's already a CRUSH.md, improve it.
244If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
245Add the .crush directory to the .gitignore file if it's not already there.`
246
247 if err := config.MarkProjectInitialized(); err != nil {
248 return util.ReportError(err)
249 }
250 var cmds []tea.Cmd
251
252 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253 if !s.selectedNo {
254 cmds = append(cmds,
255 util.CmdHandler(chat.SessionClearedMsg{}),
256 util.CmdHandler(chat.SendMsg{
257 Text: prompt,
258 }),
259 )
260 }
261 return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265 cfg := config.Get()
266 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267 if model == nil {
268 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269 }
270
271 selectedModel := config.SelectedModel{
272 Model: selectedItem.Model.ID,
273 Provider: string(selectedItem.Provider.ID),
274 ReasoningEffort: model.DefaultReasoningEffort,
275 MaxTokens: model.DefaultMaxTokens,
276 }
277
278 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279 if err != nil {
280 return util.ReportError(err)
281 }
282
283 // Now lets automatically setup the small model
284 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285 if err != nil {
286 return util.ReportError(err)
287 }
288 if knownProvider == nil {
289 // for local provider we just use the same model
290 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291 if err != nil {
292 return util.ReportError(err)
293 }
294 } else {
295 smallModel := knownProvider.DefaultSmallModelID
296 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297 // should never happen
298 if model == nil {
299 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300 if err != nil {
301 return util.ReportError(err)
302 }
303 return nil
304 }
305 smallSelectedModel := config.SelectedModel{
306 Model: smallModel,
307 Provider: string(selectedItem.Provider.ID),
308 ReasoningEffort: model.DefaultReasoningEffort,
309 MaxTokens: model.DefaultMaxTokens,
310 }
311 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312 if err != nil {
313 return util.ReportError(err)
314 }
315 }
316 return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320 providers, err := config.Providers()
321 if err != nil {
322 return nil, err
323 }
324 for _, p := range providers {
325 if p.ID == providerID {
326 return &p, nil
327 }
328 }
329 return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333 cfg := config.Get()
334 if _, ok := cfg.Providers[providerID]; ok {
335 return true
336 }
337 return false
338}
339
340func (s *splashCmp) View() string {
341 t := styles.CurrentTheme()
342
343 var content string
344 if s.needsAPIKey {
345 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
346 apiKeyView := s.apiKeyInput.View()
347 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
348 lipgloss.JoinVertical(
349 lipgloss.Left,
350 apiKeyView,
351 ),
352 )
353 content = lipgloss.JoinVertical(
354 lipgloss.Left,
355 s.logoRendered,
356 apiKeySelector,
357 )
358 } else if s.isOnboarding {
359 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
360 modelListView := s.modelList.View()
361 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
362 lipgloss.JoinVertical(
363 lipgloss.Left,
364 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
365 "",
366 modelListView,
367 ),
368 )
369 content = lipgloss.JoinVertical(
370 lipgloss.Left,
371 s.logoRendered,
372 modelSelector,
373 )
374 } else if s.needsProjectInit {
375 t := styles.CurrentTheme()
376
377 titleStyle := t.S().Base.Foreground(t.FgBase)
378 bodyStyle := t.S().Base.Foreground(t.FgMuted)
379 shortcutStyle := t.S().Base.Foreground(t.Success)
380
381 initText := lipgloss.JoinVertical(
382 lipgloss.Left,
383 titleStyle.Render("Would you like to initialize this project?"),
384 "",
385 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
386 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
387 "",
388 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
389 "",
390 bodyStyle.Render("Would you like to initialize now?"),
391 )
392
393 yesButton := core.SelectableButton(core.ButtonOpts{
394 Text: "Yep!",
395 UnderlineIndex: 0,
396 Selected: !s.selectedNo,
397 })
398
399 noButton := core.SelectableButton(core.ButtonOpts{
400 Text: "Nope",
401 UnderlineIndex: 0,
402 Selected: s.selectedNo,
403 })
404
405 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
406
407 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
408
409 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
410 lipgloss.JoinVertical(
411 lipgloss.Left,
412 initText,
413 "",
414 buttons,
415 ),
416 )
417
418 content = lipgloss.JoinVertical(
419 lipgloss.Left,
420 s.logoRendered,
421 initContent,
422 )
423 } else {
424 content = s.logoRendered
425 }
426
427 return t.S().Base.
428 Width(s.width).
429 Height(s.height).
430 PaddingTop(SplashScreenPaddingY).
431 PaddingLeft(SplashScreenPaddingX).
432 PaddingRight(SplashScreenPaddingX).
433 PaddingBottom(SplashScreenPaddingY).
434 Render(content)
435}
436
437func (s *splashCmp) Cursor() *tea.Cursor {
438 if s.needsAPIKey {
439 cursor := s.apiKeyInput.Cursor()
440 if cursor != nil {
441 return s.moveCursor(cursor)
442 }
443 } else if s.isOnboarding {
444 cursor := s.modelList.Cursor()
445 if cursor != nil {
446 return s.moveCursor(cursor)
447 }
448 } else {
449 return nil
450 }
451 return nil
452}
453
454func (s *splashCmp) logoBlock() string {
455 t := styles.CurrentTheme()
456 const padding = 2
457 return logo.Render(version.Version, false, logo.Opts{
458 FieldColor: t.Primary,
459 TitleColorA: t.Secondary,
460 TitleColorB: t.Primary,
461 CharmColor: t.Secondary,
462 VersionColor: t.Primary,
463 Width: s.width - (SplashScreenPaddingX * 2),
464 })
465}
466
467func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
468 if cursor == nil {
469 return nil
470 }
471
472 // Calculate the correct Y offset based on current state
473 logoHeight := lipgloss.Height(m.logoRendered)
474 baseOffset := logoHeight + SplashScreenPaddingY
475
476 if m.needsAPIKey {
477 // For API key input, position at the bottom of the remaining space
478 remainingHeight := m.height - logoHeight - (SplashScreenPaddingY * 2)
479 offset := baseOffset + remainingHeight - lipgloss.Height(m.apiKeyInput.View())
480 cursor.Y += offset
481 // API key input already includes prompt in its cursor positioning
482 cursor.X = cursor.X + SplashScreenPaddingX
483 } else if m.isOnboarding {
484 // For model list, use the original calculation
485 listHeight := min(40, m.height-(SplashScreenPaddingY*2)-logoHeight-2)
486 offset := m.height - listHeight
487 cursor.Y += offset
488 // Model list doesn't have a prompt, so add padding + space for list styling
489 cursor.X = cursor.X + SplashScreenPaddingX + 1
490 }
491
492 return cursor
493}
494
495// Bindings implements SplashPage.
496func (s *splashCmp) Bindings() []key.Binding {
497 if s.needsAPIKey {
498 return []key.Binding{
499 s.keyMap.Select,
500 s.keyMap.Back,
501 }
502 } else if s.isOnboarding {
503 return []key.Binding{
504 s.keyMap.Select,
505 s.keyMap.Next,
506 s.keyMap.Previous,
507 }
508 } else if s.needsProjectInit {
509 return []key.Binding{
510 s.keyMap.Select,
511 s.keyMap.Yes,
512 s.keyMap.No,
513 s.keyMap.Tab,
514 s.keyMap.LeftRight,
515 }
516 }
517 return []key.Binding{}
518}