diff --git a/go.mod b/go.mod index f1efa6f666e49a4ceb9dba0d4743afae634feb9d..3d57c6b24dd05eed8c4c30981de84cdd28ec4e62 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/charmbracelet/keygen v0.4.3 github.com/charmbracelet/log v0.2.4 github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09 + github.com/creack/pty v1.1.18 github.com/go-jose/go-jose/v3 v3.0.0 github.com/gobwas/glob v0.2.3 github.com/gogs/git-module v1.8.3 diff --git a/go.sum b/go.sum index b5b8ac15bfe533ce76bf9189c44a6a55abfdb01b..1b0302d2031d2144b36523f344c3ec840e1e3f80 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/charmbracelet/wish v1.1.1/go.mod h1:xh4KZpSULw+Xqb9bcbhw92QAinVB75CVL github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/server/ssh/pty/pty.go b/server/ssh/pty/pty.go new file mode 100644 index 0000000000000000000000000000000000000000..cb5195d8675bda3d7fb72d11b8dcb53748e3f89d --- /dev/null +++ b/server/ssh/pty/pty.go @@ -0,0 +1,102 @@ +package pty + +import ( + "context" + "errors" + "io" + "os" + "syscall" +) + +var ( + // ErrInvalidCommand is returned when the command is invalid. + ErrInvalidCommand = errors.New("pty: invalid command") +) + +// New returns a new pseudo-terminal. +func New() (Pty, error) { + return newPty() +} + +// Pty is a pseudo-terminal interface. +type Pty interface { + io.ReadWriteCloser + + // Name returns the name of the pseudo-terminal. + // On Windows, this will always be "windows-pty". + // On Unix, this will return the name of the slave end of the + // pseudo-terminal TTY. + Name() string + + // Command returns a command that can be used to start a process + // attached to the pseudo-terminal. + Command(name string, args ...string) *Cmd + + // CommandContext returns a command that can be used to start a process + // attached to the pseudo-terminal. + CommandContext(ctx context.Context, name string, args ...string) *Cmd + + // Resize resizes the pseudo-terminal. + Resize(rows, cols int) error + + // Control access to the underlying file descriptor in a blocking manner. + Control(f func(fd uintptr)) error +} + +// Cmd is a command that can be started attached to a pseudo-terminal. +// This is similar to the API of exec.Cmd. The main difference is that +// the command is started attached to a pseudo-terminal. +// This is required as we cannot use exec.Cmd directly on Windows due to +// limitation of starting a process attached to a pseudo-terminal. +// See: https://github.com/golang/go/issues/62708 +type Cmd struct { + ctx context.Context + pty Pty + sys interface{} + + // Path is the path of the command to run. + Path string + + // Args holds command line arguments, including the command as Args[0]. + Args []string + + // Env specifies the environment of the process. + // If Env is nil, the new process uses the current process's environment. + Env []string + + // Dir specifies the working directory of the command. + // If Dir is the empty string, the current directory is used. + Dir string + + // SysProcAttr holds optional, operating system-specific attributes. + SysProcAttr *syscall.SysProcAttr + + // Process is the underlying process, once started. + Process *os.Process + + // ProcessState contains information about an exited process. + // If the process was started successfully, Wait or Run will populate this + // field when the command completes. + ProcessState *os.ProcessState + + // Cancel is called when the command is canceled. + Cancel func() +} + +// Start starts the specified command attached to the pseudo-terminal. +func (c *Cmd) Start() error { + return c.start() +} + +// Wait waits for the command to exit. +func (c *Cmd) Wait() error { + return c.wait() +} + +// Run runs the command and waits for it to complete. +func (c *Cmd) Run() error { + if err := c.Start(); err != nil { + return err + } + return c.Wait() +} diff --git a/server/ssh/pty/pty_other.go b/server/ssh/pty/pty_other.go new file mode 100644 index 0000000000000000000000000000000000000000..3a9e23d50c8430eaee6a6e1816f300adbe023a53 --- /dev/null +++ b/server/ssh/pty/pty_other.go @@ -0,0 +1,138 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "context" + "errors" + "os" + "os/exec" + + "github.com/creack/pty" + "golang.org/x/sys/unix" +) + +type unixPty struct { + master, slave *os.File + closed bool +} + +var _ Pty = &unixPty{} + +// Close implements Pty. +func (p *unixPty) Close() error { + if p.closed { + return nil + } + defer func() { + p.closed = true + }() + return errors.Join(p.master.Close(), p.slave.Close()) +} + +// Command implements Pty. +func (p *unixPty) Command(name string, args ...string) *Cmd { + return p.CommandContext(nil, name, args...) // nolint:staticcheck +} + +// CommandContext implements Pty. +func (p *unixPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd { + cmd := exec.Command(name, args...) + if ctx != nil { + cmd = exec.CommandContext(ctx, name, args...) + } + c := &Cmd{ + ctx: ctx, + pty: p, + sys: cmd, + Path: name, + Args: append([]string{name}, args...), + } + return c +} + +// Name implements Pty. +func (p *unixPty) Name() string { + return p.slave.Name() +} + +// Read implements Pty. +func (p *unixPty) Read(b []byte) (n int, err error) { + return p.master.Read(b) +} + +func (p *unixPty) Control(f func(fd uintptr)) error { + conn, err := p.master.SyscallConn() + if err != nil { + return err + } + return conn.Control(f) +} + +// Resize implements Pty. +func (p *unixPty) Resize(rows int, cols int) error { + var ctrlErr error + if err := p.Control(func(fd uintptr) { + ctrlErr = unix.IoctlSetWinsize(int(fd), unix.TIOCSWINSZ, &unix.Winsize{ + Row: uint16(rows), + Col: uint16(cols), + }) + }); err != nil { + return err + } + + return ctrlErr +} + +// Write implements Pty. +func (p *unixPty) Write(b []byte) (n int, err error) { + return p.master.Write(b) +} + +func newPty() (Pty, error) { + master, slave, err := pty.Open() + if err != nil { + return nil, err + } + + return &unixPty{ + master: master, + slave: slave, + }, nil +} + +func (c *Cmd) start() error { + cmd, ok := c.sys.(*exec.Cmd) + if !ok { + return ErrInvalidCommand + } + pty, ok := c.pty.(*unixPty) + if !ok { + return ErrInvalidCommand + } + + cmd.Stdin = pty.slave + cmd.Stdout = pty.slave + cmd.Stderr = pty.slave + cmd.SysProcAttr = &unix.SysProcAttr{ + Setsid: true, + Setctty: true, + } + if err := cmd.Start(); err != nil { + return err + } + + c.Process = cmd.Process + return nil +} + +func (c *Cmd) wait() error { + cmd, ok := c.sys.(*exec.Cmd) + if !ok { + return ErrInvalidCommand + } + err := cmd.Wait() + c.ProcessState = cmd.ProcessState + return err +}