feat(ssh): add pty to handle pty allocation

Ayman Bagabas created

Change summary

go.mod                      |   1 
go.sum                      |   2 
server/ssh/pty/pty.go       | 102 ++++++++++++++++++++++++++++
server/ssh/pty/pty_other.go | 138 +++++++++++++++++++++++++++++++++++++++
4 files changed, 243 insertions(+)

Detailed changes

go.mod 🔗

@@ -25,6 +25,7 @@ require (
 	github.com/charmbracelet/keygen v0.4.3
 	github.com/charmbracelet/log v0.2.4
 	github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09
+	github.com/creack/pty v1.1.18
 	github.com/go-jose/go-jose/v3 v3.0.0
 	github.com/gobwas/glob v0.2.3
 	github.com/gogs/git-module v1.8.3

go.sum 🔗

@@ -42,6 +42,8 @@ github.com/charmbracelet/wish v1.1.1/go.mod h1:xh4KZpSULw+Xqb9bcbhw92QAinVB75CVL
 github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
 github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
 github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
+github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
+github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

server/ssh/pty/pty.go 🔗

@@ -0,0 +1,102 @@
+package pty
+
+import (
+	"context"
+	"errors"
+	"io"
+	"os"
+	"syscall"
+)
+
+var (
+	// ErrInvalidCommand is returned when the command is invalid.
+	ErrInvalidCommand = errors.New("pty: invalid command")
+)
+
+// New returns a new pseudo-terminal.
+func New() (Pty, error) {
+	return newPty()
+}
+
+// Pty is a pseudo-terminal interface.
+type Pty interface {
+	io.ReadWriteCloser
+
+	// Name returns the name of the pseudo-terminal.
+	// On Windows, this will always be "windows-pty".
+	// On Unix, this will return the name of the slave end of the
+	// pseudo-terminal TTY.
+	Name() string
+
+	// Command returns a command that can be used to start a process
+	// attached to the pseudo-terminal.
+	Command(name string, args ...string) *Cmd
+
+	// CommandContext returns a command that can be used to start a process
+	// attached to the pseudo-terminal.
+	CommandContext(ctx context.Context, name string, args ...string) *Cmd
+
+	// Resize resizes the pseudo-terminal.
+	Resize(rows, cols int) error
+
+	// Control access to the underlying file descriptor in a blocking manner.
+	Control(f func(fd uintptr)) error
+}
+
+// Cmd is a command that can be started attached to a pseudo-terminal.
+// This is similar to the API of exec.Cmd. The main difference is that
+// the command is started attached to a pseudo-terminal.
+// This is required as we cannot use exec.Cmd directly on Windows due to
+// limitation of starting a process attached to a pseudo-terminal.
+// See: https://github.com/golang/go/issues/62708
+type Cmd struct {
+	ctx context.Context
+	pty Pty
+	sys interface{}
+
+	// Path is the path of the command to run.
+	Path string
+
+	// Args holds command line arguments, including the command as Args[0].
+	Args []string
+
+	// Env specifies the environment of the process.
+	// If Env is nil, the new process uses the current process's environment.
+	Env []string
+
+	// Dir specifies the working directory of the command.
+	// If Dir is the empty string, the current directory is used.
+	Dir string
+
+	// SysProcAttr holds optional, operating system-specific attributes.
+	SysProcAttr *syscall.SysProcAttr
+
+	// Process is the underlying process, once started.
+	Process *os.Process
+
+	// ProcessState contains information about an exited process.
+	// If the process was started successfully, Wait or Run will populate this
+	// field when the command completes.
+	ProcessState *os.ProcessState
+
+	// Cancel is called when the command is canceled.
+	Cancel func()
+}
+
+// Start starts the specified command attached to the pseudo-terminal.
+func (c *Cmd) Start() error {
+	return c.start()
+}
+
+// Wait waits for the command to exit.
+func (c *Cmd) Wait() error {
+	return c.wait()
+}
+
+// Run runs the command and waits for it to complete.
+func (c *Cmd) Run() error {
+	if err := c.Start(); err != nil {
+		return err
+	}
+	return c.Wait()
+}

server/ssh/pty/pty_other.go 🔗

@@ -0,0 +1,138 @@
+//go:build !windows
+// +build !windows
+
+package pty
+
+import (
+	"context"
+	"errors"
+	"os"
+	"os/exec"
+
+	"github.com/creack/pty"
+	"golang.org/x/sys/unix"
+)
+
+type unixPty struct {
+	master, slave *os.File
+	closed        bool
+}
+
+var _ Pty = &unixPty{}
+
+// Close implements Pty.
+func (p *unixPty) Close() error {
+	if p.closed {
+		return nil
+	}
+	defer func() {
+		p.closed = true
+	}()
+	return errors.Join(p.master.Close(), p.slave.Close())
+}
+
+// Command implements Pty.
+func (p *unixPty) Command(name string, args ...string) *Cmd {
+	return p.CommandContext(nil, name, args...) // nolint:staticcheck
+}
+
+// CommandContext implements Pty.
+func (p *unixPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd {
+	cmd := exec.Command(name, args...)
+	if ctx != nil {
+		cmd = exec.CommandContext(ctx, name, args...)
+	}
+	c := &Cmd{
+		ctx:  ctx,
+		pty:  p,
+		sys:  cmd,
+		Path: name,
+		Args: append([]string{name}, args...),
+	}
+	return c
+}
+
+// Name implements Pty.
+func (p *unixPty) Name() string {
+	return p.slave.Name()
+}
+
+// Read implements Pty.
+func (p *unixPty) Read(b []byte) (n int, err error) {
+	return p.master.Read(b)
+}
+
+func (p *unixPty) Control(f func(fd uintptr)) error {
+	conn, err := p.master.SyscallConn()
+	if err != nil {
+		return err
+	}
+	return conn.Control(f)
+}
+
+// Resize implements Pty.
+func (p *unixPty) Resize(rows int, cols int) error {
+	var ctrlErr error
+	if err := p.Control(func(fd uintptr) {
+		ctrlErr = unix.IoctlSetWinsize(int(fd), unix.TIOCSWINSZ, &unix.Winsize{
+			Row: uint16(rows),
+			Col: uint16(cols),
+		})
+	}); err != nil {
+		return err
+	}
+
+	return ctrlErr
+}
+
+// Write implements Pty.
+func (p *unixPty) Write(b []byte) (n int, err error) {
+	return p.master.Write(b)
+}
+
+func newPty() (Pty, error) {
+	master, slave, err := pty.Open()
+	if err != nil {
+		return nil, err
+	}
+
+	return &unixPty{
+		master: master,
+		slave:  slave,
+	}, nil
+}
+
+func (c *Cmd) start() error {
+	cmd, ok := c.sys.(*exec.Cmd)
+	if !ok {
+		return ErrInvalidCommand
+	}
+	pty, ok := c.pty.(*unixPty)
+	if !ok {
+		return ErrInvalidCommand
+	}
+
+	cmd.Stdin = pty.slave
+	cmd.Stdout = pty.slave
+	cmd.Stderr = pty.slave
+	cmd.SysProcAttr = &unix.SysProcAttr{
+		Setsid:  true,
+		Setctty: true,
+	}
+	if err := cmd.Start(); err != nil {
+		return err
+	}
+
+	c.Process = cmd.Process
+	return nil
+}
+
+func (c *Cmd) wait() error {
+	cmd, ok := c.sys.(*exec.Cmd)
+	if !ok {
+		return ErrInvalidCommand
+	}
+	err := cmd.Wait()
+	c.ProcessState = cmd.ProcessState
+	return err
+}