Use middleware for session handling

Toby Padilla created

Change summary

main.go      |  2 
server.go    | 65 ++++++++++++++++++++++++++++++-----------------------
tui/model.go |  6 ++--
3 files changed, 41 insertions(+), 32 deletions(-)

Detailed changes

main.go 🔗

@@ -17,7 +17,7 @@ func main() {
 	if err != nil {
 		panic(err)
 	}
-	s, err := NewServer(cfg.Port, cfg.KeyPath, tui.SessionHandler)
+	s, err := NewServer(cfg.Port, cfg.KeyPath, LoggingMiddleware(), BubbleTeaMiddleware(tui.SessionHandler))
 	if err != nil {
 		panic(err)
 	}

server.go 🔗

@@ -12,22 +12,44 @@ import (
 	gossh "golang.org/x/crypto/ssh"
 )
 
-type SessionHandler func(ssh.Session) (tea.Model, error)
+type Middleware func(ssh.Handler) ssh.Handler
 
-type Server struct {
-	server  *ssh.Server
-	key     gossh.PublicKey
-	handler SessionHandler
+func LoggingMiddleware() Middleware {
+	return func(sh ssh.Handler) ssh.Handler {
+		return func(s ssh.Session) {
+			hpk := s.PublicKey() != nil
+			log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
+			sh(s)
+			log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
+		}
+	}
 }
 
-func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error) {
-	s := &Server{
-		server:  &ssh.Server{},
-		handler: handler,
+func BubbleTeaMiddleware(bth func(ssh.Session) tea.Model) Middleware {
+	return func(sh ssh.Handler) ssh.Handler {
+		return func(s ssh.Session) {
+			m := bth(s)
+			if m != nil {
+				p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s))
+				err := p.Start()
+				if err != nil {
+					log.Printf("%s error %v: %s\n", s.RemoteAddr().String(), s.Command(), err)
+				}
+			}
+			sh(s)
+		}
 	}
+}
+
+type Server struct {
+	server *ssh.Server
+	key    gossh.PublicKey
+}
+
+func NewServer(port int, keyPath string, mw ...Middleware) (*Server, error) {
+	s := &Server{server: &ssh.Server{}}
 	s.server.Version = "OpenSSH_7.6p1"
 	s.server.Addr = fmt.Sprintf(":%d", port)
-	s.server.Handler = s.sessionHandler
 	s.server.PasswordHandler = s.passHandler
 	s.server.PublicKeyHandler = s.authHandler
 	kps := strings.Split(keyPath, string(filepath.Separator))
@@ -42,28 +64,15 @@ func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error
 	if err != nil {
 		return nil, err
 	}
+	h := func(s ssh.Session) {}
+	for _, m := range mw {
+		h = m(h)
+	}
+	s.server.Handler = h
 	return s, nil
 }
 
 func (srv *Server) sessionHandler(s ssh.Session) {
-	hpk := s.PublicKey() != nil
-	log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
-	m, err := srv.handler(s)
-	if err != nil {
-		log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err)
-		s.Exit(1)
-		return
-	}
-	if m != nil {
-		p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s))
-		err = p.Start()
-		if err != nil {
-			log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err)
-			s.Exit(1)
-			return
-		}
-	}
-	log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
 }
 
 func (srv *Server) authHandler(ctx ssh.Context, key ssh.PublicKey) bool {

tui/model.go 🔗

@@ -24,12 +24,12 @@ func (e errMsg) Error() string {
 	return e.err.Error()
 }
 
-func SessionHandler(s ssh.Session) (tea.Model, error) {
+func SessionHandler(s ssh.Session) tea.Model {
 	pty, changes, active := s.Pty()
 	if !active {
-		return nil, fmt.Errorf("you need to do this from a terminal with PTY support")
+		return nil
 	}
-	return NewModel(pty.Window.Width, pty.Window.Height, changes), nil
+	return NewModel(pty.Window.Width, pty.Window.Height, changes)
 }
 
 type Model struct {