1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "openrouter",
109 }
110 for _, p := range providers {
111 if slices.Contains(simpleProviders, string(p.ID)) {
112 filteredProviders = append(filteredProviders, p)
113 }
114 }
115 s.modelList.SetProviders(filteredProviders)
116 }
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120 s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125 return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135 s.height = height
136 if width != s.width {
137 s.width = width
138 s.logoRendered = s.logoBlock()
139 }
140 // remove padding, logo height, gap, title space
141 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142 listWidth := min(60, width)
143 return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148 switch msg := msg.(type) {
149 case tea.WindowSizeMsg:
150 return s, s.SetSize(msg.Width, msg.Height)
151 case tea.KeyPressMsg:
152 switch {
153 case key.Matches(msg, s.keyMap.Back):
154 if s.needsAPIKey {
155 // Go back to model selection
156 s.needsAPIKey = false
157 s.selectedModel = nil
158 return s, nil
159 }
160 case key.Matches(msg, s.keyMap.Select):
161 if s.isOnboarding && !s.needsAPIKey {
162 modelInx := s.modelList.SelectedIndex()
163 items := s.modelList.Items()
164 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166 cmd := s.setPreferredModel(selectedItem)
167 s.isOnboarding = false
168 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169 } else {
170 // Provider not configured, show API key input
171 s.needsAPIKey = true
172 s.selectedModel = &selectedItem
173 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174 return s, nil
175 }
176 } else if s.needsAPIKey {
177 // Handle API key submission
178 apiKey := s.apiKeyInput.Value()
179 if apiKey != "" {
180 return s, s.saveAPIKeyAndContinue(apiKey)
181 }
182 } else if s.needsProjectInit {
183 return s, s.initializeProject()
184 }
185 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186 if s.needsProjectInit {
187 s.selectedNo = !s.selectedNo
188 return s, nil
189 }
190 case key.Matches(msg, s.keyMap.Yes):
191 if s.needsProjectInit {
192 return s, s.initializeProject()
193 }
194 case key.Matches(msg, s.keyMap.No):
195 if s.needsProjectInit {
196 s.needsProjectInit = false
197 return s, util.CmdHandler(OnboardingCompleteMsg{})
198 }
199 default:
200 if s.needsAPIKey {
201 u, cmd := s.apiKeyInput.Update(msg)
202 s.apiKeyInput = u.(*models.APIKeyInput)
203 return s, cmd
204 } else if s.isOnboarding {
205 u, cmd := s.modelList.Update(msg)
206 s.modelList = u
207 return s, cmd
208 }
209 }
210 case tea.PasteMsg:
211 if s.needsAPIKey {
212 u, cmd := s.apiKeyInput.Update(msg)
213 s.apiKeyInput = u.(*models.APIKeyInput)
214 return s, cmd
215 } else if s.isOnboarding {
216 var cmd tea.Cmd
217 s.modelList, cmd = s.modelList.Update(msg)
218 return s, cmd
219 }
220 }
221 return s, nil
222}
223
224func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
225 if s.selectedModel == nil {
226 return util.ReportError(fmt.Errorf("no model selected"))
227 }
228
229 cfg := config.Get()
230 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
231 if err != nil {
232 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
233 }
234
235 // Reset API key state and continue with model selection
236 s.needsAPIKey = false
237 cmd := s.setPreferredModel(*s.selectedModel)
238 s.isOnboarding = false
239 s.selectedModel = nil
240
241 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
242}
243
244func (s *splashCmp) initializeProject() tea.Cmd {
245 s.needsProjectInit = false
246
247 if err := config.MarkProjectInitialized(); err != nil {
248 return util.ReportError(err)
249 }
250 var cmds []tea.Cmd
251
252 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253 if !s.selectedNo {
254 cmds = append(cmds,
255 util.CmdHandler(chat.SessionClearedMsg{}),
256 util.CmdHandler(chat.SendMsg{
257 Text: prompt.Initialize(),
258 }),
259 )
260 }
261 return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265 cfg := config.Get()
266 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267 if model == nil {
268 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269 }
270
271 selectedModel := config.SelectedModel{
272 Model: selectedItem.Model.ID,
273 Provider: string(selectedItem.Provider.ID),
274 ReasoningEffort: model.DefaultReasoningEffort,
275 MaxTokens: model.DefaultMaxTokens,
276 }
277
278 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279 if err != nil {
280 return util.ReportError(err)
281 }
282
283 // Now lets automatically setup the small model
284 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285 if err != nil {
286 return util.ReportError(err)
287 }
288 if knownProvider == nil {
289 // for local provider we just use the same model
290 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291 if err != nil {
292 return util.ReportError(err)
293 }
294 } else {
295 smallModel := knownProvider.DefaultSmallModelID
296 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297 // should never happen
298 if model == nil {
299 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300 if err != nil {
301 return util.ReportError(err)
302 }
303 return nil
304 }
305 smallSelectedModel := config.SelectedModel{
306 Model: smallModel,
307 Provider: string(selectedItem.Provider.ID),
308 ReasoningEffort: model.DefaultReasoningEffort,
309 MaxTokens: model.DefaultMaxTokens,
310 }
311 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312 if err != nil {
313 return util.ReportError(err)
314 }
315 }
316 cfg.SetupAgents()
317 return nil
318}
319
320func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
321 providers, err := config.Providers()
322 if err != nil {
323 return nil, err
324 }
325 for _, p := range providers {
326 if p.ID == providerID {
327 return &p, nil
328 }
329 }
330 return nil, nil
331}
332
333func (s *splashCmp) isProviderConfigured(providerID string) bool {
334 cfg := config.Get()
335 if _, ok := cfg.Providers[providerID]; ok {
336 return true
337 }
338 return false
339}
340
341func (s *splashCmp) View() string {
342 t := styles.CurrentTheme()
343 var content string
344 if s.needsAPIKey {
345 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
346 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
347 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
348 lipgloss.JoinVertical(
349 lipgloss.Left,
350 apiKeyView,
351 ),
352 )
353 content = lipgloss.JoinVertical(
354 lipgloss.Left,
355 s.logoRendered,
356 apiKeySelector,
357 )
358 } else if s.isOnboarding {
359 modelListView := s.modelList.View()
360 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
361 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
362 lipgloss.JoinVertical(
363 lipgloss.Left,
364 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
365 "",
366 modelListView,
367 ),
368 )
369 content = lipgloss.JoinVertical(
370 lipgloss.Left,
371 s.logoRendered,
372 modelSelector,
373 )
374 } else if s.needsProjectInit {
375 titleStyle := t.S().Base.Foreground(t.FgBase)
376 bodyStyle := t.S().Base.Foreground(t.FgMuted)
377 shortcutStyle := t.S().Base.Foreground(t.Success)
378
379 initText := lipgloss.JoinVertical(
380 lipgloss.Left,
381 titleStyle.Render("Would you like to initialize this project?"),
382 "",
383 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
384 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
385 "",
386 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
387 "",
388 bodyStyle.Render("Would you like to initialize now?"),
389 )
390
391 yesButton := core.SelectableButton(core.ButtonOpts{
392 Text: "Yep!",
393 UnderlineIndex: 0,
394 Selected: !s.selectedNo,
395 })
396
397 noButton := core.SelectableButton(core.ButtonOpts{
398 Text: "Nope",
399 UnderlineIndex: 0,
400 Selected: s.selectedNo,
401 })
402
403 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
404 infoSection := s.infoSection()
405
406 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
407
408 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
409 lipgloss.JoinVertical(
410 lipgloss.Left,
411 initText,
412 "",
413 buttons,
414 ),
415 )
416
417 content = lipgloss.JoinVertical(
418 lipgloss.Left,
419 s.logoRendered,
420 infoSection,
421 initContent,
422 )
423 } else {
424 parts := []string{
425 s.logoRendered,
426 s.infoSection(),
427 }
428 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
429 }
430
431 return t.S().Base.
432 Width(s.width).
433 Height(s.height).
434 PaddingTop(SplashScreenPaddingY).
435 PaddingBottom(SplashScreenPaddingY).
436 Render(content)
437}
438
439func (s *splashCmp) Cursor() *tea.Cursor {
440 if s.needsAPIKey {
441 cursor := s.apiKeyInput.Cursor()
442 if cursor != nil {
443 return s.moveCursor(cursor)
444 }
445 } else if s.isOnboarding {
446 cursor := s.modelList.Cursor()
447 if cursor != nil {
448 return s.moveCursor(cursor)
449 }
450 } else {
451 return nil
452 }
453 return nil
454}
455
456func (s *splashCmp) infoSection() string {
457 t := styles.CurrentTheme()
458 return t.S().Base.PaddingLeft(2).Render(
459 lipgloss.JoinVertical(
460 lipgloss.Left,
461 s.cwd(),
462 "",
463 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
464 "",
465 ),
466 )
467}
468
469func (s *splashCmp) logoBlock() string {
470 t := styles.CurrentTheme()
471 return t.S().Base.Padding(0, 2).Width(s.width).Render(
472 logo.Render(version.Version, false, logo.Opts{
473 FieldColor: t.Primary,
474 TitleColorA: t.Secondary,
475 TitleColorB: t.Primary,
476 CharmColor: t.Secondary,
477 VersionColor: t.Primary,
478 Width: s.width - 4,
479 }),
480 )
481}
482
483func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
484 if cursor == nil {
485 return nil
486 }
487 // Calculate the correct Y offset based on current state
488 logoHeight := lipgloss.Height(s.logoRendered)
489 if s.needsAPIKey {
490 infoSectionHeight := lipgloss.Height(s.infoSection())
491 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
492 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
493 offset := baseOffset + remainingHeight
494 cursor.Y += offset
495 cursor.X = cursor.X + 1
496 } else if s.isOnboarding {
497 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
498 cursor.Y += offset
499 cursor.X = cursor.X + 1
500 }
501
502 return cursor
503}
504
505func (s *splashCmp) logoGap() int {
506 if s.height > 35 {
507 return LogoGap
508 }
509 return 0
510}
511
512// Bindings implements SplashPage.
513func (s *splashCmp) Bindings() []key.Binding {
514 if s.needsAPIKey {
515 return []key.Binding{
516 s.keyMap.Select,
517 s.keyMap.Back,
518 }
519 } else if s.isOnboarding {
520 return []key.Binding{
521 s.keyMap.Select,
522 s.keyMap.Next,
523 s.keyMap.Previous,
524 }
525 } else if s.needsProjectInit {
526 return []key.Binding{
527 s.keyMap.Select,
528 s.keyMap.Yes,
529 s.keyMap.No,
530 s.keyMap.Tab,
531 s.keyMap.LeftRight,
532 }
533 }
534 return []key.Binding{}
535}
536
537func (s *splashCmp) getMaxInfoWidth() int {
538 return min(s.width-2, 40) // 2 for left padding
539}
540
541func (s *splashCmp) cwd() string {
542 cwd := config.Get().WorkingDir()
543 t := styles.CurrentTheme()
544 homeDir, err := os.UserHomeDir()
545 if err == nil && cwd != homeDir {
546 cwd = strings.ReplaceAll(cwd, homeDir, "~")
547 }
548 maxWidth := s.getMaxInfoWidth()
549 return t.S().Muted.Width(maxWidth).Render(cwd)
550}
551
552func LSPList(maxWidth int) []string {
553 t := styles.CurrentTheme()
554 lspList := []string{}
555 lsp := config.Get().LSP.Sorted()
556 if len(lsp) == 0 {
557 return []string{t.S().Base.Foreground(t.Border).Render("None")}
558 }
559 for _, l := range lsp {
560 iconColor := t.Success
561 if l.LSP.Disabled {
562 iconColor = t.FgMuted
563 }
564 lspList = append(lspList,
565 core.Status(
566 core.StatusOpts{
567 IconColor: iconColor,
568 Title: l.Name,
569 Description: l.LSP.Command,
570 },
571 maxWidth,
572 ),
573 )
574 }
575 return lspList
576}
577
578func (s *splashCmp) lspBlock() string {
579 t := styles.CurrentTheme()
580 maxWidth := s.getMaxInfoWidth() / 2
581 section := t.S().Subtle.Render("LSPs")
582 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
583 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
584 lipgloss.JoinVertical(
585 lipgloss.Left,
586 lspList...,
587 ),
588 )
589}
590
591func MCPList(maxWidth int) []string {
592 t := styles.CurrentTheme()
593 mcpList := []string{}
594 mcps := config.Get().MCP.Sorted()
595 if len(mcps) == 0 {
596 return []string{t.S().Base.Foreground(t.Border).Render("None")}
597 }
598 for _, l := range mcps {
599 iconColor := t.Success
600 if l.MCP.Disabled {
601 iconColor = t.FgMuted
602 }
603 mcpList = append(mcpList,
604 core.Status(
605 core.StatusOpts{
606 IconColor: iconColor,
607 Title: l.Name,
608 Description: l.MCP.Command,
609 },
610 maxWidth,
611 ),
612 )
613 }
614 return mcpList
615}
616
617func (s *splashCmp) mcpBlock() string {
618 t := styles.CurrentTheme()
619 maxWidth := s.getMaxInfoWidth() / 2
620 section := t.S().Subtle.Render("MCPs")
621 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
622 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
623 lipgloss.JoinVertical(
624 lipgloss.Left,
625 mcpList...,
626 ),
627 )
628}
629
630func (s *splashCmp) IsShowingAPIKey() bool {
631 return s.needsAPIKey
632}