1package model
2
3import (
4 "reflect"
5 "testing"
6 "unsafe"
7
8 "charm.land/catwalk/pkg/catwalk"
9 "github.com/charmbracelet/crush/internal/app"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/charmbracelet/crush/internal/ui/common"
13 "github.com/stretchr/testify/require"
14)
15
16func TestCurrentModelSupportsImages(t *testing.T) {
17 t.Parallel()
18
19 t.Run("returns false when config is nil", func(t *testing.T) {
20 t.Parallel()
21
22 ui := newTestUIWithConfig(t, nil)
23 require.False(t, ui.currentModelSupportsImages())
24 })
25
26 t.Run("returns false when coder agent is missing", func(t *testing.T) {
27 t.Parallel()
28
29 cfg := &config.Config{
30 Providers: csync.NewMap[string, config.ProviderConfig](),
31 Agents: map[string]config.Agent{},
32 }
33 ui := newTestUIWithConfig(t, cfg)
34 require.False(t, ui.currentModelSupportsImages())
35 })
36
37 t.Run("returns false when model is not found", func(t *testing.T) {
38 t.Parallel()
39
40 cfg := &config.Config{
41 Providers: csync.NewMap[string, config.ProviderConfig](),
42 Agents: map[string]config.Agent{
43 config.AgentCoder: {Model: config.SelectedModelTypeLarge},
44 },
45 }
46 ui := newTestUIWithConfig(t, cfg)
47 require.False(t, ui.currentModelSupportsImages())
48 })
49
50 t.Run("returns true when current model supports images", func(t *testing.T) {
51 t.Parallel()
52
53 providers := csync.NewMap[string, config.ProviderConfig]()
54 providers.Set("test-provider", config.ProviderConfig{
55 ID: "test-provider",
56 Models: []catwalk.Model{
57 {ID: "test-model", SupportsImages: true},
58 },
59 })
60
61 cfg := &config.Config{
62 Models: map[config.SelectedModelType]config.SelectedModel{
63 config.SelectedModelTypeLarge: {
64 Provider: "test-provider",
65 Model: "test-model",
66 },
67 },
68 Providers: providers,
69 Agents: map[string]config.Agent{
70 config.AgentCoder: {Model: config.SelectedModelTypeLarge},
71 },
72 }
73
74 ui := newTestUIWithConfig(t, cfg)
75 require.True(t, ui.currentModelSupportsImages())
76 })
77}
78
79func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI {
80 t.Helper()
81
82 store := &config.ConfigStore{}
83 setUnexportedField(t, store, "config", cfg)
84
85 appInstance := &app.App{}
86 setUnexportedField(t, appInstance, "config", store)
87
88 return &UI{
89 com: &common.Common{
90 App: appInstance,
91 },
92 }
93}
94
95func setUnexportedField(t *testing.T, target any, name string, value any) {
96 t.Helper()
97
98 v := reflect.ValueOf(target)
99 require.Equal(t, reflect.Pointer, v.Kind())
100 require.False(t, v.IsNil())
101
102 field := v.Elem().FieldByName(name)
103 require.Truef(t, field.IsValid(), "field %q not found", name)
104
105 fieldValue := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem()
106 fieldValue.Set(reflect.ValueOf(value))
107}