exec_terminal_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"net/http"
  7	"net/http/httptest"
  8	"strings"
  9	"testing"
 10	"time"
 11
 12	"github.com/coder/websocket"
 13	"github.com/coder/websocket/wsjson"
 14)
 15
 16func TestExecTerminal_SimpleCommand(t *testing.T) {
 17	h := NewTestHarness(t)
 18	defer h.cleanup()
 19
 20	mux := http.NewServeMux()
 21	h.server.RegisterRoutes(mux)
 22	server := httptest.NewServer(mux)
 23	defer server.Close()
 24
 25	// Convert http to ws URL
 26	wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=echo+hello"
 27
 28	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 29	defer cancel()
 30
 31	conn, _, err := websocket.Dial(ctx, wsURL, nil)
 32	if err != nil {
 33		t.Fatalf("Failed to dial websocket: %v", err)
 34	}
 35	defer conn.Close(websocket.StatusNormalClosure, "test done")
 36
 37	// Send init message
 38	initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
 39	if err := wsjson.Write(ctx, conn, initMsg); err != nil {
 40		t.Fatalf("Failed to write init message: %v", err)
 41	}
 42
 43	// Read messages until connection closes (server closes after sending exit)
 44	var output strings.Builder
 45	var exitCode int = -1
 46
 47	for {
 48		var msg ExecMessage
 49		err := wsjson.Read(ctx, conn, &msg)
 50		if err != nil {
 51			// Connection closed - this is expected after exit message
 52			break
 53		}
 54
 55		switch msg.Type {
 56		case "output":
 57			data, err := base64.StdEncoding.DecodeString(msg.Data)
 58			if err == nil {
 59				output.Write(data)
 60			}
 61		case "exit":
 62			if msg.Data == "0" {
 63				exitCode = 0
 64			} else {
 65				exitCode = 1
 66			}
 67			// Don't break here - continue reading until connection is closed
 68			// to ensure we've received all output
 69		case "error":
 70			t.Fatalf("Received error: %s", msg.Data)
 71		}
 72	}
 73
 74	if exitCode != 0 {
 75		t.Errorf("Expected exit code 0, got %d", exitCode)
 76	}
 77
 78	if !strings.Contains(output.String(), "hello") {
 79		t.Errorf("Expected output to contain 'hello', got: %q", output.String())
 80	}
 81}
 82
 83func TestExecTerminal_FailingCommand(t *testing.T) {
 84	h := NewTestHarness(t)
 85	defer h.cleanup()
 86
 87	mux := http.NewServeMux()
 88	h.server.RegisterRoutes(mux)
 89	server := httptest.NewServer(mux)
 90	defer server.Close()
 91
 92	wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=exit+42"
 93
 94	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 95	defer cancel()
 96
 97	conn, _, err := websocket.Dial(ctx, wsURL, nil)
 98	if err != nil {
 99		t.Fatalf("Failed to dial websocket: %v", err)
100	}
101	defer conn.Close(websocket.StatusNormalClosure, "test done")
102
103	// Send init message
104	initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
105	if err := wsjson.Write(ctx, conn, initMsg); err != nil {
106		t.Fatalf("Failed to write init message: %v", err)
107	}
108
109	// Read messages until we get exit
110	var exitCode string
111
112	for {
113		var msg ExecMessage
114		err := wsjson.Read(ctx, conn, &msg)
115		if err != nil {
116			break
117		}
118
119		if msg.Type == "exit" {
120			exitCode = msg.Data
121		}
122	}
123
124	if exitCode != "42" {
125		t.Errorf("Expected exit code 42, got %q", exitCode)
126	}
127}
128
129func TestExecTerminal_MissingCmd(t *testing.T) {
130	h := NewTestHarness(t)
131	defer h.cleanup()
132
133	mux := http.NewServeMux()
134	h.server.RegisterRoutes(mux)
135	server := httptest.NewServer(mux)
136	defer server.Close()
137
138	// Try without cmd parameter
139	wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws"
140
141	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
142	defer cancel()
143
144	_, resp, err := websocket.Dial(ctx, wsURL, nil)
145	if err == nil {
146		t.Fatal("Expected error for missing cmd parameter")
147	}
148
149	if resp != nil && resp.StatusCode != 400 {
150		t.Errorf("Expected status 400, got %d", resp.StatusCode)
151	}
152}
153
154func TestExecTerminal_WorkingDirectory(t *testing.T) {
155	h := NewTestHarness(t)
156	defer h.cleanup()
157
158	mux := http.NewServeMux()
159	h.server.RegisterRoutes(mux)
160	server := httptest.NewServer(mux)
161	defer server.Close()
162
163	wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=pwd&cwd=/tmp"
164
165	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
166	defer cancel()
167
168	conn, _, err := websocket.Dial(ctx, wsURL, nil)
169	if err != nil {
170		t.Fatalf("Failed to dial websocket: %v", err)
171	}
172	defer conn.Close(websocket.StatusNormalClosure, "test done")
173
174	// Send init message
175	initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
176	if err := wsjson.Write(ctx, conn, initMsg); err != nil {
177		t.Fatalf("Failed to write init message: %v", err)
178	}
179
180	// Read messages
181	var output strings.Builder
182
183	for {
184		var msg ExecMessage
185		err := wsjson.Read(ctx, conn, &msg)
186		if err != nil {
187			break
188		}
189
190		if msg.Type == "output" {
191			data, _ := base64.StdEncoding.DecodeString(msg.Data)
192			output.Write(data)
193		}
194	}
195
196	if !strings.Contains(output.String(), "/tmp") {
197		t.Errorf("Expected output to contain '/tmp', got: %q", output.String())
198	}
199}
200
201func TestExecTerminal_Input(t *testing.T) {
202	h := NewTestHarness(t)
203	defer h.cleanup()
204
205	mux := http.NewServeMux()
206	h.server.RegisterRoutes(mux)
207	server := httptest.NewServer(mux)
208	defer server.Close()
209
210	// Use cat which echoes input
211	wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/exec-ws?cmd=cat"
212
213	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
214	defer cancel()
215
216	conn, _, err := websocket.Dial(ctx, wsURL, nil)
217	if err != nil {
218		t.Fatalf("Failed to dial websocket: %v", err)
219	}
220	defer conn.Close(websocket.StatusNormalClosure, "test done")
221
222	// Send init message
223	initMsg := ExecMessage{Type: "init", Cols: 80, Rows: 24}
224	if err := wsjson.Write(ctx, conn, initMsg); err != nil {
225		t.Fatalf("Failed to write init message: %v", err)
226	}
227
228	// Send some input followed by EOF (Ctrl-D)
229	inputMsg := ExecMessage{Type: "input", Data: "test input\n"}
230	if err := wsjson.Write(ctx, conn, inputMsg); err != nil {
231		t.Fatalf("Failed to write input message: %v", err)
232	}
233
234	// Send EOF
235	eofMsg := ExecMessage{Type: "input", Data: "\x04"} // Ctrl-D
236	if err := wsjson.Write(ctx, conn, eofMsg); err != nil {
237		t.Fatalf("Failed to write EOF message: %v", err)
238	}
239
240	// Read messages
241	var output strings.Builder
242	var gotExit bool
243
244	for i := 0; i < 20; i++ { // Limit iterations to avoid infinite loop
245		var msg ExecMessage
246		err := wsjson.Read(ctx, conn, &msg)
247		if err != nil {
248			break
249		}
250
251		switch msg.Type {
252		case "output":
253			data, _ := base64.StdEncoding.DecodeString(msg.Data)
254			output.Write(data)
255		case "exit":
256			gotExit = true
257		}
258
259		if gotExit {
260			break
261		}
262	}
263
264	if !strings.Contains(output.String(), "test input") {
265		t.Errorf("Expected output to contain 'test input', got: %q", output.String())
266	}
267}