@@ -17,7 +17,7 @@ func main() {
if err != nil {
panic(err)
}
- s, err := NewServer(cfg.Port, cfg.KeyPath, tui.SessionHandler)
+ s, err := NewServer(cfg.Port, cfg.KeyPath, LoggingMiddleware(), BubbleTeaMiddleware(tui.SessionHandler))
if err != nil {
panic(err)
}
@@ -12,22 +12,44 @@ import (
gossh "golang.org/x/crypto/ssh"
)
-type SessionHandler func(ssh.Session) (tea.Model, error)
+type Middleware func(ssh.Handler) ssh.Handler
-type Server struct {
- server *ssh.Server
- key gossh.PublicKey
- handler SessionHandler
+func LoggingMiddleware() Middleware {
+ return func(sh ssh.Handler) ssh.Handler {
+ return func(s ssh.Session) {
+ hpk := s.PublicKey() != nil
+ log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
+ sh(s)
+ log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
+ }
+ }
}
-func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error) {
- s := &Server{
- server: &ssh.Server{},
- handler: handler,
+func BubbleTeaMiddleware(bth func(ssh.Session) tea.Model) Middleware {
+ return func(sh ssh.Handler) ssh.Handler {
+ return func(s ssh.Session) {
+ m := bth(s)
+ if m != nil {
+ p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s))
+ err := p.Start()
+ if err != nil {
+ log.Printf("%s error %v: %s\n", s.RemoteAddr().String(), s.Command(), err)
+ }
+ }
+ sh(s)
+ }
}
+}
+
+type Server struct {
+ server *ssh.Server
+ key gossh.PublicKey
+}
+
+func NewServer(port int, keyPath string, mw ...Middleware) (*Server, error) {
+ s := &Server{server: &ssh.Server{}}
s.server.Version = "OpenSSH_7.6p1"
s.server.Addr = fmt.Sprintf(":%d", port)
- s.server.Handler = s.sessionHandler
s.server.PasswordHandler = s.passHandler
s.server.PublicKeyHandler = s.authHandler
kps := strings.Split(keyPath, string(filepath.Separator))
@@ -42,28 +64,15 @@ func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error
if err != nil {
return nil, err
}
+ h := func(s ssh.Session) {}
+ for _, m := range mw {
+ h = m(h)
+ }
+ s.server.Handler = h
return s, nil
}
func (srv *Server) sessionHandler(s ssh.Session) {
- hpk := s.PublicKey() != nil
- log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
- m, err := srv.handler(s)
- if err != nil {
- log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err)
- s.Exit(1)
- return
- }
- if m != nil {
- p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s))
- err = p.Start()
- if err != nil {
- log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err)
- s.Exit(1)
- return
- }
- }
- log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
}
func (srv *Server) authHandler(ctx ssh.Context, key ssh.PublicKey) bool {
@@ -24,12 +24,12 @@ func (e errMsg) Error() string {
return e.err.Error()
}
-func SessionHandler(s ssh.Session) (tea.Model, error) {
+func SessionHandler(s ssh.Session) tea.Model {
pty, changes, active := s.Pty()
if !active {
- return nil, fmt.Errorf("you need to do this from a terminal with PTY support")
+ return nil
}
- return NewModel(pty.Window.Width, pty.Window.Height, changes), nil
+ return NewModel(pty.Window.Width, pty.Window.Height, changes)
}
type Model struct {