From f416931307d5b446dfe1e376d9c0ade4a28b6406 Mon Sep 17 00:00:00 2001 From: huaiyuWangh <34158348+huaiyuWangh@users.noreply.github.com> Date: Wed, 1 Apr 2026 03:24:33 +0800 Subject: [PATCH] fix: conditionally show image keybindings based on model support (#2522) --- internal/ui/model/ui.go | 53 +++++++++++------ internal/ui/model/ui_test.go | 107 +++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 16 deletions(-) create mode 100644 internal/ui/model/ui_test.go diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 72826f8494f60fed9e3ade465cafb4737fcf4da5..a33cd51760fe03b33a75fb02d56d35eb2934d0f6 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1722,11 +1722,17 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { switch { case key.Matches(msg, m.keyMap.Editor.AddImage): + if !m.currentModelSupportsImages() { + break + } if cmd := m.openFilesDialog(); cmd != nil { cmds = append(cmds, cmd) } case key.Matches(msg, m.keyMap.Editor.PasteImage): + if !m.currentModelSupportsImages() { + break + } cmds = append(cmds, m.pasteImageFromClipboard) case key.Matches(msg, m.keyMap.Editor.SendMessage): @@ -2259,15 +2265,15 @@ func (m *UI) FullHelp() [][]key.Binding { switch m.focus { case uiFocusEditor: - binds = append(binds, - []key.Binding{ - k.Editor.Newline, - k.Editor.AddImage, - k.Editor.PasteImage, - k.Editor.MentionFile, - k.Editor.OpenEditor, - }, - ) + editorBinds := []key.Binding{ + k.Editor.Newline, + k.Editor.MentionFile, + k.Editor.OpenEditor, + } + if m.currentModelSupportsImages() { + editorBinds = append(editorBinds, k.Editor.AddImage, k.Editor.PasteImage) + } + binds = append(binds, editorBinds) if hasAttachments { binds = append(binds, []key.Binding{ @@ -2309,14 +2315,16 @@ func (m *UI) FullHelp() [][]key.Binding { k.Models, k.Sessions, }, - []key.Binding{ - k.Editor.Newline, - k.Editor.AddImage, - k.Editor.PasteImage, - k.Editor.MentionFile, - k.Editor.OpenEditor, - }, ) + editorBinds := []key.Binding{ + k.Editor.Newline, + k.Editor.MentionFile, + k.Editor.OpenEditor, + } + if m.currentModelSupportsImages() { + editorBinds = append(editorBinds, k.Editor.AddImage, k.Editor.PasteImage) + } + binds = append(binds, editorBinds) if hasAttachments { binds = append(binds, []key.Binding{ @@ -2339,6 +2347,19 @@ func (m *UI) FullHelp() [][]key.Binding { return binds } +func (m *UI) currentModelSupportsImages() bool { + cfg := m.com.Config() + if cfg == nil { + return false + } + agentCfg, ok := cfg.Agents[config.AgentCoder] + if !ok { + return false + } + model := cfg.GetModelByType(agentCfg.Model) + return model != nil && model.SupportsImages +} + // toggleCompactMode toggles compact mode between uiChat and uiChatCompact states. func (m *UI) toggleCompactMode() tea.Cmd { m.forceCompactMode = !m.forceCompactMode diff --git a/internal/ui/model/ui_test.go b/internal/ui/model/ui_test.go new file mode 100644 index 0000000000000000000000000000000000000000..84b216e5470619e08af404967a347a831b53bcc2 --- /dev/null +++ b/internal/ui/model/ui_test.go @@ -0,0 +1,107 @@ +package model + +import ( + "reflect" + "testing" + "unsafe" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/stretchr/testify/require" +) + +func TestCurrentModelSupportsImages(t *testing.T) { + t.Parallel() + + t.Run("returns false when config is nil", func(t *testing.T) { + t.Parallel() + + ui := newTestUIWithConfig(t, nil) + require.False(t, ui.currentModelSupportsImages()) + }) + + t.Run("returns false when coder agent is missing", func(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Providers: csync.NewMap[string, config.ProviderConfig](), + Agents: map[string]config.Agent{}, + } + ui := newTestUIWithConfig(t, cfg) + require.False(t, ui.currentModelSupportsImages()) + }) + + t.Run("returns false when model is not found", func(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Providers: csync.NewMap[string, config.ProviderConfig](), + Agents: map[string]config.Agent{ + config.AgentCoder: {Model: config.SelectedModelTypeLarge}, + }, + } + ui := newTestUIWithConfig(t, cfg) + require.False(t, ui.currentModelSupportsImages()) + }) + + t.Run("returns true when current model supports images", func(t *testing.T) { + t.Parallel() + + providers := csync.NewMap[string, config.ProviderConfig]() + providers.Set("test-provider", config.ProviderConfig{ + ID: "test-provider", + Models: []catwalk.Model{ + {ID: "test-model", SupportsImages: true}, + }, + }) + + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Provider: "test-provider", + Model: "test-model", + }, + }, + Providers: providers, + Agents: map[string]config.Agent{ + config.AgentCoder: {Model: config.SelectedModelTypeLarge}, + }, + } + + ui := newTestUIWithConfig(t, cfg) + require.True(t, ui.currentModelSupportsImages()) + }) +} + +func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI { + t.Helper() + + store := &config.ConfigStore{} + setUnexportedField(t, store, "config", cfg) + + appInstance := &app.App{} + setUnexportedField(t, appInstance, "config", store) + + return &UI{ + com: &common.Common{ + App: appInstance, + }, + } +} + +func setUnexportedField(t *testing.T, target any, name string, value any) { + t.Helper() + + v := reflect.ValueOf(target) + require.Equal(t, reflect.Pointer, v.Kind()) + require.False(t, v.IsNil()) + + field := v.Elem().FieldByName(name) + require.Truef(t, field.IsValid(), "field %q not found", name) + + fieldValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() + fieldValue.Set(reflect.ValueOf(value)) +}