1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "openrouter",
109 }
110 for _, p := range providers {
111 if slices.Contains(simpleProviders, string(p.ID)) {
112 filteredProviders = append(filteredProviders, p)
113 }
114 }
115 s.modelList.SetProviders(filteredProviders)
116 }
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120 s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125 return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135 s.height = height
136 if width != s.width {
137 s.width = width
138 s.logoRendered = s.logoBlock()
139 }
140 // remove padding, logo height, gap, title space
141 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142 listWidth := min(60, width)
143 return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148 switch msg := msg.(type) {
149 case tea.WindowSizeMsg:
150 return s, s.SetSize(msg.Width, msg.Height)
151 case tea.KeyPressMsg:
152 switch {
153 case key.Matches(msg, s.keyMap.Back):
154 if s.needsAPIKey {
155 // Go back to model selection
156 s.needsAPIKey = false
157 s.selectedModel = nil
158 return s, nil
159 }
160 case key.Matches(msg, s.keyMap.Select):
161 if s.isOnboarding && !s.needsAPIKey {
162 modelInx := s.modelList.SelectedIndex()
163 items := s.modelList.Items()
164 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166 cmd := s.setPreferredModel(selectedItem)
167 s.isOnboarding = false
168 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169 } else {
170 // Provider not configured, show API key input
171 s.needsAPIKey = true
172 s.selectedModel = &selectedItem
173 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174 return s, nil
175 }
176 } else if s.needsAPIKey {
177 // Handle API key submission
178 apiKey := s.apiKeyInput.Value()
179 if apiKey != "" {
180 return s, s.saveAPIKeyAndContinue(apiKey)
181 }
182 } else if s.needsProjectInit {
183 return s, s.initializeProject()
184 }
185 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186 if s.needsProjectInit {
187 s.selectedNo = !s.selectedNo
188 return s, nil
189 }
190 case key.Matches(msg, s.keyMap.Yes):
191 if s.needsProjectInit {
192 return s, s.initializeProject()
193 }
194 case key.Matches(msg, s.keyMap.No):
195 s.selectedNo = true
196 return s, s.initializeProject()
197 default:
198 if s.needsAPIKey {
199 u, cmd := s.apiKeyInput.Update(msg)
200 s.apiKeyInput = u.(*models.APIKeyInput)
201 return s, cmd
202 } else if s.isOnboarding {
203 u, cmd := s.modelList.Update(msg)
204 s.modelList = u
205 return s, cmd
206 }
207 }
208 case tea.PasteMsg:
209 if s.needsAPIKey {
210 u, cmd := s.apiKeyInput.Update(msg)
211 s.apiKeyInput = u.(*models.APIKeyInput)
212 return s, cmd
213 } else if s.isOnboarding {
214 var cmd tea.Cmd
215 s.modelList, cmd = s.modelList.Update(msg)
216 return s, cmd
217 }
218 }
219 return s, nil
220}
221
222func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
223 if s.selectedModel == nil {
224 return util.ReportError(fmt.Errorf("no model selected"))
225 }
226
227 cfg := config.Get()
228 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
229 if err != nil {
230 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
231 }
232
233 // Reset API key state and continue with model selection
234 s.needsAPIKey = false
235 cmd := s.setPreferredModel(*s.selectedModel)
236 s.isOnboarding = false
237 s.selectedModel = nil
238
239 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
240}
241
242func (s *splashCmp) initializeProject() tea.Cmd {
243 s.needsProjectInit = false
244
245 if err := config.MarkProjectInitialized(); err != nil {
246 return util.ReportError(err)
247 }
248 var cmds []tea.Cmd
249
250 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
251 if !s.selectedNo {
252 cmds = append(cmds,
253 util.CmdHandler(chat.SessionClearedMsg{}),
254 util.CmdHandler(chat.SendMsg{
255 Text: prompt.Initialize(),
256 }),
257 )
258 }
259 return tea.Sequence(cmds...)
260}
261
262func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
263 cfg := config.Get()
264 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
265 if model == nil {
266 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
267 }
268
269 selectedModel := config.SelectedModel{
270 Model: selectedItem.Model.ID,
271 Provider: string(selectedItem.Provider.ID),
272 ReasoningEffort: model.DefaultReasoningEffort,
273 MaxTokens: model.DefaultMaxTokens,
274 }
275
276 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
277 if err != nil {
278 return util.ReportError(err)
279 }
280
281 // Now lets automatically setup the small model
282 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
283 if err != nil {
284 return util.ReportError(err)
285 }
286 if knownProvider == nil {
287 // for local provider we just use the same model
288 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
289 if err != nil {
290 return util.ReportError(err)
291 }
292 } else {
293 smallModel := knownProvider.DefaultSmallModelID
294 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
295 // should never happen
296 if model == nil {
297 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
298 if err != nil {
299 return util.ReportError(err)
300 }
301 return nil
302 }
303 smallSelectedModel := config.SelectedModel{
304 Model: smallModel,
305 Provider: string(selectedItem.Provider.ID),
306 ReasoningEffort: model.DefaultReasoningEffort,
307 MaxTokens: model.DefaultMaxTokens,
308 }
309 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
310 if err != nil {
311 return util.ReportError(err)
312 }
313 }
314 cfg.SetupAgents()
315 return nil
316}
317
318func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
319 providers, err := config.Providers()
320 if err != nil {
321 return nil, err
322 }
323 for _, p := range providers {
324 if p.ID == providerID {
325 return &p, nil
326 }
327 }
328 return nil, nil
329}
330
331func (s *splashCmp) isProviderConfigured(providerID string) bool {
332 cfg := config.Get()
333 if _, ok := cfg.Providers[providerID]; ok {
334 return true
335 }
336 return false
337}
338
339func (s *splashCmp) View() string {
340 t := styles.CurrentTheme()
341 var content string
342 if s.needsAPIKey {
343 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
344 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
345 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
346 lipgloss.JoinVertical(
347 lipgloss.Left,
348 apiKeyView,
349 ),
350 )
351 content = lipgloss.JoinVertical(
352 lipgloss.Left,
353 s.logoRendered,
354 apiKeySelector,
355 )
356 } else if s.isOnboarding {
357 modelListView := s.modelList.View()
358 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
359 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
360 lipgloss.JoinVertical(
361 lipgloss.Left,
362 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
363 "",
364 modelListView,
365 ),
366 )
367 content = lipgloss.JoinVertical(
368 lipgloss.Left,
369 s.logoRendered,
370 modelSelector,
371 )
372 } else if s.needsProjectInit {
373 titleStyle := t.S().Base.Foreground(t.FgBase)
374 bodyStyle := t.S().Base.Foreground(t.FgMuted)
375 shortcutStyle := t.S().Base.Foreground(t.Success)
376
377 initText := lipgloss.JoinVertical(
378 lipgloss.Left,
379 titleStyle.Render("Would you like to initialize this project?"),
380 "",
381 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
382 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
383 "",
384 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
385 "",
386 bodyStyle.Render("Would you like to initialize now?"),
387 )
388
389 yesButton := core.SelectableButton(core.ButtonOpts{
390 Text: "Yep!",
391 UnderlineIndex: 0,
392 Selected: !s.selectedNo,
393 })
394
395 noButton := core.SelectableButton(core.ButtonOpts{
396 Text: "Nope",
397 UnderlineIndex: 0,
398 Selected: s.selectedNo,
399 })
400
401 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
402 infoSection := s.infoSection()
403
404 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
405
406 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
407 lipgloss.JoinVertical(
408 lipgloss.Left,
409 initText,
410 "",
411 buttons,
412 ),
413 )
414
415 content = lipgloss.JoinVertical(
416 lipgloss.Left,
417 s.logoRendered,
418 infoSection,
419 initContent,
420 )
421 } else {
422 parts := []string{
423 s.logoRendered,
424 s.infoSection(),
425 }
426 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
427 }
428
429 return t.S().Base.
430 Width(s.width).
431 Height(s.height).
432 PaddingTop(SplashScreenPaddingY).
433 PaddingBottom(SplashScreenPaddingY).
434 Render(content)
435}
436
437func (s *splashCmp) Cursor() *tea.Cursor {
438 if s.needsAPIKey {
439 cursor := s.apiKeyInput.Cursor()
440 if cursor != nil {
441 return s.moveCursor(cursor)
442 }
443 } else if s.isOnboarding {
444 cursor := s.modelList.Cursor()
445 if cursor != nil {
446 return s.moveCursor(cursor)
447 }
448 } else {
449 return nil
450 }
451 return nil
452}
453
454func (s *splashCmp) infoSection() string {
455 t := styles.CurrentTheme()
456 return t.S().Base.PaddingLeft(2).Render(
457 lipgloss.JoinVertical(
458 lipgloss.Left,
459 s.cwd(),
460 "",
461 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
462 "",
463 ),
464 )
465}
466
467func (s *splashCmp) logoBlock() string {
468 t := styles.CurrentTheme()
469 return t.S().Base.Padding(0, 2).Width(s.width).Render(
470 logo.Render(version.Version, false, logo.Opts{
471 FieldColor: t.Primary,
472 TitleColorA: t.Secondary,
473 TitleColorB: t.Primary,
474 CharmColor: t.Secondary,
475 VersionColor: t.Primary,
476 Width: s.width - 4,
477 }),
478 )
479}
480
481func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
482 if cursor == nil {
483 return nil
484 }
485 // Calculate the correct Y offset based on current state
486 logoHeight := lipgloss.Height(s.logoRendered)
487 if s.needsAPIKey {
488 infoSectionHeight := lipgloss.Height(s.infoSection())
489 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
490 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
491 offset := baseOffset + remainingHeight
492 cursor.Y += offset
493 cursor.X = cursor.X + 1
494 } else if s.isOnboarding {
495 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
496 cursor.Y += offset
497 cursor.X = cursor.X + 1
498 }
499
500 return cursor
501}
502
503func (s *splashCmp) logoGap() int {
504 if s.height > 35 {
505 return LogoGap
506 }
507 return 0
508}
509
510// Bindings implements SplashPage.
511func (s *splashCmp) Bindings() []key.Binding {
512 if s.needsAPIKey {
513 return []key.Binding{
514 s.keyMap.Select,
515 s.keyMap.Back,
516 }
517 } else if s.isOnboarding {
518 return []key.Binding{
519 s.keyMap.Select,
520 s.keyMap.Next,
521 s.keyMap.Previous,
522 }
523 } else if s.needsProjectInit {
524 return []key.Binding{
525 s.keyMap.Select,
526 s.keyMap.Yes,
527 s.keyMap.No,
528 s.keyMap.Tab,
529 s.keyMap.LeftRight,
530 }
531 }
532 return []key.Binding{}
533}
534
535func (s *splashCmp) getMaxInfoWidth() int {
536 return min(s.width-2, 40) // 2 for left padding
537}
538
539func (s *splashCmp) cwd() string {
540 cwd := config.Get().WorkingDir()
541 t := styles.CurrentTheme()
542 homeDir, err := os.UserHomeDir()
543 if err == nil && cwd != homeDir {
544 cwd = strings.ReplaceAll(cwd, homeDir, "~")
545 }
546 maxWidth := s.getMaxInfoWidth()
547 return t.S().Muted.Width(maxWidth).Render(cwd)
548}
549
550func LSPList(maxWidth int) []string {
551 t := styles.CurrentTheme()
552 lspList := []string{}
553 lsp := config.Get().LSP.Sorted()
554 if len(lsp) == 0 {
555 return []string{t.S().Base.Foreground(t.Border).Render("None")}
556 }
557 for _, l := range lsp {
558 iconColor := t.Success
559 if l.LSP.Disabled {
560 iconColor = t.FgMuted
561 }
562 lspList = append(lspList,
563 core.Status(
564 core.StatusOpts{
565 IconColor: iconColor,
566 Title: l.Name,
567 Description: l.LSP.Command,
568 },
569 maxWidth,
570 ),
571 )
572 }
573 return lspList
574}
575
576func (s *splashCmp) lspBlock() string {
577 t := styles.CurrentTheme()
578 maxWidth := s.getMaxInfoWidth() / 2
579 section := t.S().Subtle.Render("LSPs")
580 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
581 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
582 lipgloss.JoinVertical(
583 lipgloss.Left,
584 lspList...,
585 ),
586 )
587}
588
589func MCPList(maxWidth int) []string {
590 t := styles.CurrentTheme()
591 mcpList := []string{}
592 mcps := config.Get().MCP.Sorted()
593 if len(mcps) == 0 {
594 return []string{t.S().Base.Foreground(t.Border).Render("None")}
595 }
596 for _, l := range mcps {
597 iconColor := t.Success
598 if l.MCP.Disabled {
599 iconColor = t.FgMuted
600 }
601 mcpList = append(mcpList,
602 core.Status(
603 core.StatusOpts{
604 IconColor: iconColor,
605 Title: l.Name,
606 Description: l.MCP.Command,
607 },
608 maxWidth,
609 ),
610 )
611 }
612 return mcpList
613}
614
615func (s *splashCmp) mcpBlock() string {
616 t := styles.CurrentTheme()
617 maxWidth := s.getMaxInfoWidth() / 2
618 section := t.S().Subtle.Render("MCPs")
619 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
620 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
621 lipgloss.JoinVertical(
622 lipgloss.Left,
623 mcpList...,
624 ),
625 )
626}
627
628func (s *splashCmp) IsShowingAPIKey() bool {
629 return s.needsAPIKey
630}