1//go:build windows
2// +build windows
3
4package pty
5
6import (
7 "context"
8 "errors"
9 "fmt"
10 "os"
11 "strings"
12 "sync"
13 "syscall"
14 "unicode/utf16"
15 "unsafe"
16
17 "golang.org/x/sys/windows"
18)
19
20const (
21 _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016 // nolint:revive
22)
23
24var (
25 errClosedConPty = errors.New("pseudo console is closed")
26 errNotStarted = errors.New("process not started")
27)
28
29// Install this from github.com/Microsoft/go-winio
30// go install github.com/Microsoft/go-winio/tools/mkwinsyscall@latest
31//go:generate mkwinsyscall -output zsyscall_windows.go ./*.go
32
33// https://github.com/microsoft/hcsshim/blob/main/internal/conpty/conpty.go
34type conPty struct {
35 handle windows.Handle
36 inPipe, outPipe *os.File
37 mtx sync.RWMutex
38}
39
40var _ Pty = &conPty{}
41
42func newPty() (Pty, error) {
43 ptyIn, inPipeOurs, err := os.Pipe()
44 if err != nil {
45 return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err)
46 }
47
48 outPipeOurs, ptyOut, err := os.Pipe()
49 if err != nil {
50 return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err)
51 }
52
53 var hpc windows.Handle
54 coord := windows.Coord{X: 80, Y: 25}
55 err = createPseudoConsole(coord, windows.Handle(ptyIn.Fd()), windows.Handle(ptyOut.Fd()), 0, &hpc)
56 if err != nil {
57 return nil, fmt.Errorf("failed to create pseudo console: %w", err)
58 }
59
60 if err := ptyOut.Close(); err != nil {
61 return nil, fmt.Errorf("failed to close pseudo console handle: %w", err)
62 }
63 if err := ptyIn.Close(); err != nil {
64 return nil, fmt.Errorf("failed to close pseudo console handle: %w", err)
65 }
66
67 return &conPty{
68 handle: hpc,
69 inPipe: inPipeOurs,
70 outPipe: outPipeOurs,
71 }, nil
72}
73
74// Close implements Pty.
75func (p *conPty) Close() error {
76 p.mtx.Lock()
77 defer p.mtx.Unlock()
78
79 closePseudoConsole(p.handle)
80 return errors.Join(p.inPipe.Close(), p.outPipe.Close())
81}
82
83// Command implements Pty.
84func (p *conPty) Command(name string, args ...string) *Cmd {
85 return p.CommandContext(nil, name, args...)
86}
87
88// CommandContext implements Pty.
89func (p *conPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd {
90 return &Cmd{
91 ctx: ctx,
92 pty: p,
93 Path: name,
94 Args: append([]string{name}, args...),
95 }
96}
97
98// Control implements Pty.
99func (*conPty) Control(func(fd uintptr)) error {
100 return nil
101}
102
103// Name implements Pty.
104func (*conPty) Name() string {
105 return "windows-pty"
106}
107
108// Read implements Pty.
109func (p *conPty) Read(b []byte) (n int, err error) {
110 return p.outPipe.Read(b)
111}
112
113// Resize implements Pty.
114func (p *conPty) Resize(rows int, cols int) error {
115 if err := resizePseudoConsole(p.handle, windows.Coord{X: int16(cols), Y: int16(rows)}); err != nil {
116 return fmt.Errorf("failed to resize pseudo console: %w", err)
117 }
118 return nil
119}
120
121// Write implements Pty.
122func (p *conPty) Write(b []byte) (n int, err error) {
123 return p.inPipe.Write(b)
124}
125
126// updateProcThreadAttribute updates the passed in attribute list to contain the entry necessary for use with
127// CreateProcess.
128func (p *conPty) updateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error {
129 p.mtx.RLock()
130 defer p.mtx.RUnlock()
131
132 if p.handle == 0 {
133 return errClosedConPty
134 }
135
136 if err := attrList.Update(
137 _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
138 unsafe.Pointer(p.handle),
139 unsafe.Sizeof(p.handle),
140 ); err != nil {
141 return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err)
142 }
143
144 return nil
145}
146
147// createPseudoConsole creates a windows pseudo console.
148func createPseudoConsole(size windows.Coord, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) error {
149 // We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand.
150 return _createPseudoConsole(*((*uint32)(unsafe.Pointer(&size))), hInput, hOutput, dwFlags, hpcon)
151}
152
153// resizePseudoConsole resizes the internal buffers of the pseudo console to the width and height specified in `size`.
154func resizePseudoConsole(hpcon windows.Handle, size windows.Coord) error {
155 // We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand.
156 return _resizePseudoConsole(hpcon, *((*uint32)(unsafe.Pointer(&size))))
157}
158
159//sys _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) = kernel32.CreatePseudoConsole
160//sys _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) = kernel32.ResizePseudoConsole
161//sys closePseudoConsole(hpc windows.Handle) = kernel32.ClosePseudoConsole
162
163type conPtySys struct {
164 attrs *windows.ProcThreadAttributeListContainer
165 done chan error
166 cmdErr error
167}
168
169func (c *Cmd) start() error {
170 pty, ok := c.pty.(*conPty)
171 if !ok {
172 return ErrInvalidCommand
173 }
174
175 argv0 := c.Path
176 if len(c.Dir) != 0 {
177 // Windows CreateProcess looks for argv0 relative to the current
178 // directory, and, only once the new process is started, it does
179 // Chdir(attr.Dir). We are adjusting for that difference here by
180 // making argv0 absolute.
181 var err error
182 argv0, err = joinExeDirAndFName(c.Dir, c.Path)
183 if err != nil {
184 return err
185 }
186 }
187
188 argv0p, err := windows.UTF16PtrFromString(argv0)
189 if err != nil {
190 return err
191 }
192
193 var cmdline string
194 if c.SysProcAttr.CmdLine != "" {
195 cmdline = c.SysProcAttr.CmdLine
196 } else {
197 cmdline = windows.ComposeCommandLine(c.Args)
198 }
199 argvp, err := windows.UTF16PtrFromString(cmdline)
200 if err != nil {
201 return err
202 }
203
204 var dirp *uint16
205 if len(c.Dir) != 0 {
206 dirp, err = windows.UTF16PtrFromString(c.Dir)
207 if err != nil {
208 return err
209 }
210 }
211
212 if c.Env == nil {
213 c.Env, err = execEnvDefault(c.SysProcAttr)
214 if err != nil {
215 return err
216 }
217 }
218
219 siEx := new(windows.StartupInfoEx)
220 siEx.Flags = windows.STARTF_USESTDHANDLES
221 pi := new(windows.ProcessInformation)
222
223 // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field.
224 flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags
225
226 // Allocate an attribute list that's large enough to do the operations we care about
227 // 2. Pseudo console setup if one was requested.
228 // Therefore we need a list of size 3.
229 attrs, err := windows.NewProcThreadAttributeList(1)
230 if err != nil {
231 return fmt.Errorf("failed to initialize process thread attribute list: %w", err)
232 }
233
234 c.sys = &conPtySys{
235 attrs: attrs,
236 done: make(chan error, 1),
237 }
238
239 if err := pty.updateProcThreadAttribute(attrs); err != nil {
240 return err
241 }
242
243 var zeroSec windows.SecurityAttributes
244 pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
245 if c.SysProcAttr.ProcessAttributes != nil {
246 pSec = &windows.SecurityAttributes{
247 Length: c.SysProcAttr.ProcessAttributes.Length,
248 InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle,
249 }
250 }
251 tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
252 if c.SysProcAttr.ThreadAttributes != nil {
253 tSec = &windows.SecurityAttributes{
254 Length: c.SysProcAttr.ThreadAttributes.Length,
255 InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle,
256 }
257 }
258
259 siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall
260 siEx.Cb = uint32(unsafe.Sizeof(*siEx))
261 if c.SysProcAttr.Token != 0 {
262 err = windows.CreateProcessAsUser(
263 windows.Token(c.SysProcAttr.Token),
264 argv0p,
265 argvp,
266 pSec,
267 tSec,
268 false,
269 flags,
270 createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))),
271 dirp,
272 &siEx.StartupInfo,
273 pi,
274 )
275 } else {
276 err = windows.CreateProcess(
277 argv0p,
278 argvp,
279 pSec,
280 tSec,
281 false,
282 flags,
283 createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))),
284 dirp,
285 &siEx.StartupInfo,
286 pi,
287 )
288 }
289 if err != nil {
290 return fmt.Errorf("failed to create process: %w", err)
291 }
292 // Don't need the thread handle for anything.
293 defer func() {
294 _ = windows.CloseHandle(pi.Thread)
295 }()
296
297 // Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a
298 // process has been launched.
299 c.Process, err = os.FindProcess(int(pi.ProcessId))
300 if err != nil {
301 // If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the
302 // object.
303 if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil {
304 return fmt.Errorf("failed to terminate process after process not found: %w", tErr)
305 }
306 return fmt.Errorf("failed to find process after starting: %w", err)
307 }
308
309 if c.ctx != nil {
310 if c.Cancel == nil {
311 c.Cancel = func() error {
312 if c.Process == nil {
313 return errNotStarted
314 }
315 return c.Process.Kill()
316 }
317 }
318 go c.waitOnContext()
319 }
320
321 return nil
322}
323
324func (c *Cmd) waitOnContext() {
325 sys := c.sys.(*conPtySys)
326 select {
327 case <-c.ctx.Done():
328 _ = c.Cancel()
329 sys.cmdErr = c.ctx.Err()
330 case err := <-sys.done:
331 sys.cmdErr = err
332 }
333}
334
335func (c *Cmd) wait() (retErr error) {
336 if c.Process == nil {
337 return errNotStarted
338 }
339 if c.ProcessState != nil {
340 return errors.New("process already waited on")
341 }
342 defer func() {
343 sys := c.sys.(*conPtySys)
344 sys.attrs.Delete()
345 sys.done <- nil
346 if retErr == nil {
347 retErr = sys.cmdErr
348 }
349 }()
350 c.ProcessState, retErr = c.Process.Wait()
351 if retErr != nil {
352 return retErr
353 }
354 return
355}
356
357//
358// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities
359// found in the go stdlib.
360//
361
362func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) {
363 if sys == nil || sys.Token == 0 {
364 return syscall.Environ(), nil
365 }
366
367 var block *uint16
368 err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false)
369 if err != nil {
370 return nil, err
371 }
372
373 defer windows.DestroyEnvironmentBlock(block)
374 blockp := uintptr(unsafe.Pointer(block))
375
376 for {
377 // find NUL terminator
378 end := unsafe.Pointer(blockp)
379 for *(*uint16)(end) != 0 {
380 end = unsafe.Pointer(uintptr(end) + 2)
381 }
382
383 n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2
384 if n == 0 {
385 // environment block ends with empty string
386 break
387 }
388
389 entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n]
390 env = append(env, string(utf16.Decode(entry)))
391 blockp += 2 * (uintptr(len(entry)) + 1)
392 }
393 return
394}
395
396func isSlash(c uint8) bool {
397 return c == '\\' || c == '/'
398}
399
400func normalizeDir(dir string) (name string, err error) {
401 ndir, err := syscall.FullPath(dir)
402 if err != nil {
403 return "", err
404 }
405 if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) {
406 // dir cannot have \\server\share\path form
407 return "", syscall.EINVAL
408 }
409 return ndir, nil
410}
411
412func volToUpper(ch int) int {
413 if 'a' <= ch && ch <= 'z' {
414 ch += 'A' - 'a'
415 }
416 return ch
417}
418
419func joinExeDirAndFName(dir, p string) (name string, err error) {
420 if len(p) == 0 {
421 return "", syscall.EINVAL
422 }
423 if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) {
424 // \\server\share\path form
425 return p, nil
426 }
427 if len(p) > 1 && p[1] == ':' {
428 // has drive letter
429 if len(p) == 2 {
430 return "", syscall.EINVAL
431 }
432 if isSlash(p[2]) {
433 return p, nil
434 } else {
435 d, err := normalizeDir(dir)
436 if err != nil {
437 return "", err
438 }
439 if volToUpper(int(p[0])) == volToUpper(int(d[0])) {
440 return syscall.FullPath(d + "\\" + p[2:])
441 } else {
442 return syscall.FullPath(p)
443 }
444 }
445 } else {
446 // no drive letter
447 d, err := normalizeDir(dir)
448 if err != nil {
449 return "", err
450 }
451 if isSlash(p[0]) {
452 return windows.FullPath(d[:2] + p)
453 } else {
454 return windows.FullPath(d + "\\" + p)
455 }
456 }
457}
458
459// createEnvBlock converts an array of environment strings into
460// the representation required by CreateProcess: a sequence of NUL
461// terminated strings followed by a nil.
462// Last bytes are two UCS-2 NULs, or four NUL bytes.
463func createEnvBlock(envv []string) *uint16 {
464 if len(envv) == 0 {
465 return &utf16.Encode([]rune("\x00\x00"))[0]
466 }
467 length := 0
468 for _, s := range envv {
469 length += len(s) + 1
470 }
471 length++
472
473 b := make([]byte, length)
474 i := 0
475 for _, s := range envv {
476 l := len(s)
477 copy(b[i:i+l], []byte(s))
478 copy(b[i+l:i+l+1], []byte{0})
479 i = i + l + 1
480 }
481 copy(b[i:i+1], []byte{0})
482
483 return &utf16.Encode([]rune(string(b)))[0]
484}
485
486// dedupEnvCase is dedupEnv with a case option for testing.
487// If caseInsensitive is true, the case of keys is ignored.
488func dedupEnvCase(caseInsensitive bool, env []string) []string {
489 out := make([]string, 0, len(env))
490 saw := make(map[string]int, len(env)) // key => index into out
491 for _, kv := range env {
492 eq := strings.Index(kv, "=")
493 if eq < 0 {
494 out = append(out, kv)
495 continue
496 }
497 k := kv[:eq]
498 if caseInsensitive {
499 k = strings.ToLower(k)
500 }
501 if dupIdx, isDup := saw[k]; isDup {
502 out[dupIdx] = kv
503 continue
504 }
505 saw[k] = len(out)
506 out = append(out, kv)
507 }
508 return out
509}
510
511// addCriticalEnv adds any critical environment variables that are required
512// (or at least almost always required) on the operating system.
513// Currently this is only used for Windows.
514func addCriticalEnv(env []string) []string {
515 for _, kv := range env {
516 eq := strings.Index(kv, "=")
517 if eq < 0 {
518 continue
519 }
520 k := kv[:eq]
521 if strings.EqualFold(k, "SYSTEMROOT") {
522 // We already have it.
523 return env
524 }
525 }
526 return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
527}