pty_windows.go

  1//go:build windows
  2// +build windows
  3
  4package pty
  5
  6import (
  7	"context"
  8	"errors"
  9	"fmt"
 10	"os"
 11	"strings"
 12	"sync"
 13	"syscall"
 14	"unicode/utf16"
 15	"unsafe"
 16
 17	"golang.org/x/sys/windows"
 18)
 19
 20const (
 21	_PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016 // nolint:revive
 22)
 23
 24var (
 25	errClosedConPty = errors.New("pseudo console is closed")
 26	errNotStarted   = errors.New("process not started")
 27)
 28
 29// Install this from github.com/Microsoft/go-winio
 30// go install github.com/Microsoft/go-winio/tools/mkwinsyscall@latest
 31//go:generate mkwinsyscall -output zsyscall_windows.go ./*.go
 32
 33// https://github.com/microsoft/hcsshim/blob/main/internal/conpty/conpty.go
 34type conPty struct {
 35	handle          windows.Handle
 36	inPipe, outPipe *os.File
 37	mtx             sync.RWMutex
 38}
 39
 40var _ Pty = &conPty{}
 41
 42func newPty() (Pty, error) {
 43	ptyIn, inPipeOurs, err := os.Pipe()
 44	if err != nil {
 45		return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err)
 46	}
 47
 48	outPipeOurs, ptyOut, err := os.Pipe()
 49	if err != nil {
 50		return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err)
 51	}
 52
 53	var hpc windows.Handle
 54	coord := windows.Coord{X: 80, Y: 25}
 55	err = createPseudoConsole(coord, windows.Handle(ptyIn.Fd()), windows.Handle(ptyOut.Fd()), 0, &hpc)
 56	if err != nil {
 57		return nil, fmt.Errorf("failed to create pseudo console: %w", err)
 58	}
 59
 60	if err := ptyOut.Close(); err != nil {
 61		return nil, fmt.Errorf("failed to close pseudo console handle: %w", err)
 62	}
 63	if err := ptyIn.Close(); err != nil {
 64		return nil, fmt.Errorf("failed to close pseudo console handle: %w", err)
 65	}
 66
 67	return &conPty{
 68		handle:  hpc,
 69		inPipe:  inPipeOurs,
 70		outPipe: outPipeOurs,
 71	}, nil
 72}
 73
 74// Close implements Pty.
 75func (p *conPty) Close() error {
 76	p.mtx.Lock()
 77	defer p.mtx.Unlock()
 78
 79	closePseudoConsole(p.handle)
 80	return errors.Join(p.inPipe.Close(), p.outPipe.Close())
 81}
 82
 83// Command implements Pty.
 84func (p *conPty) Command(name string, args ...string) *Cmd {
 85	return p.CommandContext(nil, name, args...)
 86}
 87
 88// CommandContext implements Pty.
 89func (p *conPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd {
 90	return &Cmd{
 91		ctx:  ctx,
 92		pty:  p,
 93		Path: name,
 94		Args: append([]string{name}, args...),
 95	}
 96}
 97
 98// Control implements Pty.
 99func (*conPty) Control(func(fd uintptr)) error {
100	return nil
101}
102
103// Name implements Pty.
104func (*conPty) Name() string {
105	return "windows-pty"
106}
107
108// Read implements Pty.
109func (p *conPty) Read(b []byte) (n int, err error) {
110	return p.outPipe.Read(b)
111}
112
113// Resize implements Pty.
114func (p *conPty) Resize(rows int, cols int) error {
115	if err := resizePseudoConsole(p.handle, windows.Coord{X: int16(cols), Y: int16(rows)}); err != nil {
116		return fmt.Errorf("failed to resize pseudo console: %w", err)
117	}
118	return nil
119}
120
121// Write implements Pty.
122func (p *conPty) Write(b []byte) (n int, err error) {
123	return p.inPipe.Write(b)
124}
125
126// updateProcThreadAttribute updates the passed in attribute list to contain the entry necessary for use with
127// CreateProcess.
128func (p *conPty) updateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error {
129	p.mtx.RLock()
130	defer p.mtx.RUnlock()
131
132	if p.handle == 0 {
133		return errClosedConPty
134	}
135
136	if err := attrList.Update(
137		_PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
138		unsafe.Pointer(p.handle),
139		unsafe.Sizeof(p.handle),
140	); err != nil {
141		return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err)
142	}
143
144	return nil
145}
146
147// createPseudoConsole creates a windows pseudo console.
148func createPseudoConsole(size windows.Coord, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) error {
149	// We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand.
150	return _createPseudoConsole(*((*uint32)(unsafe.Pointer(&size))), hInput, hOutput, dwFlags, hpcon)
151}
152
153// resizePseudoConsole resizes the internal buffers of the pseudo console to the width and height specified in `size`.
154func resizePseudoConsole(hpcon windows.Handle, size windows.Coord) error {
155	// We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand.
156	return _resizePseudoConsole(hpcon, *((*uint32)(unsafe.Pointer(&size))))
157}
158
159//sys _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) = kernel32.CreatePseudoConsole
160//sys _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) = kernel32.ResizePseudoConsole
161//sys closePseudoConsole(hpc windows.Handle) = kernel32.ClosePseudoConsole
162
163type conPtySys struct {
164	attrs  *windows.ProcThreadAttributeListContainer
165	done   chan error
166	cmdErr error
167}
168
169func (c *Cmd) start() error {
170	pty, ok := c.pty.(*conPty)
171	if !ok {
172		return ErrInvalidCommand
173	}
174
175	argv0 := c.Path
176	if len(c.Dir) != 0 {
177		// Windows CreateProcess looks for argv0 relative to the current
178		// directory, and, only once the new process is started, it does
179		// Chdir(attr.Dir). We are adjusting for that difference here by
180		// making argv0 absolute.
181		var err error
182		argv0, err = joinExeDirAndFName(c.Dir, c.Path)
183		if err != nil {
184			return err
185		}
186	}
187
188	argv0p, err := windows.UTF16PtrFromString(argv0)
189	if err != nil {
190		return err
191	}
192
193	var cmdline string
194	if c.SysProcAttr.CmdLine != "" {
195		cmdline = c.SysProcAttr.CmdLine
196	} else {
197		cmdline = windows.ComposeCommandLine(c.Args)
198	}
199	argvp, err := windows.UTF16PtrFromString(cmdline)
200	if err != nil {
201		return err
202	}
203
204	var dirp *uint16
205	if len(c.Dir) != 0 {
206		dirp, err = windows.UTF16PtrFromString(c.Dir)
207		if err != nil {
208			return err
209		}
210	}
211
212	if c.Env == nil {
213		c.Env, err = execEnvDefault(c.SysProcAttr)
214		if err != nil {
215			return err
216		}
217	}
218
219	siEx := new(windows.StartupInfoEx)
220	siEx.Flags = windows.STARTF_USESTDHANDLES
221	pi := new(windows.ProcessInformation)
222
223	// Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field.
224	flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags
225
226	// Allocate an attribute list that's large enough to do the operations we care about
227	// 2. Pseudo console setup if one was requested.
228	// Therefore we need a list of size 3.
229	attrs, err := windows.NewProcThreadAttributeList(1)
230	if err != nil {
231		return fmt.Errorf("failed to initialize process thread attribute list: %w", err)
232	}
233
234	c.sys = &conPtySys{
235		attrs: attrs,
236		done:  make(chan error, 1),
237	}
238
239	if err := pty.updateProcThreadAttribute(attrs); err != nil {
240		return err
241	}
242
243	var zeroSec windows.SecurityAttributes
244	pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
245	if c.SysProcAttr.ProcessAttributes != nil {
246		pSec = &windows.SecurityAttributes{
247			Length:        c.SysProcAttr.ProcessAttributes.Length,
248			InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle,
249		}
250	}
251	tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
252	if c.SysProcAttr.ThreadAttributes != nil {
253		tSec = &windows.SecurityAttributes{
254			Length:        c.SysProcAttr.ThreadAttributes.Length,
255			InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle,
256		}
257	}
258
259	siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall
260	siEx.Cb = uint32(unsafe.Sizeof(*siEx))
261	if c.SysProcAttr.Token != 0 {
262		err = windows.CreateProcessAsUser(
263			windows.Token(c.SysProcAttr.Token),
264			argv0p,
265			argvp,
266			pSec,
267			tSec,
268			false,
269			flags,
270			createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))),
271			dirp,
272			&siEx.StartupInfo,
273			pi,
274		)
275	} else {
276		err = windows.CreateProcess(
277			argv0p,
278			argvp,
279			pSec,
280			tSec,
281			false,
282			flags,
283			createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))),
284			dirp,
285			&siEx.StartupInfo,
286			pi,
287		)
288	}
289	if err != nil {
290		return fmt.Errorf("failed to create process: %w", err)
291	}
292	// Don't need the thread handle for anything.
293	defer func() {
294		_ = windows.CloseHandle(pi.Thread)
295	}()
296
297	// Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a
298	// process has been launched.
299	c.Process, err = os.FindProcess(int(pi.ProcessId))
300	if err != nil {
301		// If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the
302		// object.
303		if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil {
304			return fmt.Errorf("failed to terminate process after process not found: %w", tErr)
305		}
306		return fmt.Errorf("failed to find process after starting: %w", err)
307	}
308
309	if c.ctx != nil {
310		if c.Cancel == nil {
311			c.Cancel = func() error {
312				if c.Process == nil {
313					return errNotStarted
314				}
315				return c.Process.Kill()
316			}
317		}
318		go c.waitOnContext()
319	}
320
321	return nil
322}
323
324func (c *Cmd) waitOnContext() {
325	sys := c.sys.(*conPtySys)
326	select {
327	case <-c.ctx.Done():
328		_ = c.Cancel()
329		sys.cmdErr = c.ctx.Err()
330	case err := <-sys.done:
331		sys.cmdErr = err
332	}
333}
334
335func (c *Cmd) wait() (retErr error) {
336	if c.Process == nil {
337		return errNotStarted
338	}
339	if c.ProcessState != nil {
340		return errors.New("process already waited on")
341	}
342	defer func() {
343		sys := c.sys.(*conPtySys)
344		sys.attrs.Delete()
345		sys.done <- nil
346		if retErr == nil {
347			retErr = sys.cmdErr
348		}
349	}()
350	c.ProcessState, retErr = c.Process.Wait()
351	if retErr != nil {
352		return retErr
353	}
354	return
355}
356
357//
358// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities
359// found in the go stdlib.
360//
361
362func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) {
363	if sys == nil || sys.Token == 0 {
364		return syscall.Environ(), nil
365	}
366
367	var block *uint16
368	err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false)
369	if err != nil {
370		return nil, err
371	}
372
373	defer windows.DestroyEnvironmentBlock(block)
374	blockp := uintptr(unsafe.Pointer(block))
375
376	for {
377		// find NUL terminator
378		end := unsafe.Pointer(blockp)
379		for *(*uint16)(end) != 0 {
380			end = unsafe.Pointer(uintptr(end) + 2)
381		}
382
383		n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2
384		if n == 0 {
385			// environment block ends with empty string
386			break
387		}
388
389		entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n]
390		env = append(env, string(utf16.Decode(entry)))
391		blockp += 2 * (uintptr(len(entry)) + 1)
392	}
393	return
394}
395
396func isSlash(c uint8) bool {
397	return c == '\\' || c == '/'
398}
399
400func normalizeDir(dir string) (name string, err error) {
401	ndir, err := syscall.FullPath(dir)
402	if err != nil {
403		return "", err
404	}
405	if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) {
406		// dir cannot have \\server\share\path form
407		return "", syscall.EINVAL
408	}
409	return ndir, nil
410}
411
412func volToUpper(ch int) int {
413	if 'a' <= ch && ch <= 'z' {
414		ch += 'A' - 'a'
415	}
416	return ch
417}
418
419func joinExeDirAndFName(dir, p string) (name string, err error) {
420	if len(p) == 0 {
421		return "", syscall.EINVAL
422	}
423	if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) {
424		// \\server\share\path form
425		return p, nil
426	}
427	if len(p) > 1 && p[1] == ':' {
428		// has drive letter
429		if len(p) == 2 {
430			return "", syscall.EINVAL
431		}
432		if isSlash(p[2]) {
433			return p, nil
434		} else {
435			d, err := normalizeDir(dir)
436			if err != nil {
437				return "", err
438			}
439			if volToUpper(int(p[0])) == volToUpper(int(d[0])) {
440				return syscall.FullPath(d + "\\" + p[2:])
441			} else {
442				return syscall.FullPath(p)
443			}
444		}
445	} else {
446		// no drive letter
447		d, err := normalizeDir(dir)
448		if err != nil {
449			return "", err
450		}
451		if isSlash(p[0]) {
452			return windows.FullPath(d[:2] + p)
453		} else {
454			return windows.FullPath(d + "\\" + p)
455		}
456	}
457}
458
459// createEnvBlock converts an array of environment strings into
460// the representation required by CreateProcess: a sequence of NUL
461// terminated strings followed by a nil.
462// Last bytes are two UCS-2 NULs, or four NUL bytes.
463func createEnvBlock(envv []string) *uint16 {
464	if len(envv) == 0 {
465		return &utf16.Encode([]rune("\x00\x00"))[0]
466	}
467	length := 0
468	for _, s := range envv {
469		length += len(s) + 1
470	}
471	length++
472
473	b := make([]byte, length)
474	i := 0
475	for _, s := range envv {
476		l := len(s)
477		copy(b[i:i+l], []byte(s))
478		copy(b[i+l:i+l+1], []byte{0})
479		i = i + l + 1
480	}
481	copy(b[i:i+1], []byte{0})
482
483	return &utf16.Encode([]rune(string(b)))[0]
484}
485
486// dedupEnvCase is dedupEnv with a case option for testing.
487// If caseInsensitive is true, the case of keys is ignored.
488func dedupEnvCase(caseInsensitive bool, env []string) []string {
489	out := make([]string, 0, len(env))
490	saw := make(map[string]int, len(env)) // key => index into out
491	for _, kv := range env {
492		eq := strings.Index(kv, "=")
493		if eq < 0 {
494			out = append(out, kv)
495			continue
496		}
497		k := kv[:eq]
498		if caseInsensitive {
499			k = strings.ToLower(k)
500		}
501		if dupIdx, isDup := saw[k]; isDup {
502			out[dupIdx] = kv
503			continue
504		}
505		saw[k] = len(out)
506		out = append(out, kv)
507	}
508	return out
509}
510
511// addCriticalEnv adds any critical environment variables that are required
512// (or at least almost always required) on the operating system.
513// Currently this is only used for Windows.
514func addCriticalEnv(env []string) []string {
515	for _, kv := range env {
516		eq := strings.Index(kv, "=")
517		if eq < 0 {
518			continue
519		}
520		k := kv[:eq]
521		if strings.EqualFold(k, "SYSTEMROOT") {
522			// We already have it.
523			return env
524		}
525	}
526	return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
527}