shell.go

 1package ssh
 2
 3import (
 4	"fmt"
 5	"os"
 6	"os/exec"
 7	"strings"
 8
 9	"github.com/charmbracelet/log"
10	"github.com/charmbracelet/ssh"
11	"github.com/charmbracelet/wish"
12)
13
14// ShellMiddleware is a middleware for the SSH shell.
15func ShellMiddleware(sh ssh.Handler) ssh.Handler {
16	softBin, err := os.Executable()
17	if err != nil {
18		// TODO: handle this better
19		panic(err)
20	}
21
22	return func(s ssh.Session) {
23		ctx := s.Context()
24		logger := log.FromContext(ctx).WithPrefix("ssh")
25
26		args := s.Command()
27		ppty, winch, isInteractive := s.Pty()
28
29		envs := s.Environ()
30		if len(args) > 0 {
31			envs = append(envs, "SSH_ORIGINAL_COMMAND="+strings.Join(args, " "))
32		}
33
34		var cmd interface {
35			Run() error
36		}
37		cmdArgs := []string{"shell", "-c", fmt.Sprintf("'%s'", strings.Join(args, " "))}
38		if isInteractive && ppty.Pty != nil {
39			ppty.Pty.Resize(ppty.Window.Width, ppty.Window.Height)
40			go func() {
41				for win := range winch {
42					log.Printf("resizing to %d x %d", win.Width, win.Height)
43					ppty.Pty.Resize(win.Width, win.Height)
44				}
45			}()
46
47			c := ppty.Pty.CommandContext(ctx, softBin, cmdArgs...)
48			c.Env = append(envs, "PATH="+os.Getenv("PATH"))
49			cmd = c
50		} else {
51			c := exec.CommandContext(ctx, softBin, cmdArgs...)
52			c.Env = append(envs, "PATH="+os.Getenv("PATH"))
53			cmd = c
54		}
55
56		if err := cmd.Run(); err != nil {
57			logger.Errorf("error running command: %s", err)
58			wish.Fatal(s, "internal server error")
59			return
60		}
61
62		sh(s)
63	}
64}