diff --git a/cmd/soft/serve.go b/cmd/soft/serve.go index 58c1f43c699f4dc6b3e311287ed504a496bae1f8..8207fd0d9cce1a2e33244a6db7db4163598b99ba 100644 --- a/cmd/soft/serve.go +++ b/cmd/soft/serve.go @@ -36,7 +36,6 @@ var ( signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) <-done - log.Printf("Stopping SSH server on %s:%d", cfg.BindAddr, cfg.Port) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if err := s.Shutdown(ctx); err != nil { diff --git a/examples/setuid/main.go b/examples/setuid/main.go index 0c58fbb8347581193de163564c52eaf50a5aff2a..3138ebff1275f4f1a6b48649b093ebc53af144fa 100644 --- a/examples/setuid/main.go +++ b/examples/setuid/main.go @@ -52,7 +52,7 @@ func main() { log.Printf("Starting SSH server on %s:%d", cfg.BindAddr, cfg.Port) go func() { - if err := s.Serve(ls); err != nil { + if err := s.SSHServer.Serve(ls); err != nil { log.Fatalln(err) } }() diff --git a/git/command.go b/git/command.go new file mode 100644 index 0000000000000000000000000000000000000000..eb4f0d17ac5af0c7313cd200244a2bac36c04c13 --- /dev/null +++ b/git/command.go @@ -0,0 +1,11 @@ +package git + +import "github.com/gogs/git-module" + +// RunInDirOptions are options for RunInDir. +type RunInDirOptions = git.RunInDirOptions + +// NewCommand creates a new git command. +func NewCommand(args ...string) *git.Command { + return git.NewCommand(args...) +} diff --git a/go.mod b/go.mod index 937af69c0343abfc324206f1e22f197ecf6088b0..5a0c129a0a3d36b3ec2481baa581c74340c23d5d 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/muesli/mango-cobra v1.2.0 github.com/muesli/roff v0.1.0 github.com/spf13/cobra v1.6.1 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c gopkg.in/yaml.v3 v3.0.1 ) diff --git a/server/config/config.go b/server/config/config.go index dc60e2dad1f2cd802a98d1ca3894c6db4bc0d8ba..577e205a7c70be254c0a72c75b8aca48f2321222 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -16,14 +16,19 @@ type Callbacks interface { // Config is the configuration for Soft Serve. type Config struct { - BindAddr string `env:"SOFT_SERVE_BIND_ADDRESS" envDefault:""` - Host string `env:"SOFT_SERVE_HOST" envDefault:"localhost"` - Port int `env:"SOFT_SERVE_PORT" envDefault:"23231"` - KeyPath string `env:"SOFT_SERVE_KEY_PATH"` - RepoPath string `env:"SOFT_SERVE_REPO_PATH" envDefault:".repos"` - InitialAdminKeys []string `env:"SOFT_SERVE_INITIAL_ADMIN_KEY" envSeparator:"\n"` - Callbacks Callbacks - ErrorLog *log.Logger + BindAddr string `env:"SOFT_SERVE_BIND_ADDRESS" envDefault:""` + Host string `env:"SOFT_SERVE_HOST" envDefault:"localhost"` + Port int `env:"SOFT_SERVE_PORT" envDefault:"23231"` + GitPort int `env:"SOFT_SERVE_GIT_PORT" envDefault:"9418"` + GitMaxTimeout int `env:"SOFT_SERVE_GIT_MAX_TIMEOUT" envDefault:"300"` + // MaxReadTimeout is the maximum time a client can take to send a request. + GitMaxReadTimeout int `env:"SOFT_SERVE_GIT_MAX_READ_TIMEOUT" envDefault:"3"` + GitMaxConnections int `env:"SOFT_SERVE_GIT_MAX_CONNECTIONS" envDefault:"32"` + KeyPath string `env:"SOFT_SERVE_KEY_PATH"` + RepoPath string `env:"SOFT_SERVE_REPO_PATH" envDefault:".repos"` + InitialAdminKeys []string `env:"SOFT_SERVE_INITIAL_ADMIN_KEY" envSeparator:"\n"` + Callbacks Callbacks + ErrorLog *log.Logger } // DefaultConfig returns a Config with the values populated with the defaults diff --git a/server/git/auth.go b/server/git/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..4cd894e5ca12f2ae579fa0e962ecbc3839922aae --- /dev/null +++ b/server/git/auth.go @@ -0,0 +1,46 @@ +package git + +import "github.com/gliderlabs/ssh" + +// AccessLevel is the level of access allowed to a repo. +type AccessLevel int + +const ( + // NoAccess does not allow access to the repo. + NoAccess AccessLevel = iota + + // ReadOnlyAccess allows read-only access to the repo. + ReadOnlyAccess + + // ReadWriteAccess allows read and write access to the repo. + ReadWriteAccess + + // AdminAccess allows read, write, and admin access to the repo. + AdminAccess +) + +// String implements the Stringer interface for AccessLevel. +func (a AccessLevel) String() string { + switch a { + case NoAccess: + return "no-access" + case ReadOnlyAccess: + return "read-only" + case ReadWriteAccess: + return "read-write" + case AdminAccess: + return "admin-access" + default: + return "" + } +} + +// Hooks is an interface that allows for custom authorization +// implementations and post push/fetch notifications. Prior to git access, +// AuthRepo will be called with the ssh.Session public key and the repo name. +// Implementers return the appropriate AccessLevel. +type Hooks interface { + AuthRepo(string, ssh.PublicKey) AccessLevel + Push(string, ssh.PublicKey) + Fetch(string, ssh.PublicKey) +} diff --git a/server/git/daemon/daemon.go b/server/git/daemon/daemon.go new file mode 100644 index 0000000000000000000000000000000000000000..ee5fc1aa127fe7c345a355f88d51f9b1168d4760 --- /dev/null +++ b/server/git/daemon/daemon.go @@ -0,0 +1,194 @@ +package daemon + +import ( + "bytes" + "context" + "errors" + "fmt" + "log" + "net" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/git" + "github.com/go-git/go-git/v5/plumbing/format/pktline" +) + +// ErrServerClosed indicates that the server has been closed. +var ErrServerClosed = errors.New("git: Server closed") + +// Daemon represents a Git daemon. +type Daemon struct { + auth git.Hooks + listener net.Listener + addr string + exit chan struct{} + conns map[net.Conn]struct{} + cfg *config.Config + wg sync.WaitGroup +} + +// NewDaemon returns a new Git daemon. +func NewDaemon(cfg *config.Config, auth git.Hooks) (*Daemon, error) { + addr := fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.GitPort) + d := &Daemon{ + addr: addr, + auth: auth, + exit: make(chan struct{}), + cfg: cfg, + conns: make(map[net.Conn]struct{}), + } + listener, err := net.Listen("tcp", d.addr) + if err != nil { + return nil, err + } + d.listener = listener + d.wg.Add(1) + return d, nil +} + +// Start starts the Git TCP daemon. +func (d *Daemon) Start() error { + // set up channel on which to send accepted connections + listen := make(chan net.Conn, d.cfg.GitMaxConnections) + go d.acceptConnection(d.listener, listen) + + // loop work cycle with accept connections or interrupt + // by system signal + for { + log.Printf("listener len %d cap %d", len(listen), cap(listen)) + select { + case conn := <-listen: + d.wg.Add(1) + go func() { + d.handleClient(conn) + d.wg.Done() + }() + case <-d.exit: + if err := d.Close(); err != nil { + return err + } + return ErrServerClosed + } + } +} + +func fatal(c net.Conn, err error) { + git.WritePktline(c, err) + if err := c.Close(); err != nil { + log.Printf("git: error closing connection: %v", err) + } +} + +// acceptConnection accepts connections on the listener. +func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn) { + defer d.wg.Done() + for { + conn, err := listener.Accept() + if err != nil { + select { + case <-d.exit: + log.Printf("git: listener closed") + return + default: + log.Printf("git: error accepting connection: %v", err) + continue + } + } + listen <- conn + } +} + +// handleClient handles a git protocol client. +func (d *Daemon) handleClient(c net.Conn) { + d.conns[c] = struct{}{} + defer delete(d.conns, c) + + // Close connection if there are too many open connections. + if len(d.conns) >= d.cfg.GitMaxConnections { + log.Printf("git: max connections reached, closing %s", c.RemoteAddr()) + fatal(c, git.ErrMaxConns) + return + } + + // Set connection timeout. + if err := c.SetDeadline(time.Now().Add(time.Duration(d.cfg.GitMaxTimeout) * time.Second)); err != nil { + log.Printf("git: error setting deadline: %v", err) + fatal(c, git.ErrSystemMalfunction) + return + } + + readc := make(chan struct{}, 1) + go func() { + select { + case <-time.After(time.Duration(d.cfg.GitMaxReadTimeout) * time.Second): + log.Printf("git: read timeout from %s", c.RemoteAddr()) + fatal(c, git.ErrMaxTimeout) + case <-readc: + } + }() + + s := pktline.NewScanner(c) + if !s.Scan() { + if err := s.Err(); err != nil { + log.Printf("git: error scanning pktline: %v", err) + fatal(c, git.ErrSystemMalfunction) + } + return + } + readc <- struct{}{} + + line := s.Bytes() + split := bytes.SplitN(line, []byte{' '}, 2) + if len(split) != 2 { + return + } + + var repo string + cmd := string(split[0]) + opts := bytes.Split(split[1], []byte{'\x00'}) + if len(opts) == 0 { + return + } + repo = filepath.Clean(string(opts[0])) + + log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo) + defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo) + repo = strings.TrimPrefix(repo, "/") + auth := d.auth.AuthRepo(strings.TrimSuffix(repo, ".git"), nil) + if auth < git.ReadOnlyAccess { + fatal(c, git.ErrNotAuthed) + return + } + // git bare repositories should end in ".git" + // https://git-scm.com/docs/gitrepository-layout + if !strings.HasSuffix(repo, ".git") { + repo += ".git" + } + + err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath, repo) + if err == git.ErrInvalidRepo { + trimmed := strings.TrimSuffix(repo, ".git") + log.Printf("git: invalid repo %q trying again %q", repo, trimmed) + err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath, trimmed) + } + if err != nil { + fatal(c, err) + return + } +} + +// Close closes the underlying listener. +func (d *Daemon) Close() error { + return d.listener.Close() +} + +// Shutdown gracefully shuts down the daemon. +func (d *Daemon) Shutdown(_ context.Context) error { + close(d.exit) + d.wg.Wait() + return nil +} diff --git a/server/git/daemon/daemon_test.go b/server/git/daemon/daemon_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2e5d137328e0c4105760b50d89be284038046991 --- /dev/null +++ b/server/git/daemon/daemon_test.go @@ -0,0 +1,87 @@ +package daemon + +import ( + "bytes" + "context" + "io" + "log" + "net" + "os" + "testing" + + appCfg "github.com/charmbracelet/soft-serve/config" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/git" + "github.com/go-git/go-git/v5/plumbing/format/pktline" +) + +var testDaemon *Daemon + +func TestMain(m *testing.M) { + cfg := config.DefaultConfig() + // Reduce the max connections to 3 so we can test the timeout. + cfg.GitMaxConnections = 3 + // Reduce the max timeout to 100 second so we can test the timeout. + cfg.GitMaxTimeout = 100 + // Reduce the max read timeout to 1 second so we can test the timeout. + cfg.GitMaxReadTimeout = 1 + ac, err := appCfg.NewConfig(cfg) + if err != nil { + log.Fatal(err) + } + d, err := NewDaemon(cfg, ac) + if err != nil { + log.Fatal(err) + } + testDaemon = d + go func() { + if err := d.Start(); err != ErrServerClosed { + log.Fatal(err) + } + }() + defer d.Shutdown(context.Background()) + os.Exit(m.Run()) +} + +func TestMaxReadTimeout(t *testing.T) { + c, err := net.Dial("tcp", testDaemon.addr) + if err != nil { + t.Fatal(err) + } + out, err := readPktline(c) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if out != git.ErrMaxTimeout.Error() { + t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout) + } +} + +func TestInvalidRepo(t *testing.T) { + c, err := net.Dial("tcp", testDaemon.addr) + if err != nil { + t.Fatal(err) + } + if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + out, err := readPktline(c) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if out != git.ErrInvalidRepo.Error() { + t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo) + } +} + +func readPktline(c net.Conn) (string, error) { + buf, err := io.ReadAll(c) + if err != nil { + return "", err + } + pktout := pktline.NewScanner(bytes.NewReader(buf)) + if !pktout.Scan() { + return "", pktout.Err() + } + return string(pktout.Bytes()), nil +} diff --git a/server/git/error.go b/server/git/error.go new file mode 100644 index 0000000000000000000000000000000000000000..2d9dfafb8af951d248fb6b549e49d5adb2def32e --- /dev/null +++ b/server/git/error.go @@ -0,0 +1,18 @@ +package git + +import "errors" + +// ErrNotAuthed represents unauthorized access. +var ErrNotAuthed = errors.New("you are not authorized to do this") + +// ErrSystemMalfunction represents a general system error returned to clients. +var ErrSystemMalfunction = errors.New("something went wrong") + +// ErrInvalidRepo represents an attempt to access a non-existent repo. +var ErrInvalidRepo = errors.New("invalid repo") + +// ErrMaxConns represents a maximum connection limit being reached. +var ErrMaxConns = errors.New("too many connections, try again later") + +// ErrMaxTimeout is returned when the maximum read timeout is exceeded. +var ErrMaxTimeout = errors.New("git: max timeout reached") diff --git a/server/git/pack.go b/server/git/pack.go new file mode 100644 index 0000000000000000000000000000000000000000..88eaf17001794bf1b82b6aa5a65fca62b55f43b0 --- /dev/null +++ b/server/git/pack.go @@ -0,0 +1,131 @@ +package git + +import ( + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/charmbracelet/soft-serve/git" + "github.com/go-git/go-git/v5/plumbing/format/pktline" +) + +// GitPack runs the git pack protocol against the provided repo. +func GitPack(out io.Writer, in io.Reader, er io.Writer, gitCmd string, repoDir string, repo string) error { + cmd := strings.TrimPrefix(gitCmd, "git-") + rp := filepath.Join(repoDir, repo) + switch gitCmd { + case "git-upload-archive", "git-upload-pack": + exists, err := fileExists(rp) + if !exists { + return ErrInvalidRepo + } + if err != nil { + return err + } + return RunGit(out, in, er, "", cmd, rp) + case "git-receive-pack": + err := ensureRepo(repoDir, repo) + if err != nil { + return err + } + err = RunGit(out, in, er, "", cmd, rp) + if err != nil { + return err + } + err = ensureDefaultBranch(out, in, er, rp) + if err != nil { + return err + } + // Needed for git dumb http server + return RunGit(out, in, er, rp, "update-server-info") + default: + return fmt.Errorf("unknown git command: %s", gitCmd) + } +} + +// RunGit runs a git command in the given repo. +func RunGit(out io.Writer, in io.Reader, err io.Writer, dir string, args ...string) error { + c := git.NewCommand(args...) + return c.RunInDirWithOptions(dir, git.RunInDirOptions{ + Stdout: out, + Stdin: in, + Stderr: err, + }) +} + +// WritePktline encodes and writes a pktline to the given writer. +func WritePktline(w io.Writer, v ...interface{}) { + msg := fmt.Sprint(v...) + pkt := pktline.NewEncoder(w) + if err := pkt.EncodeString(msg); err != nil { + log.Printf("git: error writing pkt-line message: %s", err) + } + if err := pkt.Flush(); err != nil { + log.Printf("git: error flushing pkt-line message: %s", err) + } +} + +func fileExists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return true, err +} + +func ensureRepo(dir string, repo string) error { + exists, err := fileExists(dir) + if err != nil { + return err + } + if !exists { + err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700)) + if err != nil { + return err + } + } + rp := filepath.Join(dir, repo) + exists, err = fileExists(rp) + if err != nil { + return err + } + if !exists { + _, err := git.Init(rp, true) + if err != nil { + return err + } + } + return nil +} + +func ensureDefaultBranch(out io.Writer, in io.Reader, er io.Writer, repoPath string) error { + r, err := git.Open(repoPath) + if err != nil { + return err + } + brs, err := r.Branches() + if err != nil { + return err + } + if len(brs) == 0 { + return fmt.Errorf("no branches found") + } + // Rename the default branch to the first branch available + _, err = r.HEAD() + if err == git.ErrReferenceNotExist { + err = RunGit(out, in, er, repoPath, "branch", "-M", brs[0]) + if err != nil { + return err + } + } + if err != nil && err != git.ErrReferenceNotExist { + return err + } + return nil +} diff --git a/server/git/ssh.go b/server/git/ssh.go deleted file mode 100644 index b3165b7e4c1f24c0bdf1307c309ea4f5a9b455e3..0000000000000000000000000000000000000000 --- a/server/git/ssh.go +++ /dev/null @@ -1,250 +0,0 @@ -package git - -import ( - "errors" - "fmt" - "log" - "os" - "path/filepath" - "strings" - - "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/wish" - "github.com/gliderlabs/ssh" - g "github.com/gogs/git-module" -) - -// ErrNotAuthed represents unauthorized access. -var ErrNotAuthed = errors.New("you are not authorized to do this") - -// ErrSystemMalfunction represents a general system error returned to clients. -var ErrSystemMalfunction = errors.New("something went wrong") - -// ErrInvalidRepo represents an attempt to access a non-existent repo. -var ErrInvalidRepo = errors.New("invalid repo") - -// AccessLevel is the level of access allowed to a repo. -type AccessLevel int - -const ( - // NoAccess does not allow access to the repo. - NoAccess AccessLevel = iota - - // ReadOnlyAccess allows read-only access to the repo. - ReadOnlyAccess - - // ReadWriteAccess allows read and write access to the repo. - ReadWriteAccess - - // AdminAccess allows read, write, and admin access to the repo. - AdminAccess -) - -// String implements the Stringer interface for AccessLevel. -func (a AccessLevel) String() string { - switch a { - case NoAccess: - return "no-access" - case ReadOnlyAccess: - return "read-only" - case ReadWriteAccess: - return "read-write" - case AdminAccess: - return "admin-access" - default: - return "" - } -} - -// Hooks is an interface that allows for custom authorization -// implementations and post push/fetch notifications. Prior to git access, -// AuthRepo will be called with the ssh.Session public key and the repo name. -// Implementers return the appropriate AccessLevel. -type Hooks interface { - AuthRepo(string, ssh.PublicKey) AccessLevel - Push(string, ssh.PublicKey) - Fetch(string, ssh.PublicKey) -} - -// Middleware adds Git server functionality to the ssh.Server. Repos are stored -// in the specified repo directory. The provided Hooks implementation will be -// checked for access on a per repo basis for a ssh.Session public key. -// Hooks.Push and Hooks.Fetch will be called on successful completion of -// their commands. -func Middleware(repoDir string, gh Hooks) wish.Middleware { - return func(sh ssh.Handler) ssh.Handler { - return func(s ssh.Session) { - func() { - cmd := s.Command() - if len(cmd) == 2 && strings.HasPrefix(cmd[0], "git") { - gc := cmd[0] - // repo should be in the form of "repo.git" - repo := strings.TrimPrefix(cmd[1], "/") - repo = filepath.Clean(repo) - if strings.Contains(repo, "/") { - log.Printf("invalid repo: %s", repo) - Fatal(s, fmt.Errorf("%s: %s", ErrInvalidRepo, "user repos not supported")) - return - } - pk := s.PublicKey() - access := gh.AuthRepo(strings.TrimSuffix(repo, ".git"), pk) - // git bare repositories should end in ".git" - // https://git-scm.com/docs/gitrepository-layout - if !strings.HasSuffix(repo, ".git") { - repo += ".git" - } - switch gc { - case "git-receive-pack": - switch access { - case ReadWriteAccess, AdminAccess: - err := gitPack(s, gc, repoDir, repo) - if err != nil { - Fatal(s, ErrSystemMalfunction) - } else { - gh.Push(repo, pk) - } - default: - Fatal(s, ErrNotAuthed) - } - return - case "git-upload-archive", "git-upload-pack": - switch access { - case ReadOnlyAccess, ReadWriteAccess, AdminAccess: - // try to upload .git first, then - err := gitPack(s, gc, repoDir, repo) - if err != nil { - err = gitPack(s, gc, repoDir, strings.TrimSuffix(repo, ".git")) - } - switch err { - case ErrInvalidRepo: - Fatal(s, ErrInvalidRepo) - case nil: - gh.Fetch(repo, pk) - default: - log.Printf("unknown git error: %s", err) - Fatal(s, ErrSystemMalfunction) - } - default: - Fatal(s, ErrNotAuthed) - } - return - } - } - }() - sh(s) - } - } -} - -func gitPack(s ssh.Session, gitCmd string, repoDir string, repo string) error { - cmd := strings.TrimPrefix(gitCmd, "git-") - rp := filepath.Join(repoDir, repo) - switch gitCmd { - case "git-upload-archive", "git-upload-pack": - exists, err := fileExists(rp) - if !exists { - return ErrInvalidRepo - } - if err != nil { - return err - } - return runGit(s, "", cmd, rp) - case "git-receive-pack": - err := ensureRepo(repoDir, repo) - if err != nil { - return err - } - err = runGit(s, "", cmd, rp) - if err != nil { - return err - } - err = ensureDefaultBranch(s, rp) - if err != nil { - return err - } - // Needed for git dumb http server - return runGit(s, rp, "update-server-info") - default: - return fmt.Errorf("unknown git command: %s", gitCmd) - } -} - -func fileExists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return true, err -} - -// Fatal prints to the session's STDOUT as a git response and exit 1. -func Fatal(s ssh.Session, v ...interface{}) { - msg := fmt.Sprint(v...) - // hex length includes 4 byte length prefix and ending newline - pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg) - _, _ = wish.WriteString(s, pktLine) - s.Exit(1) // nolint: errcheck -} - -func ensureRepo(dir string, repo string) error { - exists, err := fileExists(dir) - if err != nil { - return err - } - if !exists { - err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700)) - if err != nil { - return err - } - } - rp := filepath.Join(dir, repo) - exists, err = fileExists(rp) - if err != nil { - return err - } - if !exists { - _, err := git.Init(rp, true) - if err != nil { - return err - } - } - return nil -} - -func runGit(s ssh.Session, dir string, args ...string) error { - c := g.NewCommand(args...) - return c.RunInDirWithOptions(dir, g.RunInDirOptions{ - Stdout: s, - Stdin: s, - Stderr: s.Stderr(), - }) -} - -func ensureDefaultBranch(s ssh.Session, repoPath string) error { - r, err := git.Open(repoPath) - if err != nil { - return err - } - brs, err := r.Branches() - if err != nil { - return err - } - if len(brs) == 0 { - return fmt.Errorf("no branches found") - } - // Rename the default branch to the first branch available - _, err = r.HEAD() - if err == git.ErrReferenceNotExist { - err = runGit(s, repoPath, "branch", "-M", brs[0]) - if err != nil { - return err - } - } - if err != nil && err != git.ErrReferenceNotExist { - return err - } - return nil -} diff --git a/server/git/ssh/ssh.go b/server/git/ssh/ssh.go new file mode 100644 index 0000000000000000000000000000000000000000..c1f3c56a9e2d5b1bde36cb7efe4e205b9cb5fdda --- /dev/null +++ b/server/git/ssh/ssh.go @@ -0,0 +1,88 @@ +package ssh + +import ( + "fmt" + "log" + "path/filepath" + "strings" + + "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/wish" + "github.com/gliderlabs/ssh" +) + +// Middleware adds Git server functionality to the ssh.Server. Repos are stored +// in the specified repo directory. The provided Hooks implementation will be +// checked for access on a per repo basis for a ssh.Session public key. +// Hooks.Push and Hooks.Fetch will be called on successful completion of +// their commands. +func Middleware(repoDir string, gh git.Hooks) wish.Middleware { + return func(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + func() { + cmd := s.Command() + if len(cmd) == 2 && strings.HasPrefix(cmd[0], "git") { + gc := cmd[0] + // repo should be in the form of "repo.git" + repo := strings.TrimPrefix(cmd[1], "/") + repo = filepath.Clean(repo) + if strings.Contains(repo, "/") { + log.Printf("invalid repo: %s", repo) + Fatal(s, fmt.Errorf("%s: %s", git.ErrInvalidRepo, "user repos not supported")) + return + } + pk := s.PublicKey() + access := gh.AuthRepo(strings.TrimSuffix(repo, ".git"), pk) + // git bare repositories should end in ".git" + // https://git-scm.com/docs/gitrepository-layout + if !strings.HasSuffix(repo, ".git") { + repo += ".git" + } + switch gc { + case "git-receive-pack": + switch access { + case git.ReadWriteAccess, git.AdminAccess: + err := git.GitPack(s, s, s.Stderr(), gc, repoDir, repo) + if err != nil { + Fatal(s, git.ErrSystemMalfunction) + } else { + gh.Push(repo, pk) + } + default: + Fatal(s, git.ErrNotAuthed) + } + return + case "git-upload-archive", "git-upload-pack": + switch access { + case git.ReadOnlyAccess, git.ReadWriteAccess, git.AdminAccess: + // try to upload .git first, then + err := git.GitPack(s, s, s.Stderr(), gc, repoDir, repo) + if err != nil { + err = git.GitPack(s, s, s.Stderr(), gc, repoDir, strings.TrimSuffix(repo, ".git")) + } + switch err { + case git.ErrInvalidRepo: + Fatal(s, git.ErrInvalidRepo) + case nil: + gh.Fetch(repo, pk) + default: + log.Printf("unknown git error: %s", err) + Fatal(s, git.ErrSystemMalfunction) + } + default: + Fatal(s, git.ErrNotAuthed) + } + return + } + } + }() + sh(s) + } + } +} + +// Fatal prints to the session's STDOUT as a git response and exit 1. +func Fatal(s ssh.Session, v ...interface{}) { + git.WritePktline(s, v...) + s.Exit(1) // nolint: errcheck +} diff --git a/server/git/ssh_test.go b/server/git/ssh/ssh_test.go similarity index 92% rename from server/git/ssh_test.go rename to server/git/ssh/ssh_test.go index 487e94995b500633f8161b0afdf6eb393ee7c949..88f7914a5ba03532870d8015aace60c5341c7a50 100644 --- a/server/git/ssh_test.go +++ b/server/git/ssh/ssh_test.go @@ -1,4 +1,4 @@ -package git +package ssh import ( "fmt" @@ -11,6 +11,7 @@ import ( "testing" "github.com/charmbracelet/keygen" + "github.com/charmbracelet/soft-serve/server/git" "github.com/charmbracelet/wish" "github.com/gliderlabs/ssh" ) @@ -27,13 +28,13 @@ func TestGitMiddleware(t *testing.T) { pushes: []action{}, fetches: []action{}, access: []accessDetails{ - {pubkey, "repo1", AdminAccess}, - {pubkey, "repo2", AdminAccess}, - {pubkey, "repo3", AdminAccess}, - {pubkey, "repo4", AdminAccess}, - {pubkey, "repo5", NoAccess}, - {pubkey, "repo6", ReadOnlyAccess}, - {pubkey, "repo7", AdminAccess}, + {pubkey, "repo1", git.AdminAccess}, + {pubkey, "repo2", git.AdminAccess}, + {pubkey, "repo3", git.AdminAccess}, + {pubkey, "repo4", git.AdminAccess}, + {pubkey, "repo5", git.NoAccess}, + {pubkey, "repo6", git.ReadOnlyAccess}, + {pubkey, "repo7", git.AdminAccess}, }, } srv, err := wish.NewServer( @@ -179,7 +180,7 @@ func createKeyPair(t *testing.T) (ssh.PublicKey, string) { type accessDetails struct { key ssh.PublicKey repo string - level AccessLevel + level git.AccessLevel } type action struct { @@ -194,13 +195,13 @@ type testHooks struct { access []accessDetails } -func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) AccessLevel { +func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) git.AccessLevel { for _, dets := range h.access { if dets.repo == repo && ssh.KeysEqual(key, dets.key) { return dets.level } } - return NoAccess + return git.NoAccess } func (h *testHooks) Push(repo string, key ssh.PublicKey) { diff --git a/server/server.go b/server/server.go index d79782e110e7293eacccbc86bb9c62f038f1fd6d..bcbfcc55e643622f3da3e07aef6464e891c335d0 100644 --- a/server/server.go +++ b/server/server.go @@ -4,23 +4,25 @@ import ( "context" "fmt" "log" - "net" appCfg "github.com/charmbracelet/soft-serve/config" cm "github.com/charmbracelet/soft-serve/server/cmd" "github.com/charmbracelet/soft-serve/server/config" - gm "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/git/daemon" + gm "github.com/charmbracelet/soft-serve/server/git/ssh" "github.com/charmbracelet/wish" bm "github.com/charmbracelet/wish/bubbletea" lm "github.com/charmbracelet/wish/logging" rm "github.com/charmbracelet/wish/recover" "github.com/gliderlabs/ssh" "github.com/muesli/termenv" + "golang.org/x/sync/errgroup" ) // Server is the Soft Serve server. type Server struct { SSHServer *ssh.Server + GitServer *daemon.Daemon Config *config.Config config *appCfg.Config } @@ -58,40 +60,63 @@ func NewServer(cfg *config.Config) *Server { if err != nil { log.Fatalln(err) } + d, err := daemon.NewDaemon(cfg, ac) + if err != nil { + log.Fatalln(err) + } return &Server{ SSHServer: s, + GitServer: d, Config: cfg, config: ac, } } // Reload reloads the server configuration. -func (srv *Server) Reload() error { - return srv.config.Reload() +func (s *Server) Reload() error { + return s.config.Reload() } // Start starts the SSH server. -func (srv *Server) Start() error { - if err := srv.SSHServer.ListenAndServe(); err != ssh.ErrServerClosed { - return err - } - return nil -} - -// Serve serves the SSH server using the provided listener. -func (srv *Server) Serve(l net.Listener) error { - if err := srv.SSHServer.Serve(l); err != ssh.ErrServerClosed { - return err - } - return nil +func (s *Server) Start() error { + var errg errgroup.Group + errg.Go(func() error { + log.Printf("Starting Git server on %s:%d", s.Config.BindAddr, s.Config.GitPort) + if err := s.GitServer.Start(); err != daemon.ErrServerClosed { + return err + } + return nil + }) + errg.Go(func() error { + log.Printf("Starting SSH server on %s:%d", s.Config.BindAddr, s.Config.Port) + if err := s.SSHServer.ListenAndServe(); err != ssh.ErrServerClosed { + return err + } + return nil + }) + return errg.Wait() } // Shutdown lets the server gracefully shutdown. -func (srv *Server) Shutdown(ctx context.Context) error { - return srv.SSHServer.Shutdown(ctx) +func (s *Server) Shutdown(ctx context.Context) error { + var errg errgroup.Group + errg.Go(func() error { + return s.SSHServer.Shutdown(ctx) + }) + errg.Go(func() error { + return s.GitServer.Shutdown(ctx) + }) + return errg.Wait() } // Close closes the SSH server. -func (srv *Server) Close() error { - return srv.SSHServer.Close() +func (s *Server) Close() error { + var errg errgroup.Group + errg.Go(func() error { + return s.SSHServer.Close() + }) + errg.Go(func() error { + return s.GitServer.Close() + }) + return errg.Wait() }