1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "groq",
109 "openrouter",
110 }
111 for _, p := range providers {
112 if slices.Contains(simpleProviders, string(p.ID)) {
113 filteredProviders = append(filteredProviders, p)
114 }
115 }
116 s.modelList.SetProviders(filteredProviders)
117 }
118}
119
120func (s *splashCmp) SetProjectInit(needsInit bool) {
121 s.needsProjectInit = needsInit
122}
123
124// GetSize implements SplashPage.
125func (s *splashCmp) GetSize() (int, int) {
126 return s.width, s.height
127}
128
129// Init implements SplashPage.
130func (s *splashCmp) Init() tea.Cmd {
131 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
132}
133
134// SetSize implements SplashPage.
135func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
136 s.height = height
137 if width != s.width {
138 s.width = width
139 s.logoRendered = s.logoBlock()
140 }
141 // remove padding, logo height, gap, title space
142 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
143 listWidth := min(60, width)
144 return s.modelList.SetSize(listWidth, s.listHeight)
145}
146
147// Update implements SplashPage.
148func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
149 switch msg := msg.(type) {
150 case tea.WindowSizeMsg:
151 return s, s.SetSize(msg.Width, msg.Height)
152 case tea.KeyPressMsg:
153 switch {
154 case key.Matches(msg, s.keyMap.Back):
155 if s.needsAPIKey {
156 // Go back to model selection
157 s.needsAPIKey = false
158 s.selectedModel = nil
159 return s, nil
160 }
161 case key.Matches(msg, s.keyMap.Select):
162 if s.isOnboarding && !s.needsAPIKey {
163 modelInx := s.modelList.SelectedIndex()
164 items := s.modelList.Items()
165 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
166 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
167 cmd := s.setPreferredModel(selectedItem)
168 s.isOnboarding = false
169 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
170 } else {
171 // Provider not configured, show API key input
172 s.needsAPIKey = true
173 s.selectedModel = &selectedItem
174 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175 return s, nil
176 }
177 } else if s.needsAPIKey {
178 // Handle API key submission
179 apiKey := s.apiKeyInput.Value()
180 if apiKey != "" {
181 return s, s.saveAPIKeyAndContinue(apiKey)
182 }
183 } else if s.needsProjectInit {
184 return s, s.initializeProject()
185 }
186 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
187 if s.needsProjectInit {
188 s.selectedNo = !s.selectedNo
189 return s, nil
190 }
191 case key.Matches(msg, s.keyMap.Yes):
192 if s.needsProjectInit {
193 return s, s.initializeProject()
194 }
195 case key.Matches(msg, s.keyMap.No):
196 s.selectedNo = true
197 return s, s.initializeProject()
198 default:
199 if s.needsAPIKey {
200 u, cmd := s.apiKeyInput.Update(msg)
201 s.apiKeyInput = u.(*models.APIKeyInput)
202 return s, cmd
203 } else if s.isOnboarding {
204 u, cmd := s.modelList.Update(msg)
205 s.modelList = u
206 return s, cmd
207 }
208 }
209 case tea.PasteMsg:
210 if s.needsAPIKey {
211 u, cmd := s.apiKeyInput.Update(msg)
212 s.apiKeyInput = u.(*models.APIKeyInput)
213 return s, cmd
214 } else if s.isOnboarding {
215 var cmd tea.Cmd
216 s.modelList, cmd = s.modelList.Update(msg)
217 return s, cmd
218 }
219 }
220 return s, nil
221}
222
223func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
224 if s.selectedModel == nil {
225 return util.ReportError(fmt.Errorf("no model selected"))
226 }
227
228 cfg := config.Get()
229 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
230 if err != nil {
231 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
232 }
233
234 // Reset API key state and continue with model selection
235 s.needsAPIKey = false
236 cmd := s.setPreferredModel(*s.selectedModel)
237 s.isOnboarding = false
238 s.selectedModel = nil
239
240 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
241}
242
243func (s *splashCmp) initializeProject() tea.Cmd {
244 s.needsProjectInit = false
245
246 if err := config.MarkProjectInitialized(); err != nil {
247 return util.ReportError(err)
248 }
249 var cmds []tea.Cmd
250
251 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
252 if !s.selectedNo {
253 cmds = append(cmds,
254 util.CmdHandler(chat.SessionClearedMsg{}),
255 util.CmdHandler(chat.SendMsg{
256 Text: prompt.Initialize(),
257 }),
258 )
259 }
260 return tea.Sequence(cmds...)
261}
262
263func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
264 cfg := config.Get()
265 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
266 if model == nil {
267 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
268 }
269
270 selectedModel := config.SelectedModel{
271 Model: selectedItem.Model.ID,
272 Provider: string(selectedItem.Provider.ID),
273 ReasoningEffort: model.DefaultReasoningEffort,
274 MaxTokens: model.DefaultMaxTokens,
275 }
276
277 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
278 if err != nil {
279 return util.ReportError(err)
280 }
281
282 // Now lets automatically setup the small model
283 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
284 if err != nil {
285 return util.ReportError(err)
286 }
287 if knownProvider == nil {
288 // for local provider we just use the same model
289 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
290 if err != nil {
291 return util.ReportError(err)
292 }
293 } else {
294 smallModel := knownProvider.DefaultSmallModelID
295 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
296 // should never happen
297 if model == nil {
298 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
299 if err != nil {
300 return util.ReportError(err)
301 }
302 return nil
303 }
304 smallSelectedModel := config.SelectedModel{
305 Model: smallModel,
306 Provider: string(selectedItem.Provider.ID),
307 ReasoningEffort: model.DefaultReasoningEffort,
308 MaxTokens: model.DefaultMaxTokens,
309 }
310 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
311 if err != nil {
312 return util.ReportError(err)
313 }
314 }
315 cfg.SetupAgents()
316 return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320 providers, err := config.Providers()
321 if err != nil {
322 return nil, err
323 }
324 for _, p := range providers {
325 if p.ID == providerID {
326 return &p, nil
327 }
328 }
329 return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333 cfg := config.Get()
334 if _, ok := cfg.Providers[providerID]; ok {
335 return true
336 }
337 return false
338}
339
340func (s *splashCmp) View() string {
341 t := styles.CurrentTheme()
342 var content string
343 if s.needsAPIKey {
344 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
345 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
346 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
347 lipgloss.JoinVertical(
348 lipgloss.Left,
349 apiKeyView,
350 ),
351 )
352 content = lipgloss.JoinVertical(
353 lipgloss.Left,
354 s.logoRendered,
355 apiKeySelector,
356 )
357 } else if s.isOnboarding {
358 modelListView := s.modelList.View()
359 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
360 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
361 lipgloss.JoinVertical(
362 lipgloss.Left,
363 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
364 "",
365 modelListView,
366 ),
367 )
368 content = lipgloss.JoinVertical(
369 lipgloss.Left,
370 s.logoRendered,
371 modelSelector,
372 )
373 } else if s.needsProjectInit {
374 titleStyle := t.S().Base.Foreground(t.FgBase)
375 bodyStyle := t.S().Base.Foreground(t.FgMuted)
376 shortcutStyle := t.S().Base.Foreground(t.Success)
377
378 initText := lipgloss.JoinVertical(
379 lipgloss.Left,
380 titleStyle.Render("Would you like to initialize this project?"),
381 "",
382 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
383 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
384 "",
385 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
386 "",
387 bodyStyle.Render("Would you like to initialize now?"),
388 )
389
390 yesButton := core.SelectableButton(core.ButtonOpts{
391 Text: "Yep!",
392 UnderlineIndex: 0,
393 Selected: !s.selectedNo,
394 })
395
396 noButton := core.SelectableButton(core.ButtonOpts{
397 Text: "Nope",
398 UnderlineIndex: 0,
399 Selected: s.selectedNo,
400 })
401
402 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
403 infoSection := s.infoSection()
404
405 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
406
407 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
408 lipgloss.JoinVertical(
409 lipgloss.Left,
410 initText,
411 "",
412 buttons,
413 ),
414 )
415
416 content = lipgloss.JoinVertical(
417 lipgloss.Left,
418 s.logoRendered,
419 infoSection,
420 initContent,
421 )
422 } else {
423 parts := []string{
424 s.logoRendered,
425 s.infoSection(),
426 }
427 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
428 }
429
430 return t.S().Base.
431 Width(s.width).
432 Height(s.height).
433 PaddingTop(SplashScreenPaddingY).
434 PaddingBottom(SplashScreenPaddingY).
435 Render(content)
436}
437
438func (s *splashCmp) Cursor() *tea.Cursor {
439 if s.needsAPIKey {
440 cursor := s.apiKeyInput.Cursor()
441 if cursor != nil {
442 return s.moveCursor(cursor)
443 }
444 } else if s.isOnboarding {
445 cursor := s.modelList.Cursor()
446 if cursor != nil {
447 return s.moveCursor(cursor)
448 }
449 } else {
450 return nil
451 }
452 return nil
453}
454
455func (s *splashCmp) infoSection() string {
456 t := styles.CurrentTheme()
457 return t.S().Base.PaddingLeft(2).Render(
458 lipgloss.JoinVertical(
459 lipgloss.Left,
460 s.cwd(),
461 "",
462 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
463 "",
464 ),
465 )
466}
467
468func (s *splashCmp) logoBlock() string {
469 t := styles.CurrentTheme()
470 return t.S().Base.Padding(0, 2).Width(s.width).Render(
471 logo.Render(version.Version, false, logo.Opts{
472 FieldColor: t.Primary,
473 TitleColorA: t.Secondary,
474 TitleColorB: t.Primary,
475 CharmColor: t.Secondary,
476 VersionColor: t.Primary,
477 Width: s.width - 4,
478 }),
479 )
480}
481
482func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
483 if cursor == nil {
484 return nil
485 }
486 // Calculate the correct Y offset based on current state
487 logoHeight := lipgloss.Height(s.logoRendered)
488 if s.needsAPIKey {
489 infoSectionHeight := lipgloss.Height(s.infoSection())
490 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
491 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
492 offset := baseOffset + remainingHeight
493 cursor.Y += offset
494 cursor.X = cursor.X + 1
495 } else if s.isOnboarding {
496 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
497 cursor.Y += offset
498 cursor.X = cursor.X + 1
499 }
500
501 return cursor
502}
503
504func (s *splashCmp) logoGap() int {
505 if s.height > 35 {
506 return LogoGap
507 }
508 return 0
509}
510
511// Bindings implements SplashPage.
512func (s *splashCmp) Bindings() []key.Binding {
513 if s.needsAPIKey {
514 return []key.Binding{
515 s.keyMap.Select,
516 s.keyMap.Back,
517 }
518 } else if s.isOnboarding {
519 return []key.Binding{
520 s.keyMap.Select,
521 s.keyMap.Next,
522 s.keyMap.Previous,
523 }
524 } else if s.needsProjectInit {
525 return []key.Binding{
526 s.keyMap.Select,
527 s.keyMap.Yes,
528 s.keyMap.No,
529 s.keyMap.Tab,
530 s.keyMap.LeftRight,
531 }
532 }
533 return []key.Binding{}
534}
535
536func (s *splashCmp) getMaxInfoWidth() int {
537 return min(s.width-2, 40) // 2 for left padding
538}
539
540func (s *splashCmp) cwd() string {
541 cwd := config.Get().WorkingDir()
542 t := styles.CurrentTheme()
543 homeDir, err := os.UserHomeDir()
544 if err == nil && cwd != homeDir {
545 cwd = strings.ReplaceAll(cwd, homeDir, "~")
546 }
547 maxWidth := s.getMaxInfoWidth()
548 return t.S().Muted.Width(maxWidth).Render(cwd)
549}
550
551func LSPList(maxWidth int) []string {
552 t := styles.CurrentTheme()
553 lspList := []string{}
554 lsp := config.Get().LSP.Sorted()
555 if len(lsp) == 0 {
556 return []string{t.S().Base.Foreground(t.Border).Render("None")}
557 }
558 for _, l := range lsp {
559 iconColor := t.Success
560 if l.LSP.Disabled {
561 iconColor = t.FgMuted
562 }
563 lspList = append(lspList,
564 core.Status(
565 core.StatusOpts{
566 IconColor: iconColor,
567 Title: l.Name,
568 Description: l.LSP.Command,
569 },
570 maxWidth,
571 ),
572 )
573 }
574 return lspList
575}
576
577func (s *splashCmp) lspBlock() string {
578 t := styles.CurrentTheme()
579 maxWidth := s.getMaxInfoWidth() / 2
580 section := t.S().Subtle.Render("LSPs")
581 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
582 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
583 lipgloss.JoinVertical(
584 lipgloss.Left,
585 lspList...,
586 ),
587 )
588}
589
590func MCPList(maxWidth int) []string {
591 t := styles.CurrentTheme()
592 mcpList := []string{}
593 mcps := config.Get().MCP.Sorted()
594 if len(mcps) == 0 {
595 return []string{t.S().Base.Foreground(t.Border).Render("None")}
596 }
597 for _, l := range mcps {
598 iconColor := t.Success
599 if l.MCP.Disabled {
600 iconColor = t.FgMuted
601 }
602 mcpList = append(mcpList,
603 core.Status(
604 core.StatusOpts{
605 IconColor: iconColor,
606 Title: l.Name,
607 Description: l.MCP.Command,
608 },
609 maxWidth,
610 ),
611 )
612 }
613 return mcpList
614}
615
616func (s *splashCmp) mcpBlock() string {
617 t := styles.CurrentTheme()
618 maxWidth := s.getMaxInfoWidth() / 2
619 section := t.S().Subtle.Render("MCPs")
620 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
621 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
622 lipgloss.JoinVertical(
623 lipgloss.Left,
624 mcpList...,
625 ),
626 )
627}
628
629func (s *splashCmp) IsShowingAPIKey() bool {
630 return s.needsAPIKey
631}