1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "groq",
109 "openrouter",
110 }
111 for _, p := range providers {
112 if slices.Contains(simpleProviders, string(p.ID)) {
113 filteredProviders = append(filteredProviders, p)
114 }
115 }
116 s.modelList.SetProviders(filteredProviders)
117 }
118}
119
120func (s *splashCmp) SetProjectInit(needsInit bool) {
121 s.needsProjectInit = needsInit
122}
123
124// GetSize implements SplashPage.
125func (s *splashCmp) GetSize() (int, int) {
126 return s.width, s.height
127}
128
129// Init implements SplashPage.
130func (s *splashCmp) Init() tea.Cmd {
131 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
132}
133
134// SetSize implements SplashPage.
135func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
136 s.height = height
137 s.width = width
138 s.logoRendered = s.logoBlock()
139 // remove padding, logo height, gap, title space
140 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
141 listWidth := min(60, width)
142 return s.modelList.SetSize(listWidth, s.listHeight)
143}
144
145// Update implements SplashPage.
146func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
147 switch msg := msg.(type) {
148 case tea.WindowSizeMsg:
149 return s, s.SetSize(msg.Width, msg.Height)
150 case tea.KeyPressMsg:
151 switch {
152 case key.Matches(msg, s.keyMap.Back):
153 if s.needsAPIKey {
154 // Go back to model selection
155 s.needsAPIKey = false
156 s.selectedModel = nil
157 return s, nil
158 }
159 case key.Matches(msg, s.keyMap.Select):
160 if s.isOnboarding && !s.needsAPIKey {
161 modelInx := s.modelList.SelectedIndex()
162 items := s.modelList.Items()
163 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
164 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
165 cmd := s.setPreferredModel(selectedItem)
166 s.isOnboarding = false
167 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
168 } else {
169 // Provider not configured, show API key input
170 s.needsAPIKey = true
171 s.selectedModel = &selectedItem
172 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
173 return s, nil
174 }
175 } else if s.needsAPIKey {
176 // Handle API key submission
177 apiKey := s.apiKeyInput.Value()
178 if apiKey != "" {
179 return s, s.saveAPIKeyAndContinue(apiKey)
180 }
181 } else if s.needsProjectInit {
182 return s, s.initializeProject()
183 }
184 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
185 if s.needsProjectInit {
186 s.selectedNo = !s.selectedNo
187 return s, nil
188 }
189 case key.Matches(msg, s.keyMap.Yes):
190 if s.needsProjectInit {
191 return s, s.initializeProject()
192 }
193 case key.Matches(msg, s.keyMap.No):
194 s.selectedNo = true
195 return s, s.initializeProject()
196 default:
197 if s.needsAPIKey {
198 u, cmd := s.apiKeyInput.Update(msg)
199 s.apiKeyInput = u.(*models.APIKeyInput)
200 return s, cmd
201 } else if s.isOnboarding {
202 u, cmd := s.modelList.Update(msg)
203 s.modelList = u
204 return s, cmd
205 }
206 }
207 case tea.PasteMsg:
208 if s.needsAPIKey {
209 u, cmd := s.apiKeyInput.Update(msg)
210 s.apiKeyInput = u.(*models.APIKeyInput)
211 return s, cmd
212 } else if s.isOnboarding {
213 var cmd tea.Cmd
214 s.modelList, cmd = s.modelList.Update(msg)
215 return s, cmd
216 }
217 }
218 return s, nil
219}
220
221func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
222 if s.selectedModel == nil {
223 return util.ReportError(fmt.Errorf("no model selected"))
224 }
225
226 cfg := config.Get()
227 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
228 if err != nil {
229 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
230 }
231
232 // Reset API key state and continue with model selection
233 s.needsAPIKey = false
234 cmd := s.setPreferredModel(*s.selectedModel)
235 s.isOnboarding = false
236 s.selectedModel = nil
237
238 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
239}
240
241func (s *splashCmp) initializeProject() tea.Cmd {
242 s.needsProjectInit = false
243
244 if err := config.MarkProjectInitialized(); err != nil {
245 return util.ReportError(err)
246 }
247 var cmds []tea.Cmd
248
249 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
250 if !s.selectedNo {
251 cmds = append(cmds,
252 util.CmdHandler(chat.SessionClearedMsg{}),
253 util.CmdHandler(chat.SendMsg{
254 Text: prompt.Initialize(),
255 }),
256 )
257 }
258 return tea.Sequence(cmds...)
259}
260
261func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
262 cfg := config.Get()
263 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
264 if model == nil {
265 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
266 }
267
268 selectedModel := config.SelectedModel{
269 Model: selectedItem.Model.ID,
270 Provider: string(selectedItem.Provider.ID),
271 ReasoningEffort: model.DefaultReasoningEffort,
272 MaxTokens: model.DefaultMaxTokens,
273 }
274
275 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
276 if err != nil {
277 return util.ReportError(err)
278 }
279
280 // Now lets automatically setup the small model
281 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
282 if err != nil {
283 return util.ReportError(err)
284 }
285 if knownProvider == nil {
286 // for local provider we just use the same model
287 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
288 if err != nil {
289 return util.ReportError(err)
290 }
291 } else {
292 smallModel := knownProvider.DefaultSmallModelID
293 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
294 // should never happen
295 if model == nil {
296 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
297 if err != nil {
298 return util.ReportError(err)
299 }
300 return nil
301 }
302 smallSelectedModel := config.SelectedModel{
303 Model: smallModel,
304 Provider: string(selectedItem.Provider.ID),
305 ReasoningEffort: model.DefaultReasoningEffort,
306 MaxTokens: model.DefaultMaxTokens,
307 }
308 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
309 if err != nil {
310 return util.ReportError(err)
311 }
312 }
313 cfg.SetupAgents()
314 return nil
315}
316
317func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
318 providers, err := config.Providers()
319 if err != nil {
320 return nil, err
321 }
322 for _, p := range providers {
323 if p.ID == providerID {
324 return &p, nil
325 }
326 }
327 return nil, nil
328}
329
330func (s *splashCmp) isProviderConfigured(providerID string) bool {
331 cfg := config.Get()
332 if _, ok := cfg.Providers[providerID]; ok {
333 return true
334 }
335 return false
336}
337
338func (s *splashCmp) View() string {
339 t := styles.CurrentTheme()
340 var content string
341 if s.needsAPIKey {
342 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
343 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
344 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
345 lipgloss.JoinVertical(
346 lipgloss.Left,
347 apiKeyView,
348 ),
349 )
350 content = lipgloss.JoinVertical(
351 lipgloss.Left,
352 s.logoRendered,
353 apiKeySelector,
354 )
355 } else if s.isOnboarding {
356 modelListView := s.modelList.View()
357 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
358 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
359 lipgloss.JoinVertical(
360 lipgloss.Left,
361 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
362 "",
363 modelListView,
364 ),
365 )
366 content = lipgloss.JoinVertical(
367 lipgloss.Left,
368 s.logoRendered,
369 modelSelector,
370 )
371 } else if s.needsProjectInit {
372 titleStyle := t.S().Base.Foreground(t.FgBase)
373 bodyStyle := t.S().Base.Foreground(t.FgMuted)
374 shortcutStyle := t.S().Base.Foreground(t.Success)
375
376 initText := lipgloss.JoinVertical(
377 lipgloss.Left,
378 titleStyle.Render("Would you like to initialize this project?"),
379 "",
380 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
381 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
382 "",
383 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
384 "",
385 bodyStyle.Render("Would you like to initialize now?"),
386 )
387
388 yesButton := core.SelectableButton(core.ButtonOpts{
389 Text: "Yep!",
390 UnderlineIndex: 0,
391 Selected: !s.selectedNo,
392 })
393
394 noButton := core.SelectableButton(core.ButtonOpts{
395 Text: "Nope",
396 UnderlineIndex: 0,
397 Selected: s.selectedNo,
398 })
399
400 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
401 infoSection := s.infoSection()
402
403 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
404
405 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
406 lipgloss.JoinVertical(
407 lipgloss.Left,
408 initText,
409 "",
410 buttons,
411 ),
412 )
413
414 content = lipgloss.JoinVertical(
415 lipgloss.Left,
416 s.logoRendered,
417 infoSection,
418 initContent,
419 )
420 } else {
421 parts := []string{
422 s.logoRendered,
423 s.infoSection(),
424 }
425 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
426 }
427
428 return t.S().Base.
429 Width(s.width).
430 Height(s.height).
431 PaddingTop(SplashScreenPaddingY).
432 PaddingBottom(SplashScreenPaddingY).
433 Render(content)
434}
435
436func (s *splashCmp) Cursor() *tea.Cursor {
437 if s.needsAPIKey {
438 cursor := s.apiKeyInput.Cursor()
439 if cursor != nil {
440 return s.moveCursor(cursor)
441 }
442 } else if s.isOnboarding {
443 cursor := s.modelList.Cursor()
444 if cursor != nil {
445 return s.moveCursor(cursor)
446 }
447 } else {
448 return nil
449 }
450 return nil
451}
452
453func (s *splashCmp) infoSection() string {
454 t := styles.CurrentTheme()
455 return t.S().Base.PaddingLeft(2).Render(
456 lipgloss.JoinVertical(
457 lipgloss.Left,
458 s.cwd(),
459 "",
460 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
461 "",
462 ),
463 )
464}
465
466func (s *splashCmp) logoBlock() string {
467 t := styles.CurrentTheme()
468 logoStyle := t.S().Base.Padding(0, 2).Width(s.width)
469 if s.width < 40 || s.height < 20 {
470 // If the width is too small, render a smaller version of the logo
471 // NOTE: 20 is not correct because [splashCmp.height] is not the
472 // *actual* window height, instead, it is the height of the splash
473 // component and that depends on other variables like compact mode and
474 // the height of the editor.
475 return logoStyle.Render(
476 logo.SmallRender(s.width - logoStyle.GetHorizontalFrameSize()),
477 )
478 }
479 return logoStyle.Render(
480 logo.Render(version.Version, false, logo.Opts{
481 FieldColor: t.Primary,
482 TitleColorA: t.Secondary,
483 TitleColorB: t.Primary,
484 CharmColor: t.Secondary,
485 VersionColor: t.Primary,
486 Width: s.width - logoStyle.GetHorizontalFrameSize(),
487 }),
488 )
489}
490
491func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
492 if cursor == nil {
493 return nil
494 }
495 // Calculate the correct Y offset based on current state
496 logoHeight := lipgloss.Height(s.logoRendered)
497 if s.needsAPIKey {
498 infoSectionHeight := lipgloss.Height(s.infoSection())
499 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
500 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
501 offset := baseOffset + remainingHeight
502 cursor.Y += offset
503 cursor.X = cursor.X + 1
504 } else if s.isOnboarding {
505 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
506 cursor.Y += offset
507 cursor.X = cursor.X + 1
508 }
509
510 return cursor
511}
512
513func (s *splashCmp) logoGap() int {
514 if s.height > 35 {
515 return LogoGap
516 }
517 return 0
518}
519
520// Bindings implements SplashPage.
521func (s *splashCmp) Bindings() []key.Binding {
522 if s.needsAPIKey {
523 return []key.Binding{
524 s.keyMap.Select,
525 s.keyMap.Back,
526 }
527 } else if s.isOnboarding {
528 return []key.Binding{
529 s.keyMap.Select,
530 s.keyMap.Next,
531 s.keyMap.Previous,
532 }
533 } else if s.needsProjectInit {
534 return []key.Binding{
535 s.keyMap.Select,
536 s.keyMap.Yes,
537 s.keyMap.No,
538 s.keyMap.Tab,
539 s.keyMap.LeftRight,
540 }
541 }
542 return []key.Binding{}
543}
544
545func (s *splashCmp) getMaxInfoWidth() int {
546 return min(s.width-2, 40) // 2 for left padding
547}
548
549func (s *splashCmp) cwd() string {
550 cwd := config.Get().WorkingDir()
551 t := styles.CurrentTheme()
552 homeDir, err := os.UserHomeDir()
553 if err == nil && cwd != homeDir {
554 cwd = strings.ReplaceAll(cwd, homeDir, "~")
555 }
556 maxWidth := s.getMaxInfoWidth()
557 return t.S().Muted.Width(maxWidth).Render(cwd)
558}
559
560func LSPList(maxWidth int) []string {
561 t := styles.CurrentTheme()
562 lspList := []string{}
563 lsp := config.Get().LSP.Sorted()
564 if len(lsp) == 0 {
565 return []string{t.S().Base.Foreground(t.Border).Render("None")}
566 }
567 for _, l := range lsp {
568 iconColor := t.Success
569 if l.LSP.Disabled {
570 iconColor = t.FgMuted
571 }
572 lspList = append(lspList,
573 core.Status(
574 core.StatusOpts{
575 IconColor: iconColor,
576 Title: l.Name,
577 Description: l.LSP.Command,
578 },
579 maxWidth,
580 ),
581 )
582 }
583 return lspList
584}
585
586func (s *splashCmp) lspBlock() string {
587 t := styles.CurrentTheme()
588 maxWidth := s.getMaxInfoWidth() / 2
589 section := t.S().Subtle.Render("LSPs")
590 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
591 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
592 lipgloss.JoinVertical(
593 lipgloss.Left,
594 lspList...,
595 ),
596 )
597}
598
599func MCPList(maxWidth int) []string {
600 t := styles.CurrentTheme()
601 mcpList := []string{}
602 mcps := config.Get().MCP.Sorted()
603 if len(mcps) == 0 {
604 return []string{t.S().Base.Foreground(t.Border).Render("None")}
605 }
606 for _, l := range mcps {
607 iconColor := t.Success
608 if l.MCP.Disabled {
609 iconColor = t.FgMuted
610 }
611 mcpList = append(mcpList,
612 core.Status(
613 core.StatusOpts{
614 IconColor: iconColor,
615 Title: l.Name,
616 Description: l.MCP.Command,
617 },
618 maxWidth,
619 ),
620 )
621 }
622 return mcpList
623}
624
625func (s *splashCmp) mcpBlock() string {
626 t := styles.CurrentTheme()
627 maxWidth := s.getMaxInfoWidth() / 2
628 section := t.S().Subtle.Render("MCPs")
629 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
630 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
631 lipgloss.JoinVertical(
632 lipgloss.Left,
633 mcpList...,
634 ),
635 )
636}
637
638func (s *splashCmp) IsShowingAPIKey() bool {
639 return s.needsAPIKey
640}