permission_test.go

  1package proto_test
  2
  3import (
  4	"encoding/json"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/agent/tools"
  8	"github.com/charmbracelet/crush/internal/proto"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// TestPermissionRequestParamsTypeAssertable guards the permission
 13// dialog's type assertions across the client/server boundary. The TUI
 14// asserts PermissionRequest.Params to tools.*PermissionsParams; when
 15// the request round-trips over the SSE wire (server → client), the
 16// decoded value must be the same Go type, otherwise the dialog
 17// renders empty content.
 18func TestPermissionRequestParamsTypeAssertable(t *testing.T) {
 19	t.Parallel()
 20
 21	tests := []struct {
 22		name     string
 23		toolName string
 24		params   any
 25		assert   func(t *testing.T, got any)
 26	}{
 27		{
 28			name:     "bash",
 29			toolName: tools.BashToolName,
 30			params: tools.BashPermissionsParams{
 31				Description:     "list files",
 32				Command:         "ls -la",
 33				WorkingDir:      "/tmp",
 34				RunInBackground: false,
 35			},
 36			assert: func(t *testing.T, got any) {
 37				v, ok := got.(tools.BashPermissionsParams)
 38				require.True(t, ok, "params must decode as tools.BashPermissionsParams, got %T", got)
 39				require.Equal(t, "list files", v.Description)
 40				require.Equal(t, "ls -la", v.Command)
 41				require.Equal(t, "/tmp", v.WorkingDir)
 42			},
 43		},
 44		{
 45			name:     "edit",
 46			toolName: tools.EditToolName,
 47			params: tools.EditPermissionsParams{
 48				FilePath:   "/tmp/x.go",
 49				OldContent: "old",
 50				NewContent: "new",
 51			},
 52			assert: func(t *testing.T, got any) {
 53				v, ok := got.(tools.EditPermissionsParams)
 54				require.True(t, ok, "params must decode as tools.EditPermissionsParams, got %T", got)
 55				require.Equal(t, "/tmp/x.go", v.FilePath)
 56				require.Equal(t, "old", v.OldContent)
 57				require.Equal(t, "new", v.NewContent)
 58			},
 59		},
 60		{
 61			name:     "write",
 62			toolName: tools.WriteToolName,
 63			params: tools.WritePermissionsParams{
 64				FilePath:   "/tmp/x.go",
 65				NewContent: "new",
 66			},
 67			assert: func(t *testing.T, got any) {
 68				v, ok := got.(tools.WritePermissionsParams)
 69				require.True(t, ok, "params must decode as tools.WritePermissionsParams, got %T", got)
 70				require.Equal(t, "/tmp/x.go", v.FilePath)
 71				require.Equal(t, "new", v.NewContent)
 72			},
 73		},
 74		{
 75			name:     "multiedit",
 76			toolName: tools.MultiEditToolName,
 77			params: tools.MultiEditPermissionsParams{
 78				FilePath:   "/tmp/x.go",
 79				OldContent: "old",
 80				NewContent: "new",
 81			},
 82			assert: func(t *testing.T, got any) {
 83				v, ok := got.(tools.MultiEditPermissionsParams)
 84				require.True(t, ok, "params must decode as tools.MultiEditPermissionsParams, got %T", got)
 85				require.Equal(t, "/tmp/x.go", v.FilePath)
 86			},
 87		},
 88		{
 89			name:     "ls",
 90			toolName: tools.LSToolName,
 91			params: tools.LSPermissionsParams{
 92				Path:   "/tmp",
 93				Ignore: []string{".git"},
 94				Depth:  2,
 95			},
 96			assert: func(t *testing.T, got any) {
 97				v, ok := got.(tools.LSPermissionsParams)
 98				require.True(t, ok, "params must decode as tools.LSPermissionsParams, got %T", got)
 99				require.Equal(t, "/tmp", v.Path)
100				require.Equal(t, []string{".git"}, v.Ignore)
101				require.Equal(t, 2, v.Depth)
102			},
103		},
104		{
105			name:     "view",
106			toolName: tools.ViewToolName,
107			params: tools.ViewPermissionsParams{
108				FilePath: "/tmp/x.go",
109				Offset:   10,
110				Limit:    100,
111			},
112			assert: func(t *testing.T, got any) {
113				v, ok := got.(tools.ViewPermissionsParams)
114				require.True(t, ok, "params must decode as tools.ViewPermissionsParams, got %T", got)
115				require.Equal(t, "/tmp/x.go", v.FilePath)
116			},
117		},
118		{
119			name:     "fetch",
120			toolName: tools.FetchToolName,
121			params: tools.FetchPermissionsParams{
122				URL:    "https://example.com",
123				Format: "text",
124			},
125			assert: func(t *testing.T, got any) {
126				v, ok := got.(tools.FetchPermissionsParams)
127				require.True(t, ok, "params must decode as tools.FetchPermissionsParams, got %T", got)
128				require.Equal(t, "https://example.com", v.URL)
129			},
130		},
131		{
132			name:     "download",
133			toolName: tools.DownloadToolName,
134			params: tools.DownloadPermissionsParams{
135				URL:      "https://example.com/x.zip",
136				FilePath: "/tmp/x.zip",
137				Timeout:  30,
138			},
139			assert: func(t *testing.T, got any) {
140				v, ok := got.(tools.DownloadPermissionsParams)
141				require.True(t, ok, "params must decode as tools.DownloadPermissionsParams, got %T", got)
142				require.Equal(t, "https://example.com/x.zip", v.URL)
143				require.Equal(t, "/tmp/x.zip", v.FilePath)
144			},
145		},
146		{
147			name:     "agentic_fetch",
148			toolName: tools.AgenticFetchToolName,
149			params: tools.AgenticFetchPermissionsParams{
150				URL:    "https://example.com",
151				Prompt: "summarize this page",
152			},
153			assert: func(t *testing.T, got any) {
154				v, ok := got.(tools.AgenticFetchPermissionsParams)
155				require.True(t, ok, "params must decode as tools.AgenticFetchPermissionsParams, got %T", got)
156				require.Equal(t, "https://example.com", v.URL)
157				require.Equal(t, "summarize this page", v.Prompt)
158			},
159		},
160	}
161
162	for _, tc := range tests {
163		t.Run(tc.name, func(t *testing.T) {
164			t.Parallel()
165
166			// Build a server-side request with the tool's concrete
167			// params type, marshal to JSON (the wire path), then
168			// decode back through proto.PermissionRequest.
169			outbound := proto.PermissionRequest{
170				ID:         "perm-1",
171				SessionID:  "sess-1",
172				ToolCallID: "call-1",
173				ToolName:   tc.toolName,
174				Path:       "/tmp",
175				Params:     tc.params,
176			}
177			data, err := json.Marshal(outbound)
178			require.NoError(t, err)
179
180			var inbound proto.PermissionRequest
181			require.NoError(t, json.Unmarshal(data, &inbound))
182
183			tc.assert(t, inbound.Params)
184		})
185	}
186}