diff --git a/server/config/config.go b/server/config/config.go index 3eb7a05f4ce801473977e3c018d43cec6b02298c..bb173d092aea3731b5cbf37b6d834da5d02960c8 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -27,14 +27,15 @@ type SSHConfig struct { AllowKeyless bool `env:"ALLOW_KEYLESS" envDefault:"true"` AllowPassword bool `env:"ALLOW_PASSWORD" envDefault:"false"` Password string `env:"PASSWORD"` + MaxTimeout int `env:"MAX_TIMEOUT" envDefault:"0"` + IdleTimeout int `env:"IDLE_TIMEOUT" envDefault:"300"` } // GitConfig is the Git protocol configuration for the server. type GitConfig struct { - Port int `env:"PORT" envDefault:"9418"` - MaxTimeout int `env:"MAX_TIMEOUT" envDefault:"300"` - // MaxReadTimeout is the maximum time a client can take to send a request. - MaxReadTimeout int `env:"MAX_READ_TIMEOUT" envDefault:"3"` + Port int `env:"PORT" envDefault:"9418"` + MaxTimeout int `env:"MAX_TIMEOUT" envDefault:"0"` + IdleTimeout int `env:"IDLE_TIMEOUT" envDefault:"3"` MaxConnections int `env:"SOFT_SERVE_GIT_MAX_CONNECTIONS" envDefault:"32"` } diff --git a/server/db/fakedb/db.go b/server/db/fakedb/db.go new file mode 100644 index 0000000000000000000000000000000000000000..f481e9f0a910f80bec28b54673c01e00b53db67b --- /dev/null +++ b/server/db/fakedb/db.go @@ -0,0 +1,181 @@ +package fakedb + +import ( + "github.com/charmbracelet/soft-serve/proto" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/types" +) + +var _ db.Store = &FakeDB{} + +type FakeDB struct{} + +// Open implements db.Store +func (*FakeDB) Open(name string) (proto.RepositoryService, error) { + return nil, nil +} + +// GetConfig implements db.Store +func (*FakeDB) GetConfig() (*types.Config, error) { + return nil, nil +} + +// SetConfigAllowKeyless implements db.Store +func (*FakeDB) SetConfigAllowKeyless(bool) error { + return nil +} + +// SetConfigAnonAccess implements db.Store +func (*FakeDB) SetConfigAnonAccess(string) error { + return nil +} + +// SetConfigHost implements db.Store +func (*FakeDB) SetConfigHost(string) error { + return nil +} + +// SetConfigName implements db.Store +func (*FakeDB) SetConfigName(string) error { + return nil +} + +// SetConfigPort implements db.Store +func (*FakeDB) SetConfigPort(int) error { + return nil +} + +// AddUser implements db.Store +func (*FakeDB) AddUser(name string, login string, email string, password string, isAdmin bool) error { + return nil +} + +// CountUsers implements db.Store +func (*FakeDB) CountUsers() (int, error) { + return 0, nil +} + +// DeleteUser implements db.Store +func (*FakeDB) DeleteUser(int) error { + return nil +} + +// GetUser implements db.Store +func (*FakeDB) GetUser(int) (*types.User, error) { + return nil, nil +} + +// GetUserByEmail implements db.Store +func (*FakeDB) GetUserByEmail(string) (*types.User, error) { + return nil, nil +} + +// GetUserByLogin implements db.Store +func (*FakeDB) GetUserByLogin(string) (*types.User, error) { + return nil, nil +} + +// GetUserByPublicKey implements db.Store +func (*FakeDB) GetUserByPublicKey(string) (*types.User, error) { + return nil, nil +} + +// SetUserAdmin implements db.Store +func (*FakeDB) SetUserAdmin(*types.User, bool) error { + return nil +} + +// SetUserEmail implements db.Store +func (*FakeDB) SetUserEmail(*types.User, string) error { + return nil +} + +// SetUserLogin implements db.Store +func (*FakeDB) SetUserLogin(*types.User, string) error { + return nil +} + +// SetUserName implements db.Store +func (*FakeDB) SetUserName(*types.User, string) error { + return nil +} + +// SetUserPassword implements db.Store +func (*FakeDB) SetUserPassword(*types.User, string) error { + return nil +} + +// AddUserPublicKey implements db.Store +func (*FakeDB) AddUserPublicKey(*types.User, string) error { + return nil +} + +// DeleteUserPublicKey implements db.Store +func (*FakeDB) DeleteUserPublicKey(int) error { + return nil +} + +// GetUserPublicKeys implements db.Store +func (*FakeDB) GetUserPublicKeys(*types.User) ([]*types.PublicKey, error) { + return nil, nil +} + +// AddRepo implements db.Store +func (*FakeDB) AddRepo(name string, projectName string, description string, isPrivate bool) error { + return nil +} + +// DeleteRepo implements db.Store +func (*FakeDB) DeleteRepo(string) error { + return nil +} + +// GetRepo implements db.Store +func (*FakeDB) GetRepo(string) (*types.Repo, error) { + return nil, nil +} + +// SetRepoDescription implements db.Store +func (*FakeDB) SetRepoDescription(string, string) error { + return nil +} + +// SetRepoPrivate implements db.Store +func (*FakeDB) SetRepoPrivate(string, bool) error { + return nil +} + +// SetRepoProjectName implements db.Store +func (*FakeDB) SetRepoProjectName(string, string) error { + return nil +} + +// AddRepoCollab implements db.Store +func (*FakeDB) AddRepoCollab(string, *types.User) error { + return nil +} + +// DeleteRepoCollab implements db.Store +func (*FakeDB) DeleteRepoCollab(int, int) error { + return nil +} + +// ListRepoCollabs implements db.Store +func (*FakeDB) ListRepoCollabs(string) ([]*types.User, error) { + return nil, nil +} + +// ListRepoPublicKeys implements db.Store +func (*FakeDB) ListRepoPublicKeys(string) ([]*types.PublicKey, error) { + return nil, nil +} + +// Close implements db.Store +func (*FakeDB) Close() error { + return nil +} + +// CreateDB implements db.Store +func (*FakeDB) CreateDB() error { + return nil +} diff --git a/server/git/daemon/conn.go b/server/git/daemon/conn.go new file mode 100644 index 0000000000000000000000000000000000000000..1ab35242405bc4c2cad2673eec3091cead55213f --- /dev/null +++ b/server/git/daemon/conn.go @@ -0,0 +1,55 @@ +package daemon + +import ( + "context" + "net" + "time" +) + +type serverConn struct { + net.Conn + + idleTimeout time.Duration + maxDeadline time.Time + closeCanceler context.CancelFunc +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Write(p) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Read(b) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Close() (err error) { + err = c.Conn.Close() + if c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) updateDeadline() { + switch { + case c.idleTimeout > 0: + idleDeadline := time.Now().Add(c.idleTimeout) + if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { + c.Conn.SetDeadline(idleDeadline) + return + } + fallthrough + default: + c.Conn.SetDeadline(c.maxDeadline) + } +} diff --git a/server/git/daemon/daemon.go b/server/git/daemon/daemon.go index 4880ff730299457ec3aa555c64bc7de9c34227bb..75d285781dee2941ce0cd91cf813b410bbb79e78 100644 --- a/server/git/daemon/daemon.go +++ b/server/git/daemon/daemon.go @@ -25,7 +25,7 @@ var ErrServerClosed = errors.New("git: Server closed") type Daemon struct { listener net.Listener addr string - exit chan struct{} + finished chan struct{} conns map[net.Conn]struct{} cfg *config.Config wg sync.WaitGroup @@ -36,24 +36,25 @@ type Daemon struct { func NewDaemon(cfg *config.Config) (*Daemon, error) { addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Git.Port) d := &Daemon{ - addr: addr, - exit: make(chan struct{}), - cfg: cfg, - conns: make(map[net.Conn]struct{}), + addr: addr, + finished: make(chan struct{}), + cfg: cfg, + conns: make(map[net.Conn]struct{}), } listener, err := net.Listen("tcp", d.addr) if err != nil { return nil, err } d.listener = listener - d.wg.Add(1) return d, nil } // Start starts the Git TCP daemon. func (d *Daemon) Start() error { + defer d.listener.Close() // set up channel on which to send accepted connections listen := make(chan net.Conn, d.cfg.Git.MaxConnections) + d.wg.Add(1) go d.acceptConnection(d.listener, listen) // loop work cycle with accept connections or interrupt @@ -66,10 +67,7 @@ func (d *Daemon) Start() error { d.handleClient(conn) d.wg.Done() }() - case <-d.exit: - if err := d.Close(); err != nil { - return err - } + case <-d.finished: return ErrServerClosed } } @@ -89,8 +87,8 @@ func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn) conn, err := listener.Accept() if err != nil { select { - case <-d.exit: - log.Printf("git: listener closed") + case <-d.finished: + log.Printf("git: %s", ErrServerClosed) return default: log.Printf("git: error accepting connection: %v", err) @@ -102,7 +100,19 @@ func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn) } // handleClient handles a git protocol client. -func (d *Daemon) handleClient(c net.Conn) { +func (d *Daemon) handleClient(conn net.Conn) { + ctx, cancel := context.WithCancel(context.Background()) + idleTimeout := time.Duration(d.cfg.Git.IdleTimeout) * time.Second + c := &serverConn{ + Conn: conn, + idleTimeout: idleTimeout, + closeCanceler: cancel, + } + if d.cfg.Git.MaxTimeout > 0 { + dur := time.Duration(d.cfg.Git.MaxTimeout) * time.Second + c.maxDeadline = time.Now().Add(dur) + } + defer c.Close() d.conns[c] = struct{}{} defer delete(d.conns, c) @@ -110,85 +120,99 @@ func (d *Daemon) handleClient(c net.Conn) { if len(d.conns) >= d.cfg.Git.MaxConnections { log.Printf("git: max connections reached, closing %s", c.RemoteAddr()) fatal(c, git.ErrMaxConns) - return - } - - // Set connection timeout. - if err := c.SetDeadline(time.Now().Add(time.Duration(d.cfg.Git.MaxTimeout) * time.Second)); err != nil { - log.Printf("git: error setting deadline: %v", err) - fatal(c, git.ErrSystemMalfunction) + c.closeCanceler() return } readc := make(chan struct{}, 1) + s := pktline.NewScanner(c) go func() { - select { - case <-time.After(time.Duration(d.cfg.Git.MaxReadTimeout) * time.Second): - log.Printf("git: read timeout from %s", c.RemoteAddr()) - fatal(c, git.ErrMaxTimeout) - case <-readc: + if !s.Scan() { + if err := s.Err(); err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + fatal(c, git.ErrTimeout) + } else { + log.Printf("git: error scanning pktline: %v", err) + fatal(c, git.ErrSystemMalfunction) + } + } + return } + readc <- struct{}{} }() - s := pktline.NewScanner(c) - if !s.Scan() { - if err := s.Err(); err != nil { - log.Printf("git: error scanning pktline: %v", err) - fatal(c, git.ErrSystemMalfunction) + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + log.Printf("git: connection context error: %v", err) } return - } - readc <- struct{}{} - - line := s.Bytes() - split := bytes.SplitN(line, []byte{' '}, 2) - if len(split) != 2 { - return - } + case <-readc: + line := s.Bytes() + split := bytes.SplitN(line, []byte{' '}, 2) + if len(split) != 2 { + fatal(c, git.ErrInvalidRequest) + return + } - var repo string - cmd := string(split[0]) - opts := bytes.Split(split[1], []byte{'\x00'}) - if len(opts) == 0 { - return - } - repo = filepath.Clean(string(opts[0])) - - log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo) - defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo) - repo = strings.TrimPrefix(repo, "/") - auth := d.cfg.AuthRepo(strings.TrimSuffix(repo, ".git"), nil) - if auth < proto.ReadOnlyAccess { - fatal(c, git.ErrNotAuthed) - return - } - // git bare repositories should end in ".git" - // https://git-scm.com/docs/gitrepository-layout - if !strings.HasSuffix(repo, ".git") { - repo += ".git" - } + var repo string + cmd := string(split[0]) + opts := bytes.Split(split[1], []byte{'\x00'}) + if len(opts) == 0 { + fatal(c, git.ErrInvalidRequest) + return + } + repo = filepath.Clean(string(opts[0])) + + log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo) + defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo) + repo = strings.TrimPrefix(repo, "/") + auth := d.cfg.AuthRepo(strings.TrimSuffix(repo, ".git"), nil) + if auth < proto.ReadOnlyAccess { + fatal(c, git.ErrNotAuthed) + return + } + // git bare repositories should end in ".git" + // https://git-scm.com/docs/gitrepository-layout + if !strings.HasSuffix(repo, ".git") { + repo += ".git" + } - err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), repo) - if err == git.ErrInvalidRepo { - trimmed := strings.TrimSuffix(repo, ".git") - log.Printf("git: invalid repo %q trying again %q", repo, trimmed) - err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), trimmed) - } - if err != nil { - fatal(c, err) - return + err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), repo) + if err == git.ErrInvalidRepo { + trimmed := strings.TrimSuffix(repo, ".git") + log.Printf("git: invalid repo %q trying again %q", repo, trimmed) + err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), trimmed) + } + if err != nil { + fatal(c, err) + return + } } } // Close closes the underlying listener. func (d *Daemon) Close() error { - d.once.Do(func() { close(d.exit) }) + d.once.Do(func() { close(d.finished) }) + for c := range d.conns { + c.Close() + delete(d.conns, c) + } return d.listener.Close() } // Shutdown gracefully shuts down the daemon. -func (d *Daemon) Shutdown(_ context.Context) error { - d.once.Do(func() { close(d.exit) }) - d.wg.Wait() - return nil +func (d *Daemon) Shutdown(ctx context.Context) error { + d.once.Do(func() { close(d.finished) }) + finished := make(chan struct{}, 1) + go func() { + d.wg.Wait() + finished <- struct{}{} + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-finished: + return d.listener.Close() + } } diff --git a/server/git/daemon/daemon_test.go b/server/git/daemon/daemon_test.go index fbfddd1e2133a2f6bc4ff09ab1ec878afae362ad..15d10c48ebdb9880b8fd160f9c0e2cd7e21b927f 100644 --- a/server/git/daemon/daemon_test.go +++ b/server/git/daemon/daemon_test.go @@ -8,7 +8,9 @@ import ( "os" "testing" + "github.com/charmbracelet/soft-serve/proto" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db/fakedb" "github.com/charmbracelet/soft-serve/server/git" "github.com/go-git/go-git/v5/plumbing/format/pktline" ) @@ -22,18 +24,20 @@ func TestMain(m *testing.M) { } defer os.RemoveAll(tmp) cfg := &config.Config{ - Host: "", - DataPath: tmp, + Host: "", + DataPath: tmp, + AnonAccess: proto.ReadOnlyAccess, Git: config.GitConfig{ + // Reduce the max read timeout to 1 second so we can test the timeout. + IdleTimeout: 3, // Reduce the max timeout to 100 second so we can test the timeout. MaxTimeout: 100, - // Reduce the max read timeout to 1 second so we can test the timeout. - MaxReadTimeout: 1, // Reduce the max connections to 3 so we can test the timeout. MaxConnections: 3, Port: 9418, }, } + cfg = cfg.WithDB(&fakedb.FakeDB{}) d, err := NewDaemon(cfg) if err != nil { log.Fatal(err) @@ -48,7 +52,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func TestMaxReadTimeout(t *testing.T) { +func TestIdleTimeout(t *testing.T) { c, err := net.Dial("tcp", testDaemon.addr) if err != nil { t.Fatal(err) @@ -57,8 +61,8 @@ func TestMaxReadTimeout(t *testing.T) { if err != nil { t.Fatalf("expected nil, got error: %v", err) } - if out != git.ErrMaxTimeout.Error() { - t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout) + if out != git.ErrTimeout.Error() { + t.Fatalf("expected %q error, got %q", git.ErrTimeout, out) } } @@ -75,7 +79,7 @@ func TestInvalidRepo(t *testing.T) { t.Fatalf("expected nil, got error: %v", err) } if out != git.ErrInvalidRepo.Error() { - t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo) + t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out) } } diff --git a/server/git/error.go b/server/git/error.go index 2d9dfafb8af951d248fb6b549e49d5adb2def32e..11dc0ea4e825e2d1f6eb91b5eb66da4ec6ad6b53 100644 --- a/server/git/error.go +++ b/server/git/error.go @@ -11,8 +11,11 @@ var ErrSystemMalfunction = errors.New("something went wrong") // ErrInvalidRepo represents an attempt to access a non-existent repo. var ErrInvalidRepo = errors.New("invalid repo") +// ErrInvalidRequest represents an invalid request. +var ErrInvalidRequest = errors.New("invalid request") + // ErrMaxConns represents a maximum connection limit being reached. var ErrMaxConns = errors.New("too many connections, try again later") -// ErrMaxTimeout is returned when the maximum read timeout is exceeded. -var ErrMaxTimeout = errors.New("git: max timeout reached") +// ErrTimeout is returned when the maximum read timeout is exceeded. +var ErrTimeout = errors.New("I/O timeout reached") diff --git a/server/server.go b/server/server.go index b241f89fd3af611083ce317c0b36eeac8d00e8b6..07844456a917c504ea44187ce6806efc1f011344 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "time" appCfg "github.com/charmbracelet/soft-serve/config" cm "github.com/charmbracelet/soft-serve/server/cmd" @@ -68,6 +69,12 @@ func NewServer(cfg *config.Config) *Server { if err != nil { log.Fatalln(err) } + if cfg.SSH.MaxTimeout > 0 { + sh.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second + } + if cfg.SSH.IdleTimeout > 0 { + sh.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second + } s.SSHServer = sh d, err := daemon.NewDaemon(cfg) if err != nil {