1package splash
2
3import (
4 "fmt"
5 "slices"
6
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/charmbracelet/crush/internal/tui/components/chat"
12 "github.com/charmbracelet/crush/internal/tui/components/completions"
13 "github.com/charmbracelet/crush/internal/tui/components/core"
14 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
15 "github.com/charmbracelet/crush/internal/tui/components/core/list"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
17 "github.com/charmbracelet/crush/internal/tui/components/logo"
18 "github.com/charmbracelet/crush/internal/tui/styles"
19 "github.com/charmbracelet/crush/internal/tui/util"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/charmbracelet/lipgloss/v2"
22)
23
24type Splash interface {
25 util.Model
26 layout.Sizeable
27 layout.Help
28 // SetOnboarding controls whether the splash shows model selection UI
29 SetOnboarding(bool)
30 // SetProjectInit controls whether the splash shows project initialization prompt
31 SetProjectInit(bool)
32}
33
34const (
35 SplashScreenPaddingX = 2 // Padding X for the splash screen
36 SplashScreenPaddingY = 1 // Padding Y for the splash screen
37)
38
39// OnboardingCompleteMsg is sent when onboarding is complete
40type OnboardingCompleteMsg struct{}
41
42type splashCmp struct {
43 width, height int
44 keyMap KeyMap
45 logoRendered string
46
47 // State
48 isOnboarding bool
49 needsProjectInit bool
50 selectedNo bool
51
52 modelList *models.ModelListComponent
53 cursorRow, cursorCol int
54}
55
56func New() Splash {
57 keyMap := DefaultKeyMap()
58 listKeyMap := list.DefaultKeyMap()
59 listKeyMap.Down.SetEnabled(false)
60 listKeyMap.Up.SetEnabled(false)
61 listKeyMap.HalfPageDown.SetEnabled(false)
62 listKeyMap.HalfPageUp.SetEnabled(false)
63 listKeyMap.Home.SetEnabled(false)
64 listKeyMap.End.SetEnabled(false)
65 listKeyMap.DownOneItem = keyMap.Next
66 listKeyMap.UpOneItem = keyMap.Previous
67
68 t := styles.CurrentTheme()
69 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
70 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
71 return &splashCmp{
72 width: 0,
73 height: 0,
74 keyMap: keyMap,
75 logoRendered: "",
76 modelList: modelList,
77 selectedNo: false,
78 }
79}
80
81func (s *splashCmp) SetOnboarding(onboarding bool) {
82 s.isOnboarding = onboarding
83 if onboarding {
84 providers, err := config.Providers()
85 if err != nil {
86 return
87 }
88 filteredProviders := []provider.Provider{}
89 simpleProviders := []string{
90 "anthropic",
91 "openai",
92 "gemini",
93 "xai",
94 "openrouter",
95 }
96 for _, p := range providers {
97 if slices.Contains(simpleProviders, string(p.ID)) {
98 filteredProviders = append(filteredProviders, p)
99 }
100 }
101 s.modelList.SetProviders(filteredProviders)
102 }
103}
104
105func (s *splashCmp) SetProjectInit(needsInit bool) {
106 s.needsProjectInit = needsInit
107}
108
109// GetSize implements SplashPage.
110func (s *splashCmp) GetSize() (int, int) {
111 return s.width, s.height
112}
113
114// Init implements SplashPage.
115func (s *splashCmp) Init() tea.Cmd {
116 return s.modelList.Init()
117}
118
119// SetSize implements SplashPage.
120func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
121 s.width = width
122 s.height = height
123 s.logoRendered = s.logoBlock()
124 listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
125 listWidth := min(60, width-(SplashScreenPaddingX*2))
126
127 // Calculate the cursor position based on the height and logo size
128 s.cursorRow = height - listHeigh
129 return s.modelList.SetSize(listWidth, listHeigh)
130}
131
132// Update implements SplashPage.
133func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
134 switch msg := msg.(type) {
135 case tea.WindowSizeMsg:
136 return s, s.SetSize(msg.Width, msg.Height)
137 case tea.KeyPressMsg:
138 switch {
139 case key.Matches(msg, s.keyMap.Select):
140 if s.isOnboarding {
141 modelInx := s.modelList.SelectedIndex()
142 items := s.modelList.Items()
143 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
144 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
145 cmd := s.setPreferredModel(selectedItem)
146 s.isOnboarding = false
147 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
148 }
149 } else if s.needsProjectInit {
150 return s, s.initializeProject()
151 }
152 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
153 if s.needsProjectInit {
154 s.selectedNo = !s.selectedNo
155 return s, nil
156 }
157 case key.Matches(msg, s.keyMap.Yes):
158 if s.needsProjectInit {
159 return s, s.initializeProject()
160 }
161 case key.Matches(msg, s.keyMap.No):
162 if s.needsProjectInit {
163 s.needsProjectInit = false
164 return s, util.CmdHandler(OnboardingCompleteMsg{})
165 }
166 default:
167 if s.isOnboarding {
168 u, cmd := s.modelList.Update(msg)
169 s.modelList = u
170 return s, cmd
171 }
172 }
173 }
174 return s, nil
175}
176
177func (s *splashCmp) initializeProject() tea.Cmd {
178 s.needsProjectInit = false
179 prompt := `Please analyze this codebase and create a CRUSH.md file containing:
1801. Build/lint/test commands - especially for running a single test
1812. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
182
183The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
184If there's already a CRUSH.md, improve it.
185If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
186Add the .crush directory to the .gitignore file if it's not already there.`
187
188 if err := config.MarkProjectInitialized(); err != nil {
189 return util.ReportError(err)
190 }
191 var cmds []tea.Cmd
192
193 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
194 if !s.selectedNo {
195 cmds = append(cmds,
196 util.CmdHandler(chat.SessionClearedMsg{}),
197 util.CmdHandler(chat.SendMsg{
198 Text: prompt,
199 }),
200 )
201 }
202 return tea.Sequence(cmds...)
203}
204
205func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
206 cfg := config.Get()
207 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
208 if model == nil {
209 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
210 }
211
212 selectedModel := config.SelectedModel{
213 Model: selectedItem.Model.ID,
214 Provider: string(selectedItem.Provider.ID),
215 ReasoningEffort: model.DefaultReasoningEffort,
216 MaxTokens: model.DefaultMaxTokens,
217 }
218
219 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
220 if err != nil {
221 return util.ReportError(err)
222 }
223
224 // Now lets automatically setup the small model
225 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
226 if err != nil {
227 return util.ReportError(err)
228 }
229 if knownProvider == nil {
230 // for local provider we just use the same model
231 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
232 if err != nil {
233 return util.ReportError(err)
234 }
235 } else {
236 smallModel := knownProvider.DefaultSmallModelID
237 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
238 // should never happen
239 if model == nil {
240 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
241 if err != nil {
242 return util.ReportError(err)
243 }
244 return nil
245 }
246 smallSelectedModel := config.SelectedModel{
247 Model: smallModel,
248 Provider: string(selectedItem.Provider.ID),
249 ReasoningEffort: model.DefaultReasoningEffort,
250 MaxTokens: model.DefaultMaxTokens,
251 }
252 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
253 if err != nil {
254 return util.ReportError(err)
255 }
256 }
257 return nil
258}
259
260func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
261 providers, err := config.Providers()
262 if err != nil {
263 return nil, err
264 }
265 for _, p := range providers {
266 if p.ID == providerID {
267 return &p, nil
268 }
269 }
270 return nil, nil
271}
272
273func (s *splashCmp) isProviderConfigured(providerID string) bool {
274 cfg := config.Get()
275 if _, ok := cfg.Providers[providerID]; ok {
276 return true
277 }
278 return false
279}
280
281// View implements SplashPage.
282func (s *splashCmp) View() tea.View {
283 t := styles.CurrentTheme()
284 var cursor *tea.Cursor
285
286 var content string
287 if s.isOnboarding {
288 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
289 modelListView := s.modelList.View()
290 cursor = s.moveCursor(modelListView.Cursor())
291 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
292 lipgloss.JoinVertical(
293 lipgloss.Left,
294 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
295 "",
296 modelListView.String(),
297 ),
298 )
299 content = lipgloss.JoinVertical(
300 lipgloss.Left,
301 s.logoRendered,
302 modelSelector,
303 )
304 } else if s.needsProjectInit {
305 t := styles.CurrentTheme()
306
307 titleStyle := t.S().Base.Foreground(t.FgBase)
308 bodyStyle := t.S().Base.Foreground(t.FgMuted)
309 shortcutStyle := t.S().Base.Foreground(t.Success)
310
311 initText := lipgloss.JoinVertical(
312 lipgloss.Left,
313 titleStyle.Render("Would you like to initialize this project?"),
314 "",
315 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
316 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
317 "",
318 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
319 "",
320 bodyStyle.Render("Would you like to initialize now?"),
321 )
322
323 yesButton := core.SelectableButton(core.ButtonOpts{
324 Text: "Yep!",
325 UnderlineIndex: 0,
326 Selected: !s.selectedNo,
327 })
328
329 noButton := core.SelectableButton(core.ButtonOpts{
330 Text: "Nope",
331 UnderlineIndex: 0,
332 Selected: s.selectedNo,
333 })
334
335 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
336
337 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
338
339 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
340 lipgloss.JoinVertical(
341 lipgloss.Left,
342 initText,
343 "",
344 buttons,
345 ),
346 )
347
348 content = lipgloss.JoinVertical(
349 lipgloss.Left,
350 s.logoRendered,
351 initContent,
352 )
353 } else {
354 content = s.logoRendered
355 }
356
357 view := tea.NewView(
358 t.S().Base.
359 Width(s.width).
360 Height(s.height).
361 PaddingTop(SplashScreenPaddingY).
362 PaddingLeft(SplashScreenPaddingX).
363 PaddingRight(SplashScreenPaddingX).
364 PaddingBottom(SplashScreenPaddingY).
365 Render(content),
366 )
367
368 view.SetCursor(cursor)
369 return view
370}
371
372func (s *splashCmp) logoBlock() string {
373 t := styles.CurrentTheme()
374 const padding = 2
375 return logo.Render(version.Version, false, logo.Opts{
376 FieldColor: t.Primary,
377 TitleColorA: t.Secondary,
378 TitleColorB: t.Primary,
379 CharmColor: t.Secondary,
380 VersionColor: t.Primary,
381 Width: s.width - (SplashScreenPaddingX * 2),
382 })
383}
384
385func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
386 if cursor == nil {
387 return nil
388 }
389 offset := m.cursorRow
390 cursor.Y += offset
391 cursor.X = cursor.X + 3 // 3 for padding
392 return cursor
393}
394
395// Bindings implements SplashPage.
396func (s *splashCmp) Bindings() []key.Binding {
397 if s.isOnboarding {
398 return []key.Binding{
399 s.keyMap.Select,
400 s.keyMap.Next,
401 s.keyMap.Previous,
402 }
403 } else if s.needsProjectInit {
404 return []key.Binding{
405 s.keyMap.Select,
406 s.keyMap.Yes,
407 s.keyMap.No,
408 s.keyMap.Tab,
409 s.keyMap.LeftRight,
410 }
411 }
412 return []key.Binding{}
413}