From 3d9565e3c33374d6912f78caca56b15bcd966e62 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 21 Sep 2023 17:18:47 -0400 Subject: [PATCH] feat(ssh): add pty windows implementation --- server/ssh/pty/pty.go | 3 +- server/ssh/pty/pty_windows.go | 527 +++++++++++++++++++++++++++++ server/ssh/pty/zsyscall_windows.go | 75 ++++ 3 files changed, 604 insertions(+), 1 deletion(-) create mode 100644 server/ssh/pty/pty_windows.go create mode 100644 server/ssh/pty/zsyscall_windows.go diff --git a/server/ssh/pty/pty.go b/server/ssh/pty/pty.go index cb5195d8675bda3d7fb72d11b8dcb53748e3f89d..5567e8e133f0e09ada66cd8c170cad05847f781c 100644 --- a/server/ssh/pty/pty.go +++ b/server/ssh/pty/pty.go @@ -40,6 +40,7 @@ type Pty interface { Resize(rows, cols int) error // Control access to the underlying file descriptor in a blocking manner. + // Not implemented on Windows. Control(f func(fd uintptr)) error } @@ -80,7 +81,7 @@ type Cmd struct { ProcessState *os.ProcessState // Cancel is called when the command is canceled. - Cancel func() + Cancel func() error } // Start starts the specified command attached to the pseudo-terminal. diff --git a/server/ssh/pty/pty_windows.go b/server/ssh/pty/pty_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..adcf918f0c4855a6019c2bed92d36d05a934b827 --- /dev/null +++ b/server/ssh/pty/pty_windows.go @@ -0,0 +1,527 @@ +//go:build windows +// +build windows + +package pty + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "sync" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016 // nolint:revive +) + +var ( + errClosedConPty = errors.New("pseudo console is closed") + errNotStarted = errors.New("process not started") +) + +// Install this from github.com/Microsoft/go-winio +// go install github.com/Microsoft/go-winio/tools/mkwinsyscall@latest +//go:generate mkwinsyscall -output zsyscall_windows.go ./*.go + +// https://github.com/microsoft/hcsshim/blob/main/internal/conpty/conpty.go +type conPty struct { + handle windows.Handle + inPipe, outPipe *os.File + mtx sync.RWMutex +} + +var _ Pty = &conPty{} + +func newPty() (Pty, error) { + ptyIn, inPipeOurs, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) + } + + outPipeOurs, ptyOut, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) + } + + var hpc windows.Handle + coord := windows.Coord{X: 80, Y: 25} + err = createPseudoConsole(coord, windows.Handle(ptyIn.Fd()), windows.Handle(ptyOut.Fd()), 0, &hpc) + if err != nil { + return nil, fmt.Errorf("failed to create pseudo console: %w", err) + } + + if err := ptyOut.Close(); err != nil { + return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) + } + if err := ptyIn.Close(); err != nil { + return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) + } + + return &conPty{ + handle: hpc, + inPipe: inPipeOurs, + outPipe: outPipeOurs, + }, nil +} + +// Close implements Pty. +func (p *conPty) Close() error { + p.mtx.Lock() + defer p.mtx.Unlock() + + closePseudoConsole(p.handle) + return errors.Join(p.inPipe.Close(), p.outPipe.Close()) +} + +// Command implements Pty. +func (p *conPty) Command(name string, args ...string) *Cmd { + return p.CommandContext(nil, name, args...) +} + +// CommandContext implements Pty. +func (p *conPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd { + return &Cmd{ + ctx: ctx, + pty: p, + Path: name, + Args: append([]string{name}, args...), + } +} + +// Control implements Pty. +func (*conPty) Control(func(fd uintptr)) error { + return nil +} + +// Name implements Pty. +func (*conPty) Name() string { + return "windows-pty" +} + +// Read implements Pty. +func (p *conPty) Read(b []byte) (n int, err error) { + return p.outPipe.Read(b) +} + +// Resize implements Pty. +func (p *conPty) Resize(rows int, cols int) error { + if err := resizePseudoConsole(p.handle, windows.Coord{X: int16(cols), Y: int16(rows)}); err != nil { + return fmt.Errorf("failed to resize pseudo console: %w", err) + } + return nil +} + +// Write implements Pty. +func (p *conPty) Write(b []byte) (n int, err error) { + return p.inPipe.Write(b) +} + +// updateProcThreadAttribute updates the passed in attribute list to contain the entry necessary for use with +// CreateProcess. +func (p *conPty) updateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error { + p.mtx.RLock() + defer p.mtx.RUnlock() + + if p.handle == 0 { + return errClosedConPty + } + + if err := attrList.Update( + _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, + unsafe.Pointer(p.handle), + unsafe.Sizeof(p.handle), + ); err != nil { + return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err) + } + + return nil +} + +// createPseudoConsole creates a windows pseudo console. +func createPseudoConsole(size windows.Coord, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) error { + // We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand. + return _createPseudoConsole(*((*uint32)(unsafe.Pointer(&size))), hInput, hOutput, dwFlags, hpcon) +} + +// resizePseudoConsole resizes the internal buffers of the pseudo console to the width and height specified in `size`. +func resizePseudoConsole(hpcon windows.Handle, size windows.Coord) error { + // We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand. + return _resizePseudoConsole(hpcon, *((*uint32)(unsafe.Pointer(&size)))) +} + +//sys _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) = kernel32.CreatePseudoConsole +//sys _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) = kernel32.ResizePseudoConsole +//sys closePseudoConsole(hpc windows.Handle) = kernel32.ClosePseudoConsole + +type conPtySys struct { + attrs *windows.ProcThreadAttributeListContainer + done chan error + cmdErr error +} + +func (c *Cmd) start() error { + pty, ok := c.pty.(*conPty) + if !ok { + return ErrInvalidCommand + } + + argv0 := c.Path + if len(c.Dir) != 0 { + // Windows CreateProcess looks for argv0 relative to the current + // directory, and, only once the new process is started, it does + // Chdir(attr.Dir). We are adjusting for that difference here by + // making argv0 absolute. + var err error + argv0, err = joinExeDirAndFName(c.Dir, c.Path) + if err != nil { + return err + } + } + + argv0p, err := windows.UTF16PtrFromString(argv0) + if err != nil { + return err + } + + var cmdline string + if c.SysProcAttr.CmdLine != "" { + cmdline = c.SysProcAttr.CmdLine + } else { + cmdline = windows.ComposeCommandLine(c.Args) + } + argvp, err := windows.UTF16PtrFromString(cmdline) + if err != nil { + return err + } + + var dirp *uint16 + if len(c.Dir) != 0 { + dirp, err = windows.UTF16PtrFromString(c.Dir) + if err != nil { + return err + } + } + + if c.Env == nil { + c.Env, err = execEnvDefault(c.SysProcAttr) + if err != nil { + return err + } + } + + siEx := new(windows.StartupInfoEx) + siEx.Flags = windows.STARTF_USESTDHANDLES + pi := new(windows.ProcessInformation) + + // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field. + flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags + + // Allocate an attribute list that's large enough to do the operations we care about + // 2. Pseudo console setup if one was requested. + // Therefore we need a list of size 3. + attrs, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return fmt.Errorf("failed to initialize process thread attribute list: %w", err) + } + + c.sys = &conPtySys{ + attrs: attrs, + done: make(chan error, 1), + } + + if err := pty.updateProcThreadAttribute(attrs); err != nil { + return err + } + + var zeroSec windows.SecurityAttributes + pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if c.SysProcAttr.ProcessAttributes != nil { + pSec = &windows.SecurityAttributes{ + Length: c.SysProcAttr.ProcessAttributes.Length, + InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle, + } + } + tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if c.SysProcAttr.ThreadAttributes != nil { + tSec = &windows.SecurityAttributes{ + Length: c.SysProcAttr.ThreadAttributes.Length, + InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle, + } + } + + siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall + siEx.Cb = uint32(unsafe.Sizeof(*siEx)) + if c.SysProcAttr.Token != 0 { + err = windows.CreateProcessAsUser( + windows.Token(c.SysProcAttr.Token), + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } else { + err = windows.CreateProcess( + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } + if err != nil { + return fmt.Errorf("failed to create process: %w", err) + } + // Don't need the thread handle for anything. + defer func() { + _ = windows.CloseHandle(pi.Thread) + }() + + // Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a + // process has been launched. + c.Process, err = os.FindProcess(int(pi.ProcessId)) + if err != nil { + // If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the + // object. + if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil { + return fmt.Errorf("failed to terminate process after process not found: %w", tErr) + } + return fmt.Errorf("failed to find process after starting: %w", err) + } + + if c.ctx != nil { + if c.Cancel == nil { + c.Cancel = func() error { + if c.Process == nil { + return errNotStarted + } + return c.Process.Kill() + } + } + go c.waitOnContext() + } + + return nil +} + +func (c *Cmd) waitOnContext() { + sys := c.sys.(*conPtySys) + select { + case <-c.ctx.Done(): + _ = c.Cancel() + sys.cmdErr = c.ctx.Err() + case err := <-sys.done: + sys.cmdErr = err + } +} + +func (c *Cmd) wait() (retErr error) { + if c.Process == nil { + return errNotStarted + } + if c.ProcessState != nil { + return errors.New("process already waited on") + } + defer func() { + sys := c.sys.(*conPtySys) + sys.attrs.Delete() + sys.done <- nil + if retErr == nil { + retErr = sys.cmdErr + } + }() + c.ProcessState, retErr = c.Process.Wait() + if retErr != nil { + return retErr + } + return +} + +// +// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities +// found in the go stdlib. +// + +func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) { + if sys == nil || sys.Token == 0 { + return syscall.Environ(), nil + } + + var block *uint16 + err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false) + if err != nil { + return nil, err + } + + defer windows.DestroyEnvironmentBlock(block) + blockp := uintptr(unsafe.Pointer(block)) + + for { + // find NUL terminator + end := unsafe.Pointer(blockp) + for *(*uint16)(end) != 0 { + end = unsafe.Pointer(uintptr(end) + 2) + } + + n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2 + if n == 0 { + // environment block ends with empty string + break + } + + entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n] + env = append(env, string(utf16.Decode(entry))) + blockp += 2 * (uintptr(len(entry)) + 1) + } + return +} + +func isSlash(c uint8) bool { + return c == '\\' || c == '/' +} + +func normalizeDir(dir string) (name string, err error) { + ndir, err := syscall.FullPath(dir) + if err != nil { + return "", err + } + if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { + // dir cannot have \\server\share\path form + return "", syscall.EINVAL + } + return ndir, nil +} + +func volToUpper(ch int) int { + if 'a' <= ch && ch <= 'z' { + ch += 'A' - 'a' + } + return ch +} + +func joinExeDirAndFName(dir, p string) (name string, err error) { + if len(p) == 0 { + return "", syscall.EINVAL + } + if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { + // \\server\share\path form + return p, nil + } + if len(p) > 1 && p[1] == ':' { + // has drive letter + if len(p) == 2 { + return "", syscall.EINVAL + } + if isSlash(p[2]) { + return p, nil + } else { + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if volToUpper(int(p[0])) == volToUpper(int(d[0])) { + return syscall.FullPath(d + "\\" + p[2:]) + } else { + return syscall.FullPath(p) + } + } + } else { + // no drive letter + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if isSlash(p[0]) { + return windows.FullPath(d[:2] + p) + } else { + return windows.FullPath(d + "\\" + p) + } + } +} + +// createEnvBlock converts an array of environment strings into +// the representation required by CreateProcess: a sequence of NUL +// terminated strings followed by a nil. +// Last bytes are two UCS-2 NULs, or four NUL bytes. +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length++ + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/server/ssh/pty/zsyscall_windows.go b/server/ssh/pty/zsyscall_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..d2a766a1c3a353543c7c676c6a4a5a6d8eb23c6c --- /dev/null +++ b/server/ssh/pty/zsyscall_windows.go @@ -0,0 +1,75 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package pty + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole") + procCreatePseudoConsole = modkernel32.NewProc("CreatePseudoConsole") + procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole") +) + +func closePseudoConsole(hpc windows.Handle) { + syscall.Syscall(procClosePseudoConsole.Addr(), 1, uintptr(hpc), 0, 0) + return +} + +func _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) { + r0, _, _ := syscall.Syscall6(procCreatePseudoConsole.Addr(), 5, uintptr(size), uintptr(hInput), uintptr(hOutput), uintptr(dwFlags), uintptr(unsafe.Pointer(hpcon)), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) { + r0, _, _ := syscall.Syscall(procResizePseudoConsole.Addr(), 2, uintptr(hPc), uintptr(size), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +}