1package splash
2
3import (
4 "fmt"
5 "os"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/v2/key"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/tui/components/chat"
15 "github.com/charmbracelet/crush/internal/tui/components/completions"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
18 "github.com/charmbracelet/crush/internal/tui/components/core/list"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
20 "github.com/charmbracelet/crush/internal/tui/components/logo"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/charmbracelet/lipgloss/v2"
25)
26
27type Splash interface {
28 util.Model
29 layout.Sizeable
30 layout.Help
31 Cursor() *tea.Cursor
32 // SetOnboarding controls whether the splash shows model selection UI
33 SetOnboarding(bool)
34 // SetProjectInit controls whether the splash shows project initialization prompt
35 SetProjectInit(bool)
36
37 // Showing API key input
38 IsShowingAPIKey() bool
39}
40
41const (
42 SplashScreenPaddingY = 1 // Padding Y for the splash screen
43
44 LogoGap = 6
45)
46
47// OnboardingCompleteMsg is sent when onboarding is complete
48type OnboardingCompleteMsg struct{}
49
50type splashCmp struct {
51 width, height int
52 keyMap KeyMap
53 logoRendered string
54
55 // State
56 isOnboarding bool
57 needsProjectInit bool
58 needsAPIKey bool
59 selectedNo bool
60
61 listHeight int
62 modelList *models.ModelListComponent
63 apiKeyInput *models.APIKeyInput
64 selectedModel *models.ModelOption
65}
66
67func New() Splash {
68 keyMap := DefaultKeyMap()
69 listKeyMap := list.DefaultKeyMap()
70 listKeyMap.Down.SetEnabled(false)
71 listKeyMap.Up.SetEnabled(false)
72 listKeyMap.HalfPageDown.SetEnabled(false)
73 listKeyMap.HalfPageUp.SetEnabled(false)
74 listKeyMap.Home.SetEnabled(false)
75 listKeyMap.End.SetEnabled(false)
76 listKeyMap.DownOneItem = keyMap.Next
77 listKeyMap.UpOneItem = keyMap.Previous
78
79 t := styles.CurrentTheme()
80 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
81 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
82 apiKeyInput := models.NewAPIKeyInput()
83
84 return &splashCmp{
85 width: 0,
86 height: 0,
87 keyMap: keyMap,
88 logoRendered: "",
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 selectedNo: false,
92 }
93}
94
95func (s *splashCmp) SetOnboarding(onboarding bool) {
96 s.isOnboarding = onboarding
97 if onboarding {
98 providers, err := config.Providers()
99 if err != nil {
100 return
101 }
102 filteredProviders := []provider.Provider{}
103 simpleProviders := []string{
104 "anthropic",
105 "openai",
106 "gemini",
107 "xai",
108 "openrouter",
109 }
110 for _, p := range providers {
111 if slices.Contains(simpleProviders, string(p.ID)) {
112 filteredProviders = append(filteredProviders, p)
113 }
114 }
115 s.modelList.SetProviders(filteredProviders)
116 }
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120 s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125 return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130 return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135 s.height = height
136 if width != s.width {
137 s.width = width
138 s.logoRendered = s.logoBlock()
139 }
140 // remove padding, logo height, gap, title space
141 s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142 listWidth := min(60, width)
143 return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148 switch msg := msg.(type) {
149 case tea.WindowSizeMsg:
150 return s, s.SetSize(msg.Width, msg.Height)
151 case tea.KeyPressMsg:
152 switch {
153 case key.Matches(msg, s.keyMap.Back):
154 if s.needsAPIKey {
155 // Go back to model selection
156 s.needsAPIKey = false
157 s.selectedModel = nil
158 return s, nil
159 }
160 case key.Matches(msg, s.keyMap.Select):
161 if s.isOnboarding && !s.needsAPIKey {
162 modelInx := s.modelList.SelectedIndex()
163 items := s.modelList.Items()
164 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166 cmd := s.setPreferredModel(selectedItem)
167 s.isOnboarding = false
168 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169 } else {
170 // Provider not configured, show API key input
171 s.needsAPIKey = true
172 s.selectedModel = &selectedItem
173 s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174 return s, nil
175 }
176 } else if s.needsAPIKey {
177 // Handle API key submission
178 apiKey := s.apiKeyInput.Value()
179 if apiKey != "" {
180 return s, s.saveAPIKeyAndContinue(apiKey)
181 }
182 } else if s.needsProjectInit {
183 return s, s.initializeProject()
184 }
185 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186 if s.needsProjectInit {
187 s.selectedNo = !s.selectedNo
188 return s, nil
189 }
190 case key.Matches(msg, s.keyMap.Yes):
191 if s.needsProjectInit {
192 return s, s.initializeProject()
193 }
194 case key.Matches(msg, s.keyMap.No):
195 if s.needsProjectInit {
196 s.needsProjectInit = false
197 return s, util.CmdHandler(OnboardingCompleteMsg{})
198 }
199 default:
200 if s.needsAPIKey {
201 u, cmd := s.apiKeyInput.Update(msg)
202 s.apiKeyInput = u.(*models.APIKeyInput)
203 return s, cmd
204 } else if s.isOnboarding {
205 u, cmd := s.modelList.Update(msg)
206 s.modelList = u
207 return s, cmd
208 }
209 }
210 case tea.PasteMsg:
211 if s.needsAPIKey {
212 u, cmd := s.apiKeyInput.Update(msg)
213 s.apiKeyInput = u.(*models.APIKeyInput)
214 return s, cmd
215 } else if s.isOnboarding {
216 var cmd tea.Cmd
217 s.modelList, cmd = s.modelList.Update(msg)
218 return s, cmd
219 }
220 }
221 return s, nil
222}
223
224func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
225 if s.selectedModel == nil {
226 return util.ReportError(fmt.Errorf("no model selected"))
227 }
228
229 cfg := config.Get()
230 err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
231 if err != nil {
232 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
233 }
234
235 // Reset API key state and continue with model selection
236 s.needsAPIKey = false
237 cmd := s.setPreferredModel(*s.selectedModel)
238 s.isOnboarding = false
239 s.selectedModel = nil
240
241 return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
242}
243
244func (s *splashCmp) initializeProject() tea.Cmd {
245 s.needsProjectInit = false
246
247 if err := config.MarkProjectInitialized(); err != nil {
248 return util.ReportError(err)
249 }
250 var cmds []tea.Cmd
251
252 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253 if !s.selectedNo {
254 cmds = append(cmds,
255 util.CmdHandler(chat.SessionClearedMsg{}),
256 util.CmdHandler(chat.SendMsg{
257 Text: prompt.Initialize(),
258 }),
259 )
260 }
261 return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265 cfg := config.Get()
266 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267 if model == nil {
268 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269 }
270
271 selectedModel := config.SelectedModel{
272 Model: selectedItem.Model.ID,
273 Provider: string(selectedItem.Provider.ID),
274 ReasoningEffort: model.DefaultReasoningEffort,
275 MaxTokens: model.DefaultMaxTokens,
276 }
277
278 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279 if err != nil {
280 return util.ReportError(err)
281 }
282
283 // Now lets automatically setup the small model
284 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285 if err != nil {
286 return util.ReportError(err)
287 }
288 if knownProvider == nil {
289 // for local provider we just use the same model
290 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291 if err != nil {
292 return util.ReportError(err)
293 }
294 } else {
295 smallModel := knownProvider.DefaultSmallModelID
296 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297 // should never happen
298 if model == nil {
299 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300 if err != nil {
301 return util.ReportError(err)
302 }
303 return nil
304 }
305 smallSelectedModel := config.SelectedModel{
306 Model: smallModel,
307 Provider: string(selectedItem.Provider.ID),
308 ReasoningEffort: model.DefaultReasoningEffort,
309 MaxTokens: model.DefaultMaxTokens,
310 }
311 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312 if err != nil {
313 return util.ReportError(err)
314 }
315 }
316 return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320 providers, err := config.Providers()
321 if err != nil {
322 return nil, err
323 }
324 for _, p := range providers {
325 if p.ID == providerID {
326 return &p, nil
327 }
328 }
329 return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333 cfg := config.Get()
334 if _, ok := cfg.Providers[providerID]; ok {
335 return true
336 }
337 return false
338}
339
340func (s *splashCmp) View() string {
341 t := styles.CurrentTheme()
342 var content string
343 if s.needsAPIKey {
344 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
345 apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
346 apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
347 lipgloss.JoinVertical(
348 lipgloss.Left,
349 apiKeyView,
350 ),
351 )
352 content = lipgloss.JoinVertical(
353 lipgloss.Left,
354 s.logoRendered,
355 apiKeySelector,
356 )
357 } else if s.isOnboarding {
358 modelListView := s.modelList.View()
359 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
360 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
361 lipgloss.JoinVertical(
362 lipgloss.Left,
363 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
364 "",
365 modelListView,
366 ),
367 )
368 content = lipgloss.JoinVertical(
369 lipgloss.Left,
370 s.logoRendered,
371 modelSelector,
372 )
373 } else if s.needsProjectInit {
374 titleStyle := t.S().Base.Foreground(t.FgBase)
375 bodyStyle := t.S().Base.Foreground(t.FgMuted)
376 shortcutStyle := t.S().Base.Foreground(t.Success)
377
378 initText := lipgloss.JoinVertical(
379 lipgloss.Left,
380 titleStyle.Render("Would you like to initialize this project?"),
381 "",
382 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
383 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
384 "",
385 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
386 "",
387 bodyStyle.Render("Would you like to initialize now?"),
388 )
389
390 yesButton := core.SelectableButton(core.ButtonOpts{
391 Text: "Yep!",
392 UnderlineIndex: 0,
393 Selected: !s.selectedNo,
394 })
395
396 noButton := core.SelectableButton(core.ButtonOpts{
397 Text: "Nope",
398 UnderlineIndex: 0,
399 Selected: s.selectedNo,
400 })
401
402 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
403 infoSection := s.infoSection()
404
405 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
406
407 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
408 lipgloss.JoinVertical(
409 lipgloss.Left,
410 initText,
411 "",
412 buttons,
413 ),
414 )
415
416 content = lipgloss.JoinVertical(
417 lipgloss.Left,
418 s.logoRendered,
419 infoSection,
420 initContent,
421 )
422 } else {
423 parts := []string{
424 s.logoRendered,
425 s.infoSection(),
426 }
427 content = lipgloss.JoinVertical(lipgloss.Left, parts...)
428 }
429
430 return t.S().Base.
431 Width(s.width).
432 Height(s.height).
433 PaddingTop(SplashScreenPaddingY).
434 PaddingBottom(SplashScreenPaddingY).
435 Render(content)
436}
437
438func (s *splashCmp) Cursor() *tea.Cursor {
439 if s.needsAPIKey {
440 cursor := s.apiKeyInput.Cursor()
441 if cursor != nil {
442 return s.moveCursor(cursor)
443 }
444 } else if s.isOnboarding {
445 cursor := s.modelList.Cursor()
446 if cursor != nil {
447 return s.moveCursor(cursor)
448 }
449 } else {
450 return nil
451 }
452 return nil
453}
454
455func (s *splashCmp) infoSection() string {
456 t := styles.CurrentTheme()
457 return t.S().Base.PaddingLeft(2).Render(
458 lipgloss.JoinVertical(
459 lipgloss.Left,
460 s.cwd(),
461 "",
462 lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
463 "",
464 ),
465 )
466}
467
468func (s *splashCmp) logoBlock() string {
469 t := styles.CurrentTheme()
470 return t.S().Base.Padding(0, 2).Width(s.width).Render(
471 logo.Render(version.Version, false, logo.Opts{
472 FieldColor: t.Primary,
473 TitleColorA: t.Secondary,
474 TitleColorB: t.Primary,
475 CharmColor: t.Secondary,
476 VersionColor: t.Primary,
477 Width: s.width - 4,
478 }),
479 )
480}
481
482func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
483 if cursor == nil {
484 return nil
485 }
486 // Calculate the correct Y offset based on current state
487 logoHeight := lipgloss.Height(s.logoRendered)
488 if s.needsAPIKey {
489 infoSectionHeight := lipgloss.Height(s.infoSection())
490 baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
491 remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
492 offset := baseOffset + remainingHeight
493 cursor.Y += offset
494 cursor.X = cursor.X + 1
495 } else if s.isOnboarding {
496 offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
497 cursor.Y += offset
498 cursor.X = cursor.X + 1
499 }
500
501 return cursor
502}
503
504func (s *splashCmp) logoGap() int {
505 if s.height > 35 {
506 return LogoGap
507 }
508 return 0
509}
510
511// Bindings implements SplashPage.
512func (s *splashCmp) Bindings() []key.Binding {
513 if s.needsAPIKey {
514 return []key.Binding{
515 s.keyMap.Select,
516 s.keyMap.Back,
517 }
518 } else if s.isOnboarding {
519 return []key.Binding{
520 s.keyMap.Select,
521 s.keyMap.Next,
522 s.keyMap.Previous,
523 }
524 } else if s.needsProjectInit {
525 return []key.Binding{
526 s.keyMap.Select,
527 s.keyMap.Yes,
528 s.keyMap.No,
529 s.keyMap.Tab,
530 s.keyMap.LeftRight,
531 }
532 }
533 return []key.Binding{}
534}
535
536func (s *splashCmp) getMaxInfoWidth() int {
537 return min(s.width-2, 40) // 2 for left padding
538}
539
540func (s *splashCmp) cwd() string {
541 cwd := config.Get().WorkingDir()
542 t := styles.CurrentTheme()
543 homeDir, err := os.UserHomeDir()
544 if err == nil && cwd != homeDir {
545 cwd = strings.ReplaceAll(cwd, homeDir, "~")
546 }
547 maxWidth := s.getMaxInfoWidth()
548 return t.S().Muted.Width(maxWidth).Render(cwd)
549}
550
551func LSPList(maxWidth int) []string {
552 t := styles.CurrentTheme()
553 lspList := []string{}
554 lsp := config.Get().LSP.Sorted()
555 if len(lsp) == 0 {
556 return []string{t.S().Base.Foreground(t.Border).Render("None")}
557 }
558 for _, l := range lsp {
559 iconColor := t.Success
560 if l.LSP.Disabled {
561 iconColor = t.FgMuted
562 }
563 lspList = append(lspList,
564 core.Status(
565 core.StatusOpts{
566 IconColor: iconColor,
567 Title: l.Name,
568 Description: l.LSP.Command,
569 },
570 maxWidth,
571 ),
572 )
573 }
574 return lspList
575}
576
577func (s *splashCmp) lspBlock() string {
578 t := styles.CurrentTheme()
579 maxWidth := s.getMaxInfoWidth() / 2
580 section := t.S().Subtle.Render("LSPs")
581 lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
582 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
583 lipgloss.JoinVertical(
584 lipgloss.Left,
585 lspList...,
586 ),
587 )
588}
589
590func MCPList(maxWidth int) []string {
591 t := styles.CurrentTheme()
592 mcpList := []string{}
593 mcps := config.Get().MCP.Sorted()
594 if len(mcps) == 0 {
595 return []string{t.S().Base.Foreground(t.Border).Render("None")}
596 }
597 for _, l := range mcps {
598 iconColor := t.Success
599 if l.MCP.Disabled {
600 iconColor = t.FgMuted
601 }
602 mcpList = append(mcpList,
603 core.Status(
604 core.StatusOpts{
605 IconColor: iconColor,
606 Title: l.Name,
607 Description: l.MCP.Command,
608 },
609 maxWidth,
610 ),
611 )
612 }
613 return mcpList
614}
615
616func (s *splashCmp) mcpBlock() string {
617 t := styles.CurrentTheme()
618 maxWidth := s.getMaxInfoWidth() / 2
619 section := t.S().Subtle.Render("MCPs")
620 mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
621 return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
622 lipgloss.JoinVertical(
623 lipgloss.Left,
624 mcpList...,
625 ),
626 )
627}
628
629func (s *splashCmp) IsShowingAPIKey() bool {
630 return s.needsAPIKey
631}