1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "openrouter",
109 }
110 for _, p := range providers {
111 if slices.Contains(simpleProviders, string(p.ID)) {
112 filteredProviders = append(filteredProviders, p)
113 }
114 }
115 s.modelList.SetProviders(filteredProviders)
116 }
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120 s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125 return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135 s.height = height
136 if width != s.width {
137 s.width = width
138 s.logoRendered = s.logoBlock()
139 }
140 // remove padding, logo height, gap, title space
141 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142 listWidth := min(60, width)
143 return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148 switch msg := msg.(type) {
149 case tea.WindowSizeMsg:
150 return s, s.SetSize(msg.Width, msg.Height)
151 case tea.KeyPressMsg:
152 switch {
153 case key.Matches(msg, s.keyMap.Back):
154 if s.needsAPIKey {
155 // Go back to model selection
156 s.needsAPIKey = false
157 s.selectedModel = nil
158 return s, nil
159 }
160 case key.Matches(msg, s.keyMap.Select):
161 if s.isOnboarding && !s.needsAPIKey {
162 modelInx := s.modelList.SelectedIndex()
163 items := s.modelList.Items()
164 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166 cmd := s.setPreferredModel(selectedItem)
167 s.isOnboarding = false
168 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169 } else {
170 // Provider not configured, show API key input
171 s.needsAPIKey = true
172 s.selectedModel = &selectedItem
173 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174 return s, nil
175 }
176 } else if s.needsAPIKey {
177 // Handle API key submission
178 apiKey := s.apiKeyInput.Value()
179 if apiKey != "" {
180 return s, s.saveAPIKeyAndContinue(apiKey)
181 }
182 } else if s.needsProjectInit {
183 return s, s.initializeProject()
184 }
185 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186 if s.needsProjectInit {
187 s.selectedNo = !s.selectedNo
188 return s, nil
189 }
190 case key.Matches(msg, s.keyMap.Yes):
191 if s.needsProjectInit {
192 return s, s.initializeProject()
193 }
194 case key.Matches(msg, s.keyMap.No):
195 s.selectedNo = true
196 return s, s.initializeProject()
197 default:
198 if s.needsAPIKey {
199 u, cmd := s.apiKeyInput.Update(msg)
200 s.apiKeyInput = u.(*models.APIKeyInput)
201 return s, cmd
202 } else if s.isOnboarding {
203 u, cmd := s.modelList.Update(msg)
204 s.modelList = u
205 return s, cmd
206 }
207 }
208 case tea.PasteMsg:
209 if s.needsAPIKey {
210 u, cmd := s.apiKeyInput.Update(msg)
211 s.apiKeyInput = u.(*models.APIKeyInput)
212 return s, cmd
213 } else if s.isOnboarding {
214 var cmd tea.Cmd
215 s.modelList, cmd = s.modelList.Update(msg)
216 return s, cmd
217 }
218 }
219 return s, nil
220}
221
222func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
223 if s.selectedModel == nil {
224 return util.ReportError(fmt.Errorf("no model selected"))
225 }
226
227 cfg := config.Get()
228 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
229 if err != nil {
230 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
231 }
232
233 // Reset API key state and continue with model selection
234 s.needsAPIKey = false
235 cmd := s.setPreferredModel(*s.selectedModel)
236 s.isOnboarding = false
237 s.selectedModel = nil
238
239 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
240}
241
242func (s *splashCmp) initializeProject() tea.Cmd {
243 s.needsProjectInit = false
244
245 if err := config.MarkProjectInitialized(); err != nil {
246 return util.ReportError(err)
247 }
248 var cmds []tea.Cmd
249
250 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
251 if !s.selectedNo {
252 cmds = append(cmds,
253 util.CmdHandler(chat.SessionClearedMsg{}),
254 util.CmdHandler(chat.SendMsg{
255 Text: prompt.Initialize(),
256 }),
257 )
258 }
259 return tea.Sequence(cmds...)
260}
261
262func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
263 cfg := config.Get()
264 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
265 if model == nil {
266 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
267 }
268
269 selectedModel := config.SelectedModel{
270 Model: selectedItem.Model.ID,
271 Provider: string(selectedItem.Provider.ID),
272 ReasoningEffort: model.DefaultReasoningEffort,
273 MaxTokens: model.DefaultMaxTokens,
274 }
275
276 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
277 if err != nil {
278 return util.ReportError(err)
279 }
280
281 // Now lets automatically setup the small model
282 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
283 if err != nil {
284 return util.ReportError(err)
285 }
286 if knownProvider == nil {
287 // for local provider we just use the same model
288 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
289 if err != nil {
290 return util.ReportError(err)
291 }
292 } else {
293 smallModel := knownProvider.DefaultSmallModelID
294 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
295 // should never happen
296 if model == nil {
297 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
298 if err != nil {
299 return util.ReportError(err)
300 }
301 return nil
302 }
303 smallSelectedModel := config.SelectedModel{
304 Model: smallModel,
305 Provider: string(selectedItem.Provider.ID),
306 ReasoningEffort: model.DefaultReasoningEffort,
307 MaxTokens: model.DefaultMaxTokens,
308 }
309 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
310 if err != nil {
311 return util.ReportError(err)
312 }
313 }
314 return nil
315}
316
317func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
318 providers, err := config.Providers()
319 if err != nil {
320 return nil, err
321 }
322 for _, p := range providers {
323 if p.ID == providerID {
324 return &p, nil
325 }
326 }
327 return nil, nil
328}
329
330func (s *splashCmp) isProviderConfigured(providerID string) bool {
331 cfg := config.Get()
332 if _, ok := cfg.Providers[providerID]; ok {
333 return true
334 }
335 return false
336}
337
338func (s *splashCmp) View() string {
339 t := styles.CurrentTheme()
340 var content string
341 if s.needsAPIKey {
342 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
343 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
344 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
345 lipgloss.JoinVertical(
346 lipgloss.Left,
347 apiKeyView,
348 ),
349 )
350 content = lipgloss.JoinVertical(
351 lipgloss.Left,
352 s.logoRendered,
353 apiKeySelector,
354 )
355 } else if s.isOnboarding {
356 modelListView := s.modelList.View()
357 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
358 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
359 lipgloss.JoinVertical(
360 lipgloss.Left,
361 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
362 "",
363 modelListView,
364 ),
365 )
366 content = lipgloss.JoinVertical(
367 lipgloss.Left,
368 s.logoRendered,
369 modelSelector,
370 )
371 } else if s.needsProjectInit {
372 titleStyle := t.S().Base.Foreground(t.FgBase)
373 bodyStyle := t.S().Base.Foreground(t.FgMuted)
374 shortcutStyle := t.S().Base.Foreground(t.Success)
375
376 initText := lipgloss.JoinVertical(
377 lipgloss.Left,
378 titleStyle.Render("Would you like to initialize this project?"),
379 "",
380 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
381 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
382 "",
383 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
384 "",
385 bodyStyle.Render("Would you like to initialize now?"),
386 )
387
388 yesButton := core.SelectableButton(core.ButtonOpts{
389 Text: "Yep!",
390 UnderlineIndex: 0,
391 Selected: !s.selectedNo,
392 })
393
394 noButton := core.SelectableButton(core.ButtonOpts{
395 Text: "Nope",
396 UnderlineIndex: 0,
397 Selected: s.selectedNo,
398 })
399
400 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
401 infoSection := s.infoSection()
402
403 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
404
405 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
406 lipgloss.JoinVertical(
407 lipgloss.Left,
408 initText,
409 "",
410 buttons,
411 ),
412 )
413
414 content = lipgloss.JoinVertical(
415 lipgloss.Left,
416 s.logoRendered,
417 infoSection,
418 initContent,
419 )
420 } else {
421 parts := []string{
422 s.logoRendered,
423 s.infoSection(),
424 }
425 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
426 }
427
428 return t.S().Base.
429 Width(s.width).
430 Height(s.height).
431 PaddingTop(SplashScreenPaddingY).
432 PaddingBottom(SplashScreenPaddingY).
433 Render(content)
434}
435
436func (s *splashCmp) Cursor() *tea.Cursor {
437 if s.needsAPIKey {
438 cursor := s.apiKeyInput.Cursor()
439 if cursor != nil {
440 return s.moveCursor(cursor)
441 }
442 } else if s.isOnboarding {
443 cursor := s.modelList.Cursor()
444 if cursor != nil {
445 return s.moveCursor(cursor)
446 }
447 } else {
448 return nil
449 }
450 return nil
451}
452
453func (s *splashCmp) infoSection() string {
454 t := styles.CurrentTheme()
455 return t.S().Base.PaddingLeft(2).Render(
456 lipgloss.JoinVertical(
457 lipgloss.Left,
458 s.cwd(),
459 "",
460 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
461 "",
462 ),
463 )
464}
465
466func (s *splashCmp) logoBlock() string {
467 t := styles.CurrentTheme()
468 return t.S().Base.Padding(0, 2).Width(s.width).Render(
469 logo.Render(version.Version, false, logo.Opts{
470 FieldColor: t.Primary,
471 TitleColorA: t.Secondary,
472 TitleColorB: t.Primary,
473 CharmColor: t.Secondary,
474 VersionColor: t.Primary,
475 Width: s.width - 4,
476 }),
477 )
478}
479
480func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
481 if cursor == nil {
482 return nil
483 }
484 // Calculate the correct Y offset based on current state
485 logoHeight := lipgloss.Height(s.logoRendered)
486 if s.needsAPIKey {
487 infoSectionHeight := lipgloss.Height(s.infoSection())
488 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
489 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
490 offset := baseOffset + remainingHeight
491 cursor.Y += offset
492 cursor.X = cursor.X + 1
493 } else if s.isOnboarding {
494 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
495 cursor.Y += offset
496 cursor.X = cursor.X + 1
497 }
498
499 return cursor
500}
501
502func (s *splashCmp) logoGap() int {
503 if s.height > 35 {
504 return LogoGap
505 }
506 return 0
507}
508
509// Bindings implements SplashPage.
510func (s *splashCmp) Bindings() []key.Binding {
511 if s.needsAPIKey {
512 return []key.Binding{
513 s.keyMap.Select,
514 s.keyMap.Back,
515 }
516 } else if s.isOnboarding {
517 return []key.Binding{
518 s.keyMap.Select,
519 s.keyMap.Next,
520 s.keyMap.Previous,
521 }
522 } else if s.needsProjectInit {
523 return []key.Binding{
524 s.keyMap.Select,
525 s.keyMap.Yes,
526 s.keyMap.No,
527 s.keyMap.Tab,
528 s.keyMap.LeftRight,
529 }
530 }
531 return []key.Binding{}
532}
533
534func (s *splashCmp) getMaxInfoWidth() int {
535 return min(s.width-2, 40) // 2 for left padding
536}
537
538func (s *splashCmp) cwd() string {
539 cwd := config.Get().WorkingDir()
540 t := styles.CurrentTheme()
541 homeDir, err := os.UserHomeDir()
542 if err == nil && cwd != homeDir {
543 cwd = strings.ReplaceAll(cwd, homeDir, "~")
544 }
545 maxWidth := s.getMaxInfoWidth()
546 return t.S().Muted.Width(maxWidth).Render(cwd)
547}
548
549func LSPList(maxWidth int) []string {
550 t := styles.CurrentTheme()
551 lspList := []string{}
552 lsp := config.Get().LSP.Sorted()
553 if len(lsp) == 0 {
554 return []string{t.S().Base.Foreground(t.Border).Render("None")}
555 }
556 for _, l := range lsp {
557 iconColor := t.Success
558 if l.LSP.Disabled {
559 iconColor = t.FgMuted
560 }
561 lspList = append(lspList,
562 core.Status(
563 core.StatusOpts{
564 IconColor: iconColor,
565 Title: l.Name,
566 Description: l.LSP.Command,
567 },
568 maxWidth,
569 ),
570 )
571 }
572 return lspList
573}
574
575func (s *splashCmp) lspBlock() string {
576 t := styles.CurrentTheme()
577 maxWidth := s.getMaxInfoWidth() / 2
578 section := t.S().Subtle.Render("LSPs")
579 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
580 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
581 lipgloss.JoinVertical(
582 lipgloss.Left,
583 lspList...,
584 ),
585 )
586}
587
588func MCPList(maxWidth int) []string {
589 t := styles.CurrentTheme()
590 mcpList := []string{}
591 mcps := config.Get().MCP.Sorted()
592 if len(mcps) == 0 {
593 return []string{t.S().Base.Foreground(t.Border).Render("None")}
594 }
595 for _, l := range mcps {
596 iconColor := t.Success
597 if l.MCP.Disabled {
598 iconColor = t.FgMuted
599 }
600 mcpList = append(mcpList,
601 core.Status(
602 core.StatusOpts{
603 IconColor: iconColor,
604 Title: l.Name,
605 Description: l.MCP.Command,
606 },
607 maxWidth,
608 ),
609 )
610 }
611 return mcpList
612}
613
614func (s *splashCmp) mcpBlock() string {
615 t := styles.CurrentTheme()
616 maxWidth := s.getMaxInfoWidth() / 2
617 section := t.S().Subtle.Render("MCPs")
618 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
619 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
620 lipgloss.JoinVertical(
621 lipgloss.Left,
622 mcpList...,
623 ),
624 )
625}
626
627func (s *splashCmp) IsShowingAPIKey() bool {
628 return s.needsAPIKey
629}