1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "groq",
109 "openrouter",
110 }
111 for _, p := range providers {
112 if slices.Contains(simpleProviders, string(p.ID)) {
113 filteredProviders = append(filteredProviders, p)
114 }
115 }
116 s.modelList.SetProviders(filteredProviders)
117 }
118}
119
120func (s *splashCmp) SetProjectInit(needsInit bool) {
121 s.needsProjectInit = needsInit
122}
123
124// GetSize implements SplashPage.
125func (s *splashCmp) GetSize() (int, int) {
126 return s.width, s.height
127}
128
129// Init implements SplashPage.
130func (s *splashCmp) Init() tea.Cmd {
131 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
132}
133
134// SetSize implements SplashPage.
135func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
136 wasSmallScreen := s.isSmallScreen()
137 rerenderLogo := width != s.width
138 s.height = height
139 s.width = width
140 if rerenderLogo || wasSmallScreen != s.isSmallScreen() {
141 s.logoRendered = s.logoBlock()
142 }
143 // remove padding, logo height, gap, title space
144 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
145 listWidth := min(60, width)
146 return s.modelList.SetSize(listWidth, s.listHeight)
147}
148
149// Update implements SplashPage.
150func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
151 switch msg := msg.(type) {
152 case tea.WindowSizeMsg:
153 return s, s.SetSize(msg.Width, msg.Height)
154 case tea.KeyPressMsg:
155 switch {
156 case key.Matches(msg, s.keyMap.Back):
157 if s.needsAPIKey {
158 // Go back to model selection
159 s.needsAPIKey = false
160 s.selectedModel = nil
161 return s, nil
162 }
163 case key.Matches(msg, s.keyMap.Select):
164 if s.isOnboarding && !s.needsAPIKey {
165 modelInx := s.modelList.SelectedIndex()
166 items := s.modelList.Items()
167 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
168 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
169 cmd := s.setPreferredModel(selectedItem)
170 s.isOnboarding = false
171 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
172 } else {
173 // Provider not configured, show API key input
174 s.needsAPIKey = true
175 s.selectedModel = &selectedItem
176 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
177 return s, nil
178 }
179 } else if s.needsAPIKey {
180 // Handle API key submission
181 apiKey := s.apiKeyInput.Value()
182 if apiKey != "" {
183 return s, s.saveAPIKeyAndContinue(apiKey)
184 }
185 } else if s.needsProjectInit {
186 return s, s.initializeProject()
187 }
188 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
189 if s.needsProjectInit {
190 s.selectedNo = !s.selectedNo
191 return s, nil
192 }
193 case key.Matches(msg, s.keyMap.Yes):
194 if s.needsProjectInit {
195 return s, s.initializeProject()
196 }
197 case key.Matches(msg, s.keyMap.No):
198 s.selectedNo = true
199 return s, s.initializeProject()
200 default:
201 if s.needsAPIKey {
202 u, cmd := s.apiKeyInput.Update(msg)
203 s.apiKeyInput = u.(*models.APIKeyInput)
204 return s, cmd
205 } else if s.isOnboarding {
206 u, cmd := s.modelList.Update(msg)
207 s.modelList = u
208 return s, cmd
209 }
210 }
211 case tea.PasteMsg:
212 if s.needsAPIKey {
213 u, cmd := s.apiKeyInput.Update(msg)
214 s.apiKeyInput = u.(*models.APIKeyInput)
215 return s, cmd
216 } else if s.isOnboarding {
217 var cmd tea.Cmd
218 s.modelList, cmd = s.modelList.Update(msg)
219 return s, cmd
220 }
221 }
222 return s, nil
223}
224
225func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
226 if s.selectedModel == nil {
227 return util.ReportError(fmt.Errorf("no model selected"))
228 }
229
230 cfg := config.Get()
231 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
232 if err != nil {
233 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
234 }
235
236 // Reset API key state and continue with model selection
237 s.needsAPIKey = false
238 cmd := s.setPreferredModel(*s.selectedModel)
239 s.isOnboarding = false
240 s.selectedModel = nil
241
242 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
243}
244
245func (s *splashCmp) initializeProject() tea.Cmd {
246 s.needsProjectInit = false
247
248 if err := config.MarkProjectInitialized(); err != nil {
249 return util.ReportError(err)
250 }
251 var cmds []tea.Cmd
252
253 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
254 if !s.selectedNo {
255 cmds = append(cmds,
256 util.CmdHandler(chat.SessionClearedMsg{}),
257 util.CmdHandler(chat.SendMsg{
258 Text: prompt.Initialize(),
259 }),
260 )
261 }
262 return tea.Sequence(cmds...)
263}
264
265func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
266 cfg := config.Get()
267 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
268 if model == nil {
269 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
270 }
271
272 selectedModel := config.SelectedModel{
273 Model: selectedItem.Model.ID,
274 Provider: string(selectedItem.Provider.ID),
275 ReasoningEffort: model.DefaultReasoningEffort,
276 MaxTokens: model.DefaultMaxTokens,
277 }
278
279 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
280 if err != nil {
281 return util.ReportError(err)
282 }
283
284 // Now lets automatically setup the small model
285 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
286 if err != nil {
287 return util.ReportError(err)
288 }
289 if knownProvider == nil {
290 // for local provider we just use the same model
291 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
292 if err != nil {
293 return util.ReportError(err)
294 }
295 } else {
296 smallModel := knownProvider.DefaultSmallModelID
297 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
298 // should never happen
299 if model == nil {
300 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
301 if err != nil {
302 return util.ReportError(err)
303 }
304 return nil
305 }
306 smallSelectedModel := config.SelectedModel{
307 Model: smallModel,
308 Provider: string(selectedItem.Provider.ID),
309 ReasoningEffort: model.DefaultReasoningEffort,
310 MaxTokens: model.DefaultMaxTokens,
311 }
312 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
313 if err != nil {
314 return util.ReportError(err)
315 }
316 }
317 cfg.SetupAgents()
318 return nil
319}
320
321func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
322 providers, err := config.Providers()
323 if err != nil {
324 return nil, err
325 }
326 for _, p := range providers {
327 if p.ID == providerID {
328 return &p, nil
329 }
330 }
331 return nil, nil
332}
333
334func (s *splashCmp) isProviderConfigured(providerID string) bool {
335 cfg := config.Get()
336 if _, ok := cfg.Providers[providerID]; ok {
337 return true
338 }
339 return false
340}
341
342func (s *splashCmp) View() string {
343 t := styles.CurrentTheme()
344 var content string
345 if s.needsAPIKey {
346 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
347 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
348 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
349 lipgloss.JoinVertical(
350 lipgloss.Left,
351 apiKeyView,
352 ),
353 )
354 content = lipgloss.JoinVertical(
355 lipgloss.Left,
356 s.logoRendered,
357 apiKeySelector,
358 )
359 } else if s.isOnboarding {
360 modelListView := s.modelList.View()
361 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
362 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
363 lipgloss.JoinVertical(
364 lipgloss.Left,
365 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
366 "",
367 modelListView,
368 ),
369 )
370 content = lipgloss.JoinVertical(
371 lipgloss.Left,
372 s.logoRendered,
373 modelSelector,
374 )
375 } else if s.needsProjectInit {
376 titleStyle := t.S().Base.Foreground(t.FgBase)
377 bodyStyle := t.S().Base.Foreground(t.FgMuted)
378 shortcutStyle := t.S().Base.Foreground(t.Success)
379
380 initText := lipgloss.JoinVertical(
381 lipgloss.Left,
382 titleStyle.Render("Would you like to initialize this project?"),
383 "",
384 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
385 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
386 "",
387 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
388 "",
389 bodyStyle.Render("Would you like to initialize now?"),
390 )
391
392 yesButton := core.SelectableButton(core.ButtonOpts{
393 Text: "Yep!",
394 UnderlineIndex: 0,
395 Selected: !s.selectedNo,
396 })
397
398 noButton := core.SelectableButton(core.ButtonOpts{
399 Text: "Nope",
400 UnderlineIndex: 0,
401 Selected: s.selectedNo,
402 })
403
404 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
405 infoSection := s.infoSection()
406
407 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
408
409 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
410 lipgloss.JoinVertical(
411 lipgloss.Left,
412 initText,
413 "",
414 buttons,
415 ),
416 )
417
418 content = lipgloss.JoinVertical(
419 lipgloss.Left,
420 s.logoRendered,
421 infoSection,
422 initContent,
423 )
424 } else {
425 parts := []string{
426 s.logoRendered,
427 s.infoSection(),
428 }
429 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
430 }
431
432 return t.S().Base.
433 Width(s.width).
434 Height(s.height).
435 PaddingTop(SplashScreenPaddingY).
436 PaddingBottom(SplashScreenPaddingY).
437 Render(content)
438}
439
440func (s *splashCmp) Cursor() *tea.Cursor {
441 if s.needsAPIKey {
442 cursor := s.apiKeyInput.Cursor()
443 if cursor != nil {
444 return s.moveCursor(cursor)
445 }
446 } else if s.isOnboarding {
447 cursor := s.modelList.Cursor()
448 if cursor != nil {
449 return s.moveCursor(cursor)
450 }
451 } else {
452 return nil
453 }
454 return nil
455}
456
457func (s *splashCmp) isSmallScreen() bool {
458 // Consider a screen small if either the width is less than 40 or if the
459 // height is less than 20
460 return s.width < 40 || s.height < 20
461}
462
463func (s *splashCmp) infoSection() string {
464 t := styles.CurrentTheme()
465 infoStyle := t.S().Base.PaddingLeft(2)
466 if s.isSmallScreen() {
467 infoStyle = infoStyle.MarginTop(1)
468 }
469 return infoStyle.Render(
470 lipgloss.JoinVertical(
471 lipgloss.Left,
472 s.cwd(),
473 "",
474 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
475 "",
476 ),
477 )
478}
479
480func (s *splashCmp) logoBlock() string {
481 t := styles.CurrentTheme()
482 logoStyle := t.S().Base.Padding(0, 2).Width(s.width)
483 if s.isSmallScreen() {
484 // If the width is too small, render a smaller version of the logo
485 // NOTE: 20 is not correct because [splashCmp.height] is not the
486 // *actual* window height, instead, it is the height of the splash
487 // component and that depends on other variables like compact mode and
488 // the height of the editor.
489 return logoStyle.Render(
490 logo.SmallRender(s.width - logoStyle.GetHorizontalFrameSize()),
491 )
492 }
493 return logoStyle.Render(
494 logo.Render(version.Version, false, logo.Opts{
495 FieldColor: t.Primary,
496 TitleColorA: t.Secondary,
497 TitleColorB: t.Primary,
498 CharmColor: t.Secondary,
499 VersionColor: t.Primary,
500 Width: s.width - logoStyle.GetHorizontalFrameSize(),
501 }),
502 )
503}
504
505func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
506 if cursor == nil {
507 return nil
508 }
509 // Calculate the correct Y offset based on current state
510 logoHeight := lipgloss.Height(s.logoRendered)
511 if s.needsAPIKey {
512 infoSectionHeight := lipgloss.Height(s.infoSection())
513 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
514 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
515 offset := baseOffset + remainingHeight
516 cursor.Y += offset
517 cursor.X = cursor.X + 1
518 } else if s.isOnboarding {
519 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
520 cursor.Y += offset
521 cursor.X = cursor.X + 1
522 }
523
524 return cursor
525}
526
527func (s *splashCmp) logoGap() int {
528 if s.height > 35 {
529 return LogoGap
530 }
531 return 0
532}
533
534// Bindings implements SplashPage.
535func (s *splashCmp) Bindings() []key.Binding {
536 if s.needsAPIKey {
537 return []key.Binding{
538 s.keyMap.Select,
539 s.keyMap.Back,
540 }
541 } else if s.isOnboarding {
542 return []key.Binding{
543 s.keyMap.Select,
544 s.keyMap.Next,
545 s.keyMap.Previous,
546 }
547 } else if s.needsProjectInit {
548 return []key.Binding{
549 s.keyMap.Select,
550 s.keyMap.Yes,
551 s.keyMap.No,
552 s.keyMap.Tab,
553 s.keyMap.LeftRight,
554 }
555 }
556 return []key.Binding{}
557}
558
559func (s *splashCmp) getMaxInfoWidth() int {
560 return min(s.width-2, 40) // 2 for left padding
561}
562
563func (s *splashCmp) cwd() string {
564 cwd := config.Get().WorkingDir()
565 t := styles.CurrentTheme()
566 homeDir, err := os.UserHomeDir()
567 if err == nil && cwd != homeDir {
568 cwd = strings.ReplaceAll(cwd, homeDir, "~")
569 }
570 maxWidth := s.getMaxInfoWidth()
571 return t.S().Muted.Width(maxWidth).Render(cwd)
572}
573
574func LSPList(maxWidth int) []string {
575 t := styles.CurrentTheme()
576 lspList := []string{}
577 lsp := config.Get().LSP.Sorted()
578 if len(lsp) == 0 {
579 return []string{t.S().Base.Foreground(t.Border).Render("None")}
580 }
581 for _, l := range lsp {
582 iconColor := t.Success
583 if l.LSP.Disabled {
584 iconColor = t.FgMuted
585 }
586 lspList = append(lspList,
587 core.Status(
588 core.StatusOpts{
589 IconColor: iconColor,
590 Title: l.Name,
591 Description: l.LSP.Command,
592 },
593 maxWidth,
594 ),
595 )
596 }
597 return lspList
598}
599
600func (s *splashCmp) lspBlock() string {
601 t := styles.CurrentTheme()
602 maxWidth := s.getMaxInfoWidth() / 2
603 section := t.S().Subtle.Render("LSPs")
604 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
605 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
606 lipgloss.JoinVertical(
607 lipgloss.Left,
608 lspList...,
609 ),
610 )
611}
612
613func MCPList(maxWidth int) []string {
614 t := styles.CurrentTheme()
615 mcpList := []string{}
616 mcps := config.Get().MCP.Sorted()
617 if len(mcps) == 0 {
618 return []string{t.S().Base.Foreground(t.Border).Render("None")}
619 }
620 for _, l := range mcps {
621 iconColor := t.Success
622 if l.MCP.Disabled {
623 iconColor = t.FgMuted
624 }
625 mcpList = append(mcpList,
626 core.Status(
627 core.StatusOpts{
628 IconColor: iconColor,
629 Title: l.Name,
630 Description: l.MCP.Command,
631 },
632 maxWidth,
633 ),
634 )
635 }
636 return mcpList
637}
638
639func (s *splashCmp) mcpBlock() string {
640 t := styles.CurrentTheme()
641 maxWidth := s.getMaxInfoWidth() / 2
642 section := t.S().Subtle.Render("MCPs")
643 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
644 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
645 lipgloss.JoinVertical(
646 lipgloss.Left,
647 mcpList...,
648 ),
649 )
650}
651
652func (s *splashCmp) IsShowingAPIKey() bool {
653 return s.needsAPIKey
654}