feat(ssh): add pty windows implementation

Ayman Bagabas created

Change summary

server/ssh/pty/pty.go              |   3 
server/ssh/pty/pty_windows.go      | 527 ++++++++++++++++++++++++++++++++
server/ssh/pty/zsyscall_windows.go |  75 ++++
3 files changed, 604 insertions(+), 1 deletion(-)

Detailed changes

server/ssh/pty/pty.go 🔗

@@ -40,6 +40,7 @@ type Pty interface {
 	Resize(rows, cols int) error
 
 	// Control access to the underlying file descriptor in a blocking manner.
+	// Not implemented on Windows.
 	Control(f func(fd uintptr)) error
 }
 
@@ -80,7 +81,7 @@ type Cmd struct {
 	ProcessState *os.ProcessState
 
 	// Cancel is called when the command is canceled.
-	Cancel func()
+	Cancel func() error
 }
 
 // Start starts the specified command attached to the pseudo-terminal.

server/ssh/pty/pty_windows.go 🔗

@@ -0,0 +1,527 @@
+//go:build windows
+// +build windows
+
+package pty
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+	"strings"
+	"sync"
+	"syscall"
+	"unicode/utf16"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+)
+
+const (
+	_PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016 // nolint:revive
+)
+
+var (
+	errClosedConPty = errors.New("pseudo console is closed")
+	errNotStarted   = errors.New("process not started")
+)
+
+// Install this from github.com/Microsoft/go-winio
+// go install github.com/Microsoft/go-winio/tools/mkwinsyscall@latest
+//go:generate mkwinsyscall -output zsyscall_windows.go ./*.go
+
+// https://github.com/microsoft/hcsshim/blob/main/internal/conpty/conpty.go
+type conPty struct {
+	handle          windows.Handle
+	inPipe, outPipe *os.File
+	mtx             sync.RWMutex
+}
+
+var _ Pty = &conPty{}
+
+func newPty() (Pty, error) {
+	ptyIn, inPipeOurs, err := os.Pipe()
+	if err != nil {
+		return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err)
+	}
+
+	outPipeOurs, ptyOut, err := os.Pipe()
+	if err != nil {
+		return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err)
+	}
+
+	var hpc windows.Handle
+	coord := windows.Coord{X: 80, Y: 25}
+	err = createPseudoConsole(coord, windows.Handle(ptyIn.Fd()), windows.Handle(ptyOut.Fd()), 0, &hpc)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create pseudo console: %w", err)
+	}
+
+	if err := ptyOut.Close(); err != nil {
+		return nil, fmt.Errorf("failed to close pseudo console handle: %w", err)
+	}
+	if err := ptyIn.Close(); err != nil {
+		return nil, fmt.Errorf("failed to close pseudo console handle: %w", err)
+	}
+
+	return &conPty{
+		handle:  hpc,
+		inPipe:  inPipeOurs,
+		outPipe: outPipeOurs,
+	}, nil
+}
+
+// Close implements Pty.
+func (p *conPty) Close() error {
+	p.mtx.Lock()
+	defer p.mtx.Unlock()
+
+	closePseudoConsole(p.handle)
+	return errors.Join(p.inPipe.Close(), p.outPipe.Close())
+}
+
+// Command implements Pty.
+func (p *conPty) Command(name string, args ...string) *Cmd {
+	return p.CommandContext(nil, name, args...)
+}
+
+// CommandContext implements Pty.
+func (p *conPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd {
+	return &Cmd{
+		ctx:  ctx,
+		pty:  p,
+		Path: name,
+		Args: append([]string{name}, args...),
+	}
+}
+
+// Control implements Pty.
+func (*conPty) Control(func(fd uintptr)) error {
+	return nil
+}
+
+// Name implements Pty.
+func (*conPty) Name() string {
+	return "windows-pty"
+}
+
+// Read implements Pty.
+func (p *conPty) Read(b []byte) (n int, err error) {
+	return p.outPipe.Read(b)
+}
+
+// Resize implements Pty.
+func (p *conPty) Resize(rows int, cols int) error {
+	if err := resizePseudoConsole(p.handle, windows.Coord{X: int16(cols), Y: int16(rows)}); err != nil {
+		return fmt.Errorf("failed to resize pseudo console: %w", err)
+	}
+	return nil
+}
+
+// Write implements Pty.
+func (p *conPty) Write(b []byte) (n int, err error) {
+	return p.inPipe.Write(b)
+}
+
+// updateProcThreadAttribute updates the passed in attribute list to contain the entry necessary for use with
+// CreateProcess.
+func (p *conPty) updateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error {
+	p.mtx.RLock()
+	defer p.mtx.RUnlock()
+
+	if p.handle == 0 {
+		return errClosedConPty
+	}
+
+	if err := attrList.Update(
+		_PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
+		unsafe.Pointer(p.handle),
+		unsafe.Sizeof(p.handle),
+	); err != nil {
+		return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err)
+	}
+
+	return nil
+}
+
+// createPseudoConsole creates a windows pseudo console.
+func createPseudoConsole(size windows.Coord, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) error {
+	// We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand.
+	return _createPseudoConsole(*((*uint32)(unsafe.Pointer(&size))), hInput, hOutput, dwFlags, hpcon)
+}
+
+// resizePseudoConsole resizes the internal buffers of the pseudo console to the width and height specified in `size`.
+func resizePseudoConsole(hpcon windows.Handle, size windows.Coord) error {
+	// We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand.
+	return _resizePseudoConsole(hpcon, *((*uint32)(unsafe.Pointer(&size))))
+}
+
+//sys _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) = kernel32.CreatePseudoConsole
+//sys _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) = kernel32.ResizePseudoConsole
+//sys closePseudoConsole(hpc windows.Handle) = kernel32.ClosePseudoConsole
+
+type conPtySys struct {
+	attrs  *windows.ProcThreadAttributeListContainer
+	done   chan error
+	cmdErr error
+}
+
+func (c *Cmd) start() error {
+	pty, ok := c.pty.(*conPty)
+	if !ok {
+		return ErrInvalidCommand
+	}
+
+	argv0 := c.Path
+	if len(c.Dir) != 0 {
+		// Windows CreateProcess looks for argv0 relative to the current
+		// directory, and, only once the new process is started, it does
+		// Chdir(attr.Dir). We are adjusting for that difference here by
+		// making argv0 absolute.
+		var err error
+		argv0, err = joinExeDirAndFName(c.Dir, c.Path)
+		if err != nil {
+			return err
+		}
+	}
+
+	argv0p, err := windows.UTF16PtrFromString(argv0)
+	if err != nil {
+		return err
+	}
+
+	var cmdline string
+	if c.SysProcAttr.CmdLine != "" {
+		cmdline = c.SysProcAttr.CmdLine
+	} else {
+		cmdline = windows.ComposeCommandLine(c.Args)
+	}
+	argvp, err := windows.UTF16PtrFromString(cmdline)
+	if err != nil {
+		return err
+	}
+
+	var dirp *uint16
+	if len(c.Dir) != 0 {
+		dirp, err = windows.UTF16PtrFromString(c.Dir)
+		if err != nil {
+			return err
+		}
+	}
+
+	if c.Env == nil {
+		c.Env, err = execEnvDefault(c.SysProcAttr)
+		if err != nil {
+			return err
+		}
+	}
+
+	siEx := new(windows.StartupInfoEx)
+	siEx.Flags = windows.STARTF_USESTDHANDLES
+	pi := new(windows.ProcessInformation)
+
+	// Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field.
+	flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags
+
+	// Allocate an attribute list that's large enough to do the operations we care about
+	// 2. Pseudo console setup if one was requested.
+	// Therefore we need a list of size 3.
+	attrs, err := windows.NewProcThreadAttributeList(1)
+	if err != nil {
+		return fmt.Errorf("failed to initialize process thread attribute list: %w", err)
+	}
+
+	c.sys = &conPtySys{
+		attrs: attrs,
+		done:  make(chan error, 1),
+	}
+
+	if err := pty.updateProcThreadAttribute(attrs); err != nil {
+		return err
+	}
+
+	var zeroSec windows.SecurityAttributes
+	pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
+	if c.SysProcAttr.ProcessAttributes != nil {
+		pSec = &windows.SecurityAttributes{
+			Length:        c.SysProcAttr.ProcessAttributes.Length,
+			InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle,
+		}
+	}
+	tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
+	if c.SysProcAttr.ThreadAttributes != nil {
+		tSec = &windows.SecurityAttributes{
+			Length:        c.SysProcAttr.ThreadAttributes.Length,
+			InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle,
+		}
+	}
+
+	siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall
+	siEx.Cb = uint32(unsafe.Sizeof(*siEx))
+	if c.SysProcAttr.Token != 0 {
+		err = windows.CreateProcessAsUser(
+			windows.Token(c.SysProcAttr.Token),
+			argv0p,
+			argvp,
+			pSec,
+			tSec,
+			false,
+			flags,
+			createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))),
+			dirp,
+			&siEx.StartupInfo,
+			pi,
+		)
+	} else {
+		err = windows.CreateProcess(
+			argv0p,
+			argvp,
+			pSec,
+			tSec,
+			false,
+			flags,
+			createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))),
+			dirp,
+			&siEx.StartupInfo,
+			pi,
+		)
+	}
+	if err != nil {
+		return fmt.Errorf("failed to create process: %w", err)
+	}
+	// Don't need the thread handle for anything.
+	defer func() {
+		_ = windows.CloseHandle(pi.Thread)
+	}()
+
+	// Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a
+	// process has been launched.
+	c.Process, err = os.FindProcess(int(pi.ProcessId))
+	if err != nil {
+		// If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the
+		// object.
+		if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil {
+			return fmt.Errorf("failed to terminate process after process not found: %w", tErr)
+		}
+		return fmt.Errorf("failed to find process after starting: %w", err)
+	}
+
+	if c.ctx != nil {
+		if c.Cancel == nil {
+			c.Cancel = func() error {
+				if c.Process == nil {
+					return errNotStarted
+				}
+				return c.Process.Kill()
+			}
+		}
+		go c.waitOnContext()
+	}
+
+	return nil
+}
+
+func (c *Cmd) waitOnContext() {
+	sys := c.sys.(*conPtySys)
+	select {
+	case <-c.ctx.Done():
+		_ = c.Cancel()
+		sys.cmdErr = c.ctx.Err()
+	case err := <-sys.done:
+		sys.cmdErr = err
+	}
+}
+
+func (c *Cmd) wait() (retErr error) {
+	if c.Process == nil {
+		return errNotStarted
+	}
+	if c.ProcessState != nil {
+		return errors.New("process already waited on")
+	}
+	defer func() {
+		sys := c.sys.(*conPtySys)
+		sys.attrs.Delete()
+		sys.done <- nil
+		if retErr == nil {
+			retErr = sys.cmdErr
+		}
+	}()
+	c.ProcessState, retErr = c.Process.Wait()
+	if retErr != nil {
+		return retErr
+	}
+	return
+}
+
+//
+// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities
+// found in the go stdlib.
+//
+
+func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) {
+	if sys == nil || sys.Token == 0 {
+		return syscall.Environ(), nil
+	}
+
+	var block *uint16
+	err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false)
+	if err != nil {
+		return nil, err
+	}
+
+	defer windows.DestroyEnvironmentBlock(block)
+	blockp := uintptr(unsafe.Pointer(block))
+
+	for {
+		// find NUL terminator
+		end := unsafe.Pointer(blockp)
+		for *(*uint16)(end) != 0 {
+			end = unsafe.Pointer(uintptr(end) + 2)
+		}
+
+		n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2
+		if n == 0 {
+			// environment block ends with empty string
+			break
+		}
+
+		entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n]
+		env = append(env, string(utf16.Decode(entry)))
+		blockp += 2 * (uintptr(len(entry)) + 1)
+	}
+	return
+}
+
+func isSlash(c uint8) bool {
+	return c == '\\' || c == '/'
+}
+
+func normalizeDir(dir string) (name string, err error) {
+	ndir, err := syscall.FullPath(dir)
+	if err != nil {
+		return "", err
+	}
+	if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) {
+		// dir cannot have \\server\share\path form
+		return "", syscall.EINVAL
+	}
+	return ndir, nil
+}
+
+func volToUpper(ch int) int {
+	if 'a' <= ch && ch <= 'z' {
+		ch += 'A' - 'a'
+	}
+	return ch
+}
+
+func joinExeDirAndFName(dir, p string) (name string, err error) {
+	if len(p) == 0 {
+		return "", syscall.EINVAL
+	}
+	if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) {
+		// \\server\share\path form
+		return p, nil
+	}
+	if len(p) > 1 && p[1] == ':' {
+		// has drive letter
+		if len(p) == 2 {
+			return "", syscall.EINVAL
+		}
+		if isSlash(p[2]) {
+			return p, nil
+		} else {
+			d, err := normalizeDir(dir)
+			if err != nil {
+				return "", err
+			}
+			if volToUpper(int(p[0])) == volToUpper(int(d[0])) {
+				return syscall.FullPath(d + "\\" + p[2:])
+			} else {
+				return syscall.FullPath(p)
+			}
+		}
+	} else {
+		// no drive letter
+		d, err := normalizeDir(dir)
+		if err != nil {
+			return "", err
+		}
+		if isSlash(p[0]) {
+			return windows.FullPath(d[:2] + p)
+		} else {
+			return windows.FullPath(d + "\\" + p)
+		}
+	}
+}
+
+// createEnvBlock converts an array of environment strings into
+// the representation required by CreateProcess: a sequence of NUL
+// terminated strings followed by a nil.
+// Last bytes are two UCS-2 NULs, or four NUL bytes.
+func createEnvBlock(envv []string) *uint16 {
+	if len(envv) == 0 {
+		return &utf16.Encode([]rune("\x00\x00"))[0]
+	}
+	length := 0
+	for _, s := range envv {
+		length += len(s) + 1
+	}
+	length++
+
+	b := make([]byte, length)
+	i := 0
+	for _, s := range envv {
+		l := len(s)
+		copy(b[i:i+l], []byte(s))
+		copy(b[i+l:i+l+1], []byte{0})
+		i = i + l + 1
+	}
+	copy(b[i:i+1], []byte{0})
+
+	return &utf16.Encode([]rune(string(b)))[0]
+}
+
+// dedupEnvCase is dedupEnv with a case option for testing.
+// If caseInsensitive is true, the case of keys is ignored.
+func dedupEnvCase(caseInsensitive bool, env []string) []string {
+	out := make([]string, 0, len(env))
+	saw := make(map[string]int, len(env)) // key => index into out
+	for _, kv := range env {
+		eq := strings.Index(kv, "=")
+		if eq < 0 {
+			out = append(out, kv)
+			continue
+		}
+		k := kv[:eq]
+		if caseInsensitive {
+			k = strings.ToLower(k)
+		}
+		if dupIdx, isDup := saw[k]; isDup {
+			out[dupIdx] = kv
+			continue
+		}
+		saw[k] = len(out)
+		out = append(out, kv)
+	}
+	return out
+}
+
+// addCriticalEnv adds any critical environment variables that are required
+// (or at least almost always required) on the operating system.
+// Currently this is only used for Windows.
+func addCriticalEnv(env []string) []string {
+	for _, kv := range env {
+		eq := strings.Index(kv, "=")
+		if eq < 0 {
+			continue
+		}
+		k := kv[:eq]
+		if strings.EqualFold(k, "SYSTEMROOT") {
+			// We already have it.
+			return env
+		}
+	}
+	return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
+}

server/ssh/pty/zsyscall_windows.go 🔗

@@ -0,0 +1,75 @@
+//go:build windows
+
+// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
+
+package pty
+
+import (
+	"syscall"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+)
+
+var _ unsafe.Pointer
+
+// Do the interface allocations only once for common
+// Errno values.
+const (
+	errnoERROR_IO_PENDING = 997
+)
+
+var (
+	errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+	errERROR_EINVAL     error = syscall.EINVAL
+)
+
+// errnoErr returns common boxed Errno values, to prevent
+// allocations at runtime.
+func errnoErr(e syscall.Errno) error {
+	switch e {
+	case 0:
+		return errERROR_EINVAL
+	case errnoERROR_IO_PENDING:
+		return errERROR_IO_PENDING
+	}
+	// TODO: add more here, after collecting data on the common
+	// error values see on Windows. (perhaps when running
+	// all.bat?)
+	return e
+}
+
+var (
+	modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
+
+	procClosePseudoConsole  = modkernel32.NewProc("ClosePseudoConsole")
+	procCreatePseudoConsole = modkernel32.NewProc("CreatePseudoConsole")
+	procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole")
+)
+
+func closePseudoConsole(hpc windows.Handle) {
+	syscall.Syscall(procClosePseudoConsole.Addr(), 1, uintptr(hpc), 0, 0)
+	return
+}
+
+func _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) {
+	r0, _, _ := syscall.Syscall6(procCreatePseudoConsole.Addr(), 5, uintptr(size), uintptr(hInput), uintptr(hOutput), uintptr(dwFlags), uintptr(unsafe.Pointer(hpcon)), 0)
+	if int32(r0) < 0 {
+		if r0&0x1fff0000 == 0x00070000 {
+			r0 &= 0xffff
+		}
+		hr = syscall.Errno(r0)
+	}
+	return
+}
+
+func _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) {
+	r0, _, _ := syscall.Syscall(procResizePseudoConsole.Addr(), 2, uintptr(hPc), uintptr(size), 0)
+	if int32(r0) < 0 {
+		if r0&0x1fff0000 == 0x00070000 {
+			r0 &= 0xffff
+		}
+		hr = syscall.Errno(r0)
+	}
+	return
+}