fix(ui): dialog: show provider name for recent models

Ayman Bagabas created

Change summary

internal/ui/dialog/models.go      |  9 ++++-
internal/ui/dialog/models_item.go | 37 ++++++++++++++-------
internal/ui/dialog/models_list.go | 54 +++++++++++++++-----------------
3 files changed, 57 insertions(+), 43 deletions(-)

Detailed changes

internal/ui/dialog/models.go 🔗

@@ -314,7 +314,7 @@ func (m *Models) setProviderItems() error {
 
 			group := NewModelGroup(t, name, true)
 			for _, model := range p.Models {
-				item := NewModelItem(t, provider, model)
+				item := NewModelItem(t, provider, model, false)
 				group.AppendItems(item)
 				itemsMap[item.ID()] = item
 				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
@@ -368,7 +368,7 @@ func (m *Models) setProviderItems() error {
 
 		group := NewModelGroup(t, name, providerConfigured)
 		for _, model := range displayProvider.Models {
-			item := NewModelItem(t, provider, model)
+			item := NewModelItem(t, provider, model, false)
 			group.AppendItems(item)
 			itemsMap[item.ID()] = item
 			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
@@ -390,6 +390,10 @@ func (m *Models) setProviderItems() error {
 				continue
 			}
 
+			// Show provider for recent items
+			item = NewModelItem(t, item.prov, item.model, true)
+			item.showProvider = true
+
 			validRecentItems = append(validRecentItems, recent)
 			recentGroup.AppendItems(item)
 			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
@@ -412,6 +416,7 @@ func (m *Models) setProviderItems() error {
 	// Set model groups in the list.
 	m.list.SetGroups(groups...)
 	m.list.SetSelectedItem(selectedItemID)
+
 	// Update placeholder based on model type
 	if m.modelType == ModelTypeLarge {
 		m.input.Placeholder = largeModelInputPlaceholder

internal/ui/dialog/models_item.go 🔗

@@ -20,9 +20,10 @@ type ModelGroup struct {
 // NewModelGroup creates a new ModelGroup.
 func NewModelGroup(t *styles.Styles, title string, configured bool, items ...*ModelItem) ModelGroup {
 	return ModelGroup{
-		Title: title,
-		Items: items,
-		t:     t,
+		Title:      title,
+		Items:      items,
+		configured: configured,
+		t:          t,
 	}
 }
 
@@ -51,21 +52,23 @@ type ModelItem struct {
 	prov  catwalk.Provider
 	model catwalk.Model
 
-	cache   map[int]string
-	t       *styles.Styles
-	m       fuzzy.Match
-	focused bool
+	cache        map[int]string
+	t            *styles.Styles
+	m            fuzzy.Match
+	focused      bool
+	showProvider bool
 }
 
 var _ ListItem = &ModelItem{}
 
 // NewModelItem creates a new ModelItem.
-func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model) *ModelItem {
+func NewModelItem(t *styles.Styles, prov catwalk.Provider, model catwalk.Model, showProvider bool) *ModelItem {
 	return &ModelItem{
-		prov:  prov,
-		model: model,
-		t:     t,
-		cache: make(map[int]string),
+		prov:         prov,
+		model:        model,
+		t:            t,
+		cache:        make(map[int]string),
+		showProvider: showProvider,
 	}
 }
 
@@ -81,15 +84,23 @@ func (m *ModelItem) ID() string {
 
 // Render implements ListItem.
 func (m *ModelItem) Render(width int) string {
-	return renderItem(m.t, m.model.Name, "", m.focused, width, m.cache, &m.m)
+	var providerInfo string
+	if m.showProvider {
+		providerInfo = string(m.prov.Name)
+	}
+	return renderItem(m.t, m.model.Name, providerInfo, m.focused, width, m.cache, &m.m)
 }
 
 // SetFocused implements ListItem.
 func (m *ModelItem) SetFocused(focused bool) {
+	if m.focused != focused {
+		m.cache = nil
+	}
 	m.focused = focused
 }
 
 // SetMatch implements ListItem.
 func (m *ModelItem) SetMatch(fm fuzzy.Match) {
+	m.cache = nil
 	m.m = fm
 }

internal/ui/dialog/models_list.go 🔗

@@ -12,7 +12,6 @@ import (
 type ModelsList struct {
 	*list.List
 	groups []ModelGroup
-	items  []list.Item
 	query  string
 	t      *styles.Styles
 }
@@ -30,6 +29,16 @@ func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
 // SetGroups sets the model groups and updates the list items.
 func (f *ModelsList) SetGroups(groups ...ModelGroup) {
 	f.groups = groups
+	items := []list.Item{}
+	for _, g := range f.groups {
+		items = append(items, &g)
+		for _, item := range g.Items {
+			items = append(items, item)
+		}
+		// Add a space separator after each provider section
+		items = append(items, list.NewSpacerItem(1))
+	}
+	f.List.SetItems(items...)
 }
 
 // SetFilter sets the filter query and updates the list items.
@@ -39,6 +48,11 @@ func (f *ModelsList) SetFilter(q string) {
 
 // SetSelectedItem sets the selected item in the list by item ID.
 func (f *ModelsList) SetSelectedItem(itemID string) {
+	if itemID == "" {
+		f.SetSelected(0)
+		return
+	}
+
 	count := 0
 	for _, g := range f.groups {
 		for _, item := range g.Items {
@@ -51,26 +65,6 @@ func (f *ModelsList) SetSelectedItem(itemID string) {
 	}
 }
 
-// SelectNext selects the next selectable item in the list.
-func (f *ModelsList) SelectNext() bool {
-	for f.List.SelectNext() {
-		if _, ok := f.List.SelectedItem().(*ModelItem); ok {
-			return true
-		}
-	}
-	return false
-}
-
-// SelectPrev selects the previous selectable item in the list.
-func (f *ModelsList) SelectPrev() bool {
-	for f.List.SelectPrev() {
-		if _, ok := f.List.SelectedItem().(*ModelItem); ok {
-			return true
-		}
-	}
-	return false
-}
-
 // VisibleItems returns the visible items after filtering.
 func (f *ModelsList) VisibleItems() []list.Item {
 	if len(f.query) == 0 {
@@ -110,10 +104,11 @@ func (f *ModelsList) VisibleItems() []list.Item {
 	visitedGroups := map[int]bool{}
 
 	// Reconstruct groups with matched items
-	for _, match := range matches {
-		item := filterableItems[match.Index]
-		// Find which group this item belongs to
-		for gi, g := range f.groups {
+	// Find which group this item belongs to
+	for gi, g := range f.groups {
+		addedCount := 0
+		for _, match := range matches {
+			item := filterableItems[match.Index]
 			if slices.Contains(groupItems[gi], item.(*ModelItem)) {
 				if !visitedGroups[gi] {
 					// Add section header
@@ -125,11 +120,14 @@ func (f *ModelsList) VisibleItems() []list.Item {
 					ms.SetMatch(match)
 					item = ms.(list.FilterableItem)
 				}
-				// Add a space separator after each provider section
-				items = append(items, item, list.NewSpacerItem(1))
-				break
+				items = append(items, item)
+				addedCount++
 			}
 		}
+		if addedCount > 0 {
+			// Add a space separator after each provider section
+			items = append(items, list.NewSpacerItem(1))
+		}
 	}
 
 	return items