1package proto_test
2
3import (
4 "encoding/json"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/agent/tools"
8 "github.com/charmbracelet/crush/internal/proto"
9 "github.com/stretchr/testify/require"
10)
11
12// TestPermissionRequestParamsTypeAssertable guards the permission
13// dialog's type assertions across the client/server boundary. The TUI
14// asserts PermissionRequest.Params to tools.*PermissionsParams; when
15// the request round-trips over the SSE wire (server → client), the
16// decoded value must be the same Go type, otherwise the dialog
17// renders empty content.
18func TestPermissionRequestParamsTypeAssertable(t *testing.T) {
19 t.Parallel()
20
21 tests := []struct {
22 name string
23 toolName string
24 params any
25 assert func(t *testing.T, got any)
26 }{
27 {
28 name: "bash",
29 toolName: tools.BashToolName,
30 params: tools.BashPermissionsParams{
31 Description: "list files",
32 Command: "ls -la",
33 WorkingDir: "/tmp",
34 RunInBackground: false,
35 },
36 assert: func(t *testing.T, got any) {
37 v, ok := got.(tools.BashPermissionsParams)
38 require.True(t, ok, "params must decode as tools.BashPermissionsParams, got %T", got)
39 require.Equal(t, "list files", v.Description)
40 require.Equal(t, "ls -la", v.Command)
41 require.Equal(t, "/tmp", v.WorkingDir)
42 },
43 },
44 {
45 name: "edit",
46 toolName: tools.EditToolName,
47 params: tools.EditPermissionsParams{
48 FilePath: "/tmp/x.go",
49 OldContent: "old",
50 NewContent: "new",
51 },
52 assert: func(t *testing.T, got any) {
53 v, ok := got.(tools.EditPermissionsParams)
54 require.True(t, ok, "params must decode as tools.EditPermissionsParams, got %T", got)
55 require.Equal(t, "/tmp/x.go", v.FilePath)
56 require.Equal(t, "old", v.OldContent)
57 require.Equal(t, "new", v.NewContent)
58 },
59 },
60 {
61 name: "write",
62 toolName: tools.WriteToolName,
63 params: tools.WritePermissionsParams{
64 FilePath: "/tmp/x.go",
65 NewContent: "new",
66 },
67 assert: func(t *testing.T, got any) {
68 v, ok := got.(tools.WritePermissionsParams)
69 require.True(t, ok, "params must decode as tools.WritePermissionsParams, got %T", got)
70 require.Equal(t, "/tmp/x.go", v.FilePath)
71 require.Equal(t, "new", v.NewContent)
72 },
73 },
74 {
75 name: "multiedit",
76 toolName: tools.MultiEditToolName,
77 params: tools.MultiEditPermissionsParams{
78 FilePath: "/tmp/x.go",
79 OldContent: "old",
80 NewContent: "new",
81 },
82 assert: func(t *testing.T, got any) {
83 v, ok := got.(tools.MultiEditPermissionsParams)
84 require.True(t, ok, "params must decode as tools.MultiEditPermissionsParams, got %T", got)
85 require.Equal(t, "/tmp/x.go", v.FilePath)
86 },
87 },
88 {
89 name: "ls",
90 toolName: tools.LSToolName,
91 params: tools.LSPermissionsParams{
92 Path: "/tmp",
93 Ignore: []string{".git"},
94 Depth: 2,
95 },
96 assert: func(t *testing.T, got any) {
97 v, ok := got.(tools.LSPermissionsParams)
98 require.True(t, ok, "params must decode as tools.LSPermissionsParams, got %T", got)
99 require.Equal(t, "/tmp", v.Path)
100 require.Equal(t, []string{".git"}, v.Ignore)
101 require.Equal(t, 2, v.Depth)
102 },
103 },
104 {
105 name: "view",
106 toolName: tools.ViewToolName,
107 params: tools.ViewPermissionsParams{
108 FilePath: "/tmp/x.go",
109 Offset: 10,
110 Limit: 100,
111 },
112 assert: func(t *testing.T, got any) {
113 v, ok := got.(tools.ViewPermissionsParams)
114 require.True(t, ok, "params must decode as tools.ViewPermissionsParams, got %T", got)
115 require.Equal(t, "/tmp/x.go", v.FilePath)
116 },
117 },
118 {
119 name: "fetch",
120 toolName: tools.FetchToolName,
121 params: tools.FetchPermissionsParams{
122 URL: "https://example.com",
123 Format: "text",
124 },
125 assert: func(t *testing.T, got any) {
126 v, ok := got.(tools.FetchPermissionsParams)
127 require.True(t, ok, "params must decode as tools.FetchPermissionsParams, got %T", got)
128 require.Equal(t, "https://example.com", v.URL)
129 },
130 },
131 {
132 name: "download",
133 toolName: tools.DownloadToolName,
134 params: tools.DownloadPermissionsParams{
135 URL: "https://example.com/x.zip",
136 FilePath: "/tmp/x.zip",
137 Timeout: 30,
138 },
139 assert: func(t *testing.T, got any) {
140 v, ok := got.(tools.DownloadPermissionsParams)
141 require.True(t, ok, "params must decode as tools.DownloadPermissionsParams, got %T", got)
142 require.Equal(t, "https://example.com/x.zip", v.URL)
143 require.Equal(t, "/tmp/x.zip", v.FilePath)
144 },
145 },
146 {
147 name: "agentic_fetch",
148 toolName: tools.AgenticFetchToolName,
149 params: tools.AgenticFetchPermissionsParams{
150 URL: "https://example.com",
151 Prompt: "summarize this page",
152 },
153 assert: func(t *testing.T, got any) {
154 v, ok := got.(tools.AgenticFetchPermissionsParams)
155 require.True(t, ok, "params must decode as tools.AgenticFetchPermissionsParams, got %T", got)
156 require.Equal(t, "https://example.com", v.URL)
157 require.Equal(t, "summarize this page", v.Prompt)
158 },
159 },
160 }
161
162 for _, tc := range tests {
163 t.Run(tc.name, func(t *testing.T) {
164 t.Parallel()
165
166 // Build a server-side request with the tool's concrete
167 // params type, marshal to JSON (the wire path), then
168 // decode back through proto.PermissionRequest.
169 outbound := proto.PermissionRequest{
170 ID: "perm-1",
171 SessionID: "sess-1",
172 ToolCallID: "call-1",
173 ToolName: tc.toolName,
174 Path: "/tmp",
175 Params: tc.params,
176 }
177 data, err := json.Marshal(outbound)
178 require.NoError(t, err)
179
180 var inbound proto.PermissionRequest
181 require.NoError(t, json.Unmarshal(data, &inbound))
182
183 tc.assert(t, inbound.Params)
184 })
185 }
186}