fix: conditionally show image keybindings based on model support (#2522)

huaiyuWangh created

Change summary

internal/ui/model/ui.go      |  53 +++++++++++++-----
internal/ui/model/ui_test.go | 107 ++++++++++++++++++++++++++++++++++++++
2 files changed, 144 insertions(+), 16 deletions(-)

Detailed changes

internal/ui/model/ui.go 🔗

@@ -1722,11 +1722,17 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 
 			switch {
 			case key.Matches(msg, m.keyMap.Editor.AddImage):
+				if !m.currentModelSupportsImages() {
+					break
+				}
 				if cmd := m.openFilesDialog(); cmd != nil {
 					cmds = append(cmds, cmd)
 				}
 
 			case key.Matches(msg, m.keyMap.Editor.PasteImage):
+				if !m.currentModelSupportsImages() {
+					break
+				}
 				cmds = append(cmds, m.pasteImageFromClipboard)
 
 			case key.Matches(msg, m.keyMap.Editor.SendMessage):
@@ -2259,15 +2265,15 @@ func (m *UI) FullHelp() [][]key.Binding {
 
 		switch m.focus {
 		case uiFocusEditor:
-			binds = append(binds,
-				[]key.Binding{
-					k.Editor.Newline,
-					k.Editor.AddImage,
-					k.Editor.PasteImage,
-					k.Editor.MentionFile,
-					k.Editor.OpenEditor,
-				},
-			)
+			editorBinds := []key.Binding{
+				k.Editor.Newline,
+				k.Editor.MentionFile,
+				k.Editor.OpenEditor,
+			}
+			if m.currentModelSupportsImages() {
+				editorBinds = append(editorBinds, k.Editor.AddImage, k.Editor.PasteImage)
+			}
+			binds = append(binds, editorBinds)
 			if hasAttachments {
 				binds = append(binds,
 					[]key.Binding{
@@ -2309,14 +2315,16 @@ func (m *UI) FullHelp() [][]key.Binding {
 					k.Models,
 					k.Sessions,
 				},
-				[]key.Binding{
-					k.Editor.Newline,
-					k.Editor.AddImage,
-					k.Editor.PasteImage,
-					k.Editor.MentionFile,
-					k.Editor.OpenEditor,
-				},
 			)
+			editorBinds := []key.Binding{
+				k.Editor.Newline,
+				k.Editor.MentionFile,
+				k.Editor.OpenEditor,
+			}
+			if m.currentModelSupportsImages() {
+				editorBinds = append(editorBinds, k.Editor.AddImage, k.Editor.PasteImage)
+			}
+			binds = append(binds, editorBinds)
 			if hasAttachments {
 				binds = append(binds,
 					[]key.Binding{
@@ -2339,6 +2347,19 @@ func (m *UI) FullHelp() [][]key.Binding {
 	return binds
 }
 
+func (m *UI) currentModelSupportsImages() bool {
+	cfg := m.com.Config()
+	if cfg == nil {
+		return false
+	}
+	agentCfg, ok := cfg.Agents[config.AgentCoder]
+	if !ok {
+		return false
+	}
+	model := cfg.GetModelByType(agentCfg.Model)
+	return model != nil && model.SupportsImages
+}
+
 // toggleCompactMode toggles compact mode between uiChat and uiChatCompact states.
 func (m *UI) toggleCompactMode() tea.Cmd {
 	m.forceCompactMode = !m.forceCompactMode

internal/ui/model/ui_test.go 🔗

@@ -0,0 +1,107 @@
+package model
+
+import (
+	"reflect"
+	"testing"
+	"unsafe"
+
+	"charm.land/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/app"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/ui/common"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCurrentModelSupportsImages(t *testing.T) {
+	t.Parallel()
+
+	t.Run("returns false when config is nil", func(t *testing.T) {
+		t.Parallel()
+
+		ui := newTestUIWithConfig(t, nil)
+		require.False(t, ui.currentModelSupportsImages())
+	})
+
+	t.Run("returns false when coder agent is missing", func(t *testing.T) {
+		t.Parallel()
+
+		cfg := &config.Config{
+			Providers: csync.NewMap[string, config.ProviderConfig](),
+			Agents:    map[string]config.Agent{},
+		}
+		ui := newTestUIWithConfig(t, cfg)
+		require.False(t, ui.currentModelSupportsImages())
+	})
+
+	t.Run("returns false when model is not found", func(t *testing.T) {
+		t.Parallel()
+
+		cfg := &config.Config{
+			Providers: csync.NewMap[string, config.ProviderConfig](),
+			Agents: map[string]config.Agent{
+				config.AgentCoder: {Model: config.SelectedModelTypeLarge},
+			},
+		}
+		ui := newTestUIWithConfig(t, cfg)
+		require.False(t, ui.currentModelSupportsImages())
+	})
+
+	t.Run("returns true when current model supports images", func(t *testing.T) {
+		t.Parallel()
+
+		providers := csync.NewMap[string, config.ProviderConfig]()
+		providers.Set("test-provider", config.ProviderConfig{
+			ID: "test-provider",
+			Models: []catwalk.Model{
+				{ID: "test-model", SupportsImages: true},
+			},
+		})
+
+		cfg := &config.Config{
+			Models: map[config.SelectedModelType]config.SelectedModel{
+				config.SelectedModelTypeLarge: {
+					Provider: "test-provider",
+					Model:    "test-model",
+				},
+			},
+			Providers: providers,
+			Agents: map[string]config.Agent{
+				config.AgentCoder: {Model: config.SelectedModelTypeLarge},
+			},
+		}
+
+		ui := newTestUIWithConfig(t, cfg)
+		require.True(t, ui.currentModelSupportsImages())
+	})
+}
+
+func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI {
+	t.Helper()
+
+	store := &config.ConfigStore{}
+	setUnexportedField(t, store, "config", cfg)
+
+	appInstance := &app.App{}
+	setUnexportedField(t, appInstance, "config", store)
+
+	return &UI{
+		com: &common.Common{
+			App: appInstance,
+		},
+	}
+}
+
+func setUnexportedField(t *testing.T, target any, name string, value any) {
+	t.Helper()
+
+	v := reflect.ValueOf(target)
+	require.Equal(t, reflect.Pointer, v.Kind())
+	require.False(t, v.IsNil())
+
+	field := v.Elem().FieldByName(name)
+	require.Truef(t, field.IsValid(), "field %q not found", name)
+
+	fieldValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
+	fieldValue.Set(reflect.ValueOf(value))
+}