1package ssh
2
3import (
4 "fmt"
5 "os"
6 "os/exec"
7 "strings"
8
9 "github.com/charmbracelet/log"
10 "github.com/charmbracelet/ssh"
11 "github.com/charmbracelet/wish"
12)
13
14// ShellMiddleware is a middleware for the SSH shell.
15func ShellMiddleware(sh ssh.Handler) ssh.Handler {
16 softBin, err := os.Executable()
17 if err != nil {
18 // TODO: handle this better
19 panic(err)
20 }
21
22 return func(s ssh.Session) {
23 ctx := s.Context()
24 logger := log.FromContext(ctx).WithPrefix("ssh")
25
26 args := s.Command()
27 ppty, winch, isInteractive := s.Pty()
28
29 envs := s.Environ()
30 if len(args) > 0 {
31 envs = append(envs, "SSH_ORIGINAL_COMMAND="+strings.Join(args, " "))
32 }
33
34 var cmd interface {
35 Run() error
36 }
37 cmdArgs := []string{"shell", "-c", fmt.Sprintf("'%s'", strings.Join(args, " "))}
38 if isInteractive && ppty.Pty != nil {
39 ppty.Pty.Resize(ppty.Window.Width, ppty.Window.Height)
40 go func() {
41 for win := range winch {
42 log.Printf("resizing to %d x %d", win.Width, win.Height)
43 ppty.Pty.Resize(win.Width, win.Height)
44 }
45 }()
46
47 c := ppty.Pty.CommandContext(ctx, softBin, cmdArgs...)
48 c.Env = append(envs, "PATH="+os.Getenv("PATH"))
49 cmd = c
50 } else {
51 c := exec.CommandContext(ctx, softBin, cmdArgs...)
52 c.Env = append(envs, "PATH="+os.Getenv("PATH"))
53 cmd = c
54 }
55
56 if err := cmd.Run(); err != nil {
57 logger.Errorf("error running command: %s", err)
58 wish.Fatal(s, "internal server error")
59 return
60 }
61
62 sh(s)
63 }
64}