1package server
2
3import (
4 "context"
5 "encoding/base64"
6 "net/http"
7 "net/http/httptest"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/coder/websocket"
13 "github.com/coder/websocket/wsjson"
14)
15
16func TestExecTerminal_SimpleCommand(t *testing.T) {
17 h := NewTestHarness(t)
18 defer h.cleanup()
19
20 mux := http.NewServeMux()
21 h.server.RegisterRoutes(mux)
22 server := httptest.NewServer(mux)
23 defer server.Close()
24
25 // Convert http to ws URL
26 wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=echo+hello"
27
28 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
29 defer cancel()
30
31 conn, _, err := websocket.Dial(ctx, wsURL, nil)
32 if err != nil {
33 t.Fatalf("Failed to dial websocket: %v", err)
34 }
35 defer conn.Close(websocket.StatusNormalClosure, "test done")
36
37 // Send init message
38 initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
39 if err := wsjson.Write(ctx, conn, initMsg); err != nil {
40 t.Fatalf("Failed to write init message: %v", err)
41 }
42
43 // Read messages until connection closes (server closes after sending exit)
44 var output strings.Builder
45 var exitCode int = -1
46
47 for {
48 var msg ExecMessage
49 err := wsjson.Read(ctx, conn, &msg)
50 if err != nil {
51 // Connection closed - this is expected after exit message
52 break
53 }
54
55 switch msg.Type {
56 case "output":
57 data, err := base64.StdEncoding.DecodeString(msg.Data)
58 if err == nil {
59 output.Write(data)
60 }
61 case "exit":
62 if msg.Data == "0" {
63 exitCode = 0
64 } else {
65 exitCode = 1
66 }
67 // Don't break here - continue reading until connection is closed
68 // to ensure we've received all output
69 case "error":
70 t.Fatalf("Received error: %s", msg.Data)
71 }
72 }
73
74 if exitCode != 0 {
75 t.Errorf("Expected exit code 0, got %d", exitCode)
76 }
77
78 if !strings.Contains(output.String(), "hello") {
79 t.Errorf("Expected output to contain 'hello', got: %q", output.String())
80 }
81}
82
83func TestExecTerminal_FailingCommand(t *testing.T) {
84 h := NewTestHarness(t)
85 defer h.cleanup()
86
87 mux := http.NewServeMux()
88 h.server.RegisterRoutes(mux)
89 server := httptest.NewServer(mux)
90 defer server.Close()
91
92 wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=exit+42"
93
94 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
95 defer cancel()
96
97 conn, _, err := websocket.Dial(ctx, wsURL, nil)
98 if err != nil {
99 t.Fatalf("Failed to dial websocket: %v", err)
100 }
101 defer conn.Close(websocket.StatusNormalClosure, "test done")
102
103 // Send init message
104 initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
105 if err := wsjson.Write(ctx, conn, initMsg); err != nil {
106 t.Fatalf("Failed to write init message: %v", err)
107 }
108
109 // Read messages until we get exit
110 var exitCode string
111
112 for {
113 var msg ExecMessage
114 err := wsjson.Read(ctx, conn, &msg)
115 if err != nil {
116 break
117 }
118
119 if msg.Type == "exit" {
120 exitCode = msg.Data
121 }
122 }
123
124 if exitCode != "42" {
125 t.Errorf("Expected exit code 42, got %q", exitCode)
126 }
127}
128
129func TestExecTerminal_MissingCmd(t *testing.T) {
130 h := NewTestHarness(t)
131 defer h.cleanup()
132
133 mux := http.NewServeMux()
134 h.server.RegisterRoutes(mux)
135 server := httptest.NewServer(mux)
136 defer server.Close()
137
138 // Try without cmd parameter
139 wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws"
140
141 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
142 defer cancel()
143
144 _, resp, err := websocket.Dial(ctx, wsURL, nil)
145 if err == nil {
146 t.Fatal("Expected error for missing cmd parameter")
147 }
148
149 if resp != nil && resp.StatusCode != 400 {
150 t.Errorf("Expected status 400, got %d", resp.StatusCode)
151 }
152}
153
154func TestExecTerminal_WorkingDirectory(t *testing.T) {
155 h := NewTestHarness(t)
156 defer h.cleanup()
157
158 mux := http.NewServeMux()
159 h.server.RegisterRoutes(mux)
160 server := httptest.NewServer(mux)
161 defer server.Close()
162
163 wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=pwd&cwd=/tmp"
164
165 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
166 defer cancel()
167
168 conn, _, err := websocket.Dial(ctx, wsURL, nil)
169 if err != nil {
170 t.Fatalf("Failed to dial websocket: %v", err)
171 }
172 defer conn.Close(websocket.StatusNormalClosure, "test done")
173
174 // Send init message
175 initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
176 if err := wsjson.Write(ctx, conn, initMsg); err != nil {
177 t.Fatalf("Failed to write init message: %v", err)
178 }
179
180 // Read messages
181 var output strings.Builder
182
183 for {
184 var msg ExecMessage
185 err := wsjson.Read(ctx, conn, &msg)
186 if err != nil {
187 break
188 }
189
190 if msg.Type == "output" {
191 data, _ := base64.StdEncoding.DecodeString(msg.Data)
192 output.Write(data)
193 }
194 }
195
196 if !strings.Contains(output.String(), "/tmp") {
197 t.Errorf("Expected output to contain '/tmp', got: %q", output.String())
198 }
199}
200
201func TestExecTerminal_Input(t *testing.T) {
202 h := NewTestHarness(t)
203 defer h.cleanup()
204
205 mux := http.NewServeMux()
206 h.server.RegisterRoutes(mux)
207 server := httptest.NewServer(mux)
208 defer server.Close()
209
210 // Use cat which echoes input
211 wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=cat"
212
213 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
214 defer cancel()
215
216 conn, _, err := websocket.Dial(ctx, wsURL, nil)
217 if err != nil {
218 t.Fatalf("Failed to dial websocket: %v", err)
219 }
220 defer conn.Close(websocket.StatusNormalClosure, "test done")
221
222 // Send init message
223 initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
224 if err := wsjson.Write(ctx, conn, initMsg); err != nil {
225 t.Fatalf("Failed to write init message: %v", err)
226 }
227
228 // Send some input followed by EOF (Ctrl-D)
229 inputMsg := ExecMessage{Type: "input", Data: "test input\n"}
230 if err := wsjson.Write(ctx, conn, inputMsg); err != nil {
231 t.Fatalf("Failed to write input message: %v", err)
232 }
233
234 // Send EOF
235 eofMsg := ExecMessage{Type: "input", Data: "\x04"} // Ctrl-D
236 if err := wsjson.Write(ctx, conn, eofMsg); err != nil {
237 t.Fatalf("Failed to write EOF message: %v", err)
238 }
239
240 // Read messages
241 var output strings.Builder
242 var gotExit bool
243
244 for i := 0; i < 20; i++ { // Limit iterations to avoid infinite loop
245 var msg ExecMessage
246 err := wsjson.Read(ctx, conn, &msg)
247 if err != nil {
248 break
249 }
250
251 switch msg.Type {
252 case "output":
253 data, _ := base64.StdEncoding.DecodeString(msg.Data)
254 output.Write(data)
255 case "exit":
256 gotExit = true
257 }
258
259 if gotExit {
260 break
261 }
262 }
263
264 if !strings.Contains(output.String(), "test input") {
265 t.Errorf("Expected output to contain 'test input', got: %q", output.String())
266 }
267}