@@ -0,0 +1,181 @@
+package fakedb
+
+import (
+ "github.com/charmbracelet/soft-serve/proto"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/types"
+)
+
+var _ db.Store = &FakeDB{}
+
+type FakeDB struct{}
+
+// Open implements db.Store
+func (*FakeDB) Open(name string) (proto.RepositoryService, error) {
+ return nil, nil
+}
+
+// GetConfig implements db.Store
+func (*FakeDB) GetConfig() (*types.Config, error) {
+ return nil, nil
+}
+
+// SetConfigAllowKeyless implements db.Store
+func (*FakeDB) SetConfigAllowKeyless(bool) error {
+ return nil
+}
+
+// SetConfigAnonAccess implements db.Store
+func (*FakeDB) SetConfigAnonAccess(string) error {
+ return nil
+}
+
+// SetConfigHost implements db.Store
+func (*FakeDB) SetConfigHost(string) error {
+ return nil
+}
+
+// SetConfigName implements db.Store
+func (*FakeDB) SetConfigName(string) error {
+ return nil
+}
+
+// SetConfigPort implements db.Store
+func (*FakeDB) SetConfigPort(int) error {
+ return nil
+}
+
+// AddUser implements db.Store
+func (*FakeDB) AddUser(name string, login string, email string, password string, isAdmin bool) error {
+ return nil
+}
+
+// CountUsers implements db.Store
+func (*FakeDB) CountUsers() (int, error) {
+ return 0, nil
+}
+
+// DeleteUser implements db.Store
+func (*FakeDB) DeleteUser(int) error {
+ return nil
+}
+
+// GetUser implements db.Store
+func (*FakeDB) GetUser(int) (*types.User, error) {
+ return nil, nil
+}
+
+// GetUserByEmail implements db.Store
+func (*FakeDB) GetUserByEmail(string) (*types.User, error) {
+ return nil, nil
+}
+
+// GetUserByLogin implements db.Store
+func (*FakeDB) GetUserByLogin(string) (*types.User, error) {
+ return nil, nil
+}
+
+// GetUserByPublicKey implements db.Store
+func (*FakeDB) GetUserByPublicKey(string) (*types.User, error) {
+ return nil, nil
+}
+
+// SetUserAdmin implements db.Store
+func (*FakeDB) SetUserAdmin(*types.User, bool) error {
+ return nil
+}
+
+// SetUserEmail implements db.Store
+func (*FakeDB) SetUserEmail(*types.User, string) error {
+ return nil
+}
+
+// SetUserLogin implements db.Store
+func (*FakeDB) SetUserLogin(*types.User, string) error {
+ return nil
+}
+
+// SetUserName implements db.Store
+func (*FakeDB) SetUserName(*types.User, string) error {
+ return nil
+}
+
+// SetUserPassword implements db.Store
+func (*FakeDB) SetUserPassword(*types.User, string) error {
+ return nil
+}
+
+// AddUserPublicKey implements db.Store
+func (*FakeDB) AddUserPublicKey(*types.User, string) error {
+ return nil
+}
+
+// DeleteUserPublicKey implements db.Store
+func (*FakeDB) DeleteUserPublicKey(int) error {
+ return nil
+}
+
+// GetUserPublicKeys implements db.Store
+func (*FakeDB) GetUserPublicKeys(*types.User) ([]*types.PublicKey, error) {
+ return nil, nil
+}
+
+// AddRepo implements db.Store
+func (*FakeDB) AddRepo(name string, projectName string, description string, isPrivate bool) error {
+ return nil
+}
+
+// DeleteRepo implements db.Store
+func (*FakeDB) DeleteRepo(string) error {
+ return nil
+}
+
+// GetRepo implements db.Store
+func (*FakeDB) GetRepo(string) (*types.Repo, error) {
+ return nil, nil
+}
+
+// SetRepoDescription implements db.Store
+func (*FakeDB) SetRepoDescription(string, string) error {
+ return nil
+}
+
+// SetRepoPrivate implements db.Store
+func (*FakeDB) SetRepoPrivate(string, bool) error {
+ return nil
+}
+
+// SetRepoProjectName implements db.Store
+func (*FakeDB) SetRepoProjectName(string, string) error {
+ return nil
+}
+
+// AddRepoCollab implements db.Store
+func (*FakeDB) AddRepoCollab(string, *types.User) error {
+ return nil
+}
+
+// DeleteRepoCollab implements db.Store
+func (*FakeDB) DeleteRepoCollab(int, int) error {
+ return nil
+}
+
+// ListRepoCollabs implements db.Store
+func (*FakeDB) ListRepoCollabs(string) ([]*types.User, error) {
+ return nil, nil
+}
+
+// ListRepoPublicKeys implements db.Store
+func (*FakeDB) ListRepoPublicKeys(string) ([]*types.PublicKey, error) {
+ return nil, nil
+}
+
+// Close implements db.Store
+func (*FakeDB) Close() error {
+ return nil
+}
+
+// CreateDB implements db.Store
+func (*FakeDB) CreateDB() error {
+ return nil
+}
@@ -25,7 +25,7 @@ var ErrServerClosed = errors.New("git: Server closed")
type Daemon struct {
listener net.Listener
addr string
- exit chan struct{}
+ finished chan struct{}
conns map[net.Conn]struct{}
cfg *config.Config
wg sync.WaitGroup
@@ -36,24 +36,25 @@ type Daemon struct {
func NewDaemon(cfg *config.Config) (*Daemon, error) {
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Git.Port)
d := &Daemon{
- addr: addr,
- exit: make(chan struct{}),
- cfg: cfg,
- conns: make(map[net.Conn]struct{}),
+ addr: addr,
+ finished: make(chan struct{}),
+ cfg: cfg,
+ conns: make(map[net.Conn]struct{}),
}
listener, err := net.Listen("tcp", d.addr)
if err != nil {
return nil, err
}
d.listener = listener
- d.wg.Add(1)
return d, nil
}
// Start starts the Git TCP daemon.
func (d *Daemon) Start() error {
+ defer d.listener.Close()
// set up channel on which to send accepted connections
listen := make(chan net.Conn, d.cfg.Git.MaxConnections)
+ d.wg.Add(1)
go d.acceptConnection(d.listener, listen)
// loop work cycle with accept connections or interrupt
@@ -66,10 +67,7 @@ func (d *Daemon) Start() error {
d.handleClient(conn)
d.wg.Done()
}()
- case <-d.exit:
- if err := d.Close(); err != nil {
- return err
- }
+ case <-d.finished:
return ErrServerClosed
}
}
@@ -89,8 +87,8 @@ func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn)
conn, err := listener.Accept()
if err != nil {
select {
- case <-d.exit:
- log.Printf("git: listener closed")
+ case <-d.finished:
+ log.Printf("git: %s", ErrServerClosed)
return
default:
log.Printf("git: error accepting connection: %v", err)
@@ -102,7 +100,19 @@ func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn)
}
// handleClient handles a git protocol client.
-func (d *Daemon) handleClient(c net.Conn) {
+func (d *Daemon) handleClient(conn net.Conn) {
+ ctx, cancel := context.WithCancel(context.Background())
+ idleTimeout := time.Duration(d.cfg.Git.IdleTimeout) * time.Second
+ c := &serverConn{
+ Conn: conn,
+ idleTimeout: idleTimeout,
+ closeCanceler: cancel,
+ }
+ if d.cfg.Git.MaxTimeout > 0 {
+ dur := time.Duration(d.cfg.Git.MaxTimeout) * time.Second
+ c.maxDeadline = time.Now().Add(dur)
+ }
+ defer c.Close()
d.conns[c] = struct{}{}
defer delete(d.conns, c)
@@ -110,85 +120,99 @@ func (d *Daemon) handleClient(c net.Conn) {
if len(d.conns) >= d.cfg.Git.MaxConnections {
log.Printf("git: max connections reached, closing %s", c.RemoteAddr())
fatal(c, git.ErrMaxConns)
- return
- }
-
- // Set connection timeout.
- if err := c.SetDeadline(time.Now().Add(time.Duration(d.cfg.Git.MaxTimeout) * time.Second)); err != nil {
- log.Printf("git: error setting deadline: %v", err)
- fatal(c, git.ErrSystemMalfunction)
+ c.closeCanceler()
return
}
readc := make(chan struct{}, 1)
+ s := pktline.NewScanner(c)
go func() {
- select {
- case <-time.After(time.Duration(d.cfg.Git.MaxReadTimeout) * time.Second):
- log.Printf("git: read timeout from %s", c.RemoteAddr())
- fatal(c, git.ErrMaxTimeout)
- case <-readc:
+ if !s.Scan() {
+ if err := s.Err(); err != nil {
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ fatal(c, git.ErrTimeout)
+ } else {
+ log.Printf("git: error scanning pktline: %v", err)
+ fatal(c, git.ErrSystemMalfunction)
+ }
+ }
+ return
}
+ readc <- struct{}{}
}()
- s := pktline.NewScanner(c)
- if !s.Scan() {
- if err := s.Err(); err != nil {
- log.Printf("git: error scanning pktline: %v", err)
- fatal(c, git.ErrSystemMalfunction)
+ select {
+ case <-ctx.Done():
+ if err := ctx.Err(); err != nil {
+ log.Printf("git: connection context error: %v", err)
}
return
- }
- readc <- struct{}{}
-
- line := s.Bytes()
- split := bytes.SplitN(line, []byte{' '}, 2)
- if len(split) != 2 {
- return
- }
+ case <-readc:
+ line := s.Bytes()
+ split := bytes.SplitN(line, []byte{' '}, 2)
+ if len(split) != 2 {
+ fatal(c, git.ErrInvalidRequest)
+ return
+ }
- var repo string
- cmd := string(split[0])
- opts := bytes.Split(split[1], []byte{'\x00'})
- if len(opts) == 0 {
- return
- }
- repo = filepath.Clean(string(opts[0]))
-
- log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo)
- defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo)
- repo = strings.TrimPrefix(repo, "/")
- auth := d.cfg.AuthRepo(strings.TrimSuffix(repo, ".git"), nil)
- if auth < proto.ReadOnlyAccess {
- fatal(c, git.ErrNotAuthed)
- return
- }
- // git bare repositories should end in ".git"
- // https://git-scm.com/docs/gitrepository-layout
- if !strings.HasSuffix(repo, ".git") {
- repo += ".git"
- }
+ var repo string
+ cmd := string(split[0])
+ opts := bytes.Split(split[1], []byte{'\x00'})
+ if len(opts) == 0 {
+ fatal(c, git.ErrInvalidRequest)
+ return
+ }
+ repo = filepath.Clean(string(opts[0]))
+
+ log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo)
+ defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo)
+ repo = strings.TrimPrefix(repo, "/")
+ auth := d.cfg.AuthRepo(strings.TrimSuffix(repo, ".git"), nil)
+ if auth < proto.ReadOnlyAccess {
+ fatal(c, git.ErrNotAuthed)
+ return
+ }
+ // git bare repositories should end in ".git"
+ // https://git-scm.com/docs/gitrepository-layout
+ if !strings.HasSuffix(repo, ".git") {
+ repo += ".git"
+ }
- err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), repo)
- if err == git.ErrInvalidRepo {
- trimmed := strings.TrimSuffix(repo, ".git")
- log.Printf("git: invalid repo %q trying again %q", repo, trimmed)
- err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), trimmed)
- }
- if err != nil {
- fatal(c, err)
- return
+ err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), repo)
+ if err == git.ErrInvalidRepo {
+ trimmed := strings.TrimSuffix(repo, ".git")
+ log.Printf("git: invalid repo %q trying again %q", repo, trimmed)
+ err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), trimmed)
+ }
+ if err != nil {
+ fatal(c, err)
+ return
+ }
}
}
// Close closes the underlying listener.
func (d *Daemon) Close() error {
- d.once.Do(func() { close(d.exit) })
+ d.once.Do(func() { close(d.finished) })
+ for c := range d.conns {
+ c.Close()
+ delete(d.conns, c)
+ }
return d.listener.Close()
}
// Shutdown gracefully shuts down the daemon.
-func (d *Daemon) Shutdown(_ context.Context) error {
- d.once.Do(func() { close(d.exit) })
- d.wg.Wait()
- return nil
+func (d *Daemon) Shutdown(ctx context.Context) error {
+ d.once.Do(func() { close(d.finished) })
+ finished := make(chan struct{}, 1)
+ go func() {
+ d.wg.Wait()
+ finished <- struct{}{}
+ }()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-finished:
+ return d.listener.Close()
+ }
}
@@ -8,7 +8,9 @@ import (
"os"
"testing"
+ "github.com/charmbracelet/soft-serve/proto"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db/fakedb"
"github.com/charmbracelet/soft-serve/server/git"
"github.com/go-git/go-git/v5/plumbing/format/pktline"
)
@@ -22,18 +24,20 @@ func TestMain(m *testing.M) {
}
defer os.RemoveAll(tmp)
cfg := &config.Config{
- Host: "",
- DataPath: tmp,
+ Host: "",
+ DataPath: tmp,
+ AnonAccess: proto.ReadOnlyAccess,
Git: config.GitConfig{
+ // Reduce the max read timeout to 1 second so we can test the timeout.
+ IdleTimeout: 3,
// Reduce the max timeout to 100 second so we can test the timeout.
MaxTimeout: 100,
- // Reduce the max read timeout to 1 second so we can test the timeout.
- MaxReadTimeout: 1,
// Reduce the max connections to 3 so we can test the timeout.
MaxConnections: 3,
Port: 9418,
},
}
+ cfg = cfg.WithDB(&fakedb.FakeDB{})
d, err := NewDaemon(cfg)
if err != nil {
log.Fatal(err)
@@ -48,7 +52,7 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
-func TestMaxReadTimeout(t *testing.T) {
+func TestIdleTimeout(t *testing.T) {
c, err := net.Dial("tcp", testDaemon.addr)
if err != nil {
t.Fatal(err)
@@ -57,8 +61,8 @@ func TestMaxReadTimeout(t *testing.T) {
if err != nil {
t.Fatalf("expected nil, got error: %v", err)
}
- if out != git.ErrMaxTimeout.Error() {
- t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
+ if out != git.ErrTimeout.Error() {
+ t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
}
}
@@ -75,7 +79,7 @@ func TestInvalidRepo(t *testing.T) {
t.Fatalf("expected nil, got error: %v", err)
}
if out != git.ErrInvalidRepo.Error() {
- t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
+ t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
}
}