@@ -0,0 +1,247 @@
+package git
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "github.com/charmbracelet/wish"
+ "github.com/gliderlabs/ssh"
+ "github.com/go-git/go-git/v5"
+ "github.com/go-git/go-git/v5/plumbing"
+)
+
+// ErrNotAuthed represents unauthorized access.
+var ErrNotAuthed = errors.New("you are not authorized to do this")
+
+// ErrSystemMalfunction represents a general system error returned to clients.
+var ErrSystemMalfunction = errors.New("something went wrong")
+
+// ErrInvalidRepo represents an attempt to access a non-existent repo.
+var ErrInvalidRepo = errors.New("invalid repo")
+
+// AccessLevel is the level of access allowed to a repo.
+type AccessLevel int
+
+const (
+ // NoAccess does not allow access to the repo.
+ NoAccess AccessLevel = iota
+
+ // ReadOnlyAccess allows read-only access to the repo.
+ ReadOnlyAccess
+
+ // ReadWriteAccess allows read and write access to the repo.
+ ReadWriteAccess
+
+ // AdminAccess allows read, write, and admin access to the repo.
+ AdminAccess
+)
+
+// String implements the Stringer interface for AccessLevel.
+func (a AccessLevel) String() string {
+ switch a {
+ case NoAccess:
+ return "no-access"
+ case ReadOnlyAccess:
+ return "read-only"
+ case ReadWriteAccess:
+ return "read-write"
+ case AdminAccess:
+ return "admin-access"
+ default:
+ return ""
+ }
+}
+
+// Hooks is an interface that allows for custom authorization
+// implementations and post push/fetch notifications. Prior to git access,
+// AuthRepo will be called with the ssh.Session public key and the repo name.
+// Implementers return the appropriate AccessLevel.
+type Hooks interface {
+ AuthRepo(string, ssh.PublicKey) AccessLevel
+ Push(string, ssh.PublicKey)
+ Fetch(string, ssh.PublicKey)
+}
+
+// Middleware adds Git server functionality to the ssh.Server. Repos are stored
+// in the specified repo directory. The provided Hooks implementation will be
+// checked for access on a per repo basis for a ssh.Session public key.
+// Hooks.Push and Hooks.Fetch will be called on successful completion of
+// their commands.
+func Middleware(repoDir string, gh Hooks) wish.Middleware {
+ return func(sh ssh.Handler) ssh.Handler {
+ return func(s ssh.Session) {
+ cmd := s.Command()
+ if len(cmd) == 2 {
+ gc := cmd[0]
+ // repo should be in the form of "repo.git"
+ repo := strings.TrimPrefix(cmd[1], "/")
+ repo = filepath.Clean(repo)
+ if strings.Contains(repo, "/") {
+ Fatal(s, fmt.Errorf("%s: %s", ErrInvalidRepo, "user repos not supported"))
+ }
+ // git bare repositories should end in ".git"
+ // https://git-scm.com/docs/gitrepository-layout
+ if !strings.HasSuffix(repo, ".git") {
+ repo += ".git"
+ }
+ pk := s.PublicKey()
+ access := gh.AuthRepo(repo, pk)
+ switch gc {
+ case "git-receive-pack":
+ switch access {
+ case ReadWriteAccess, AdminAccess:
+ err := gitPack(s, gc, repoDir, repo)
+ if err != nil {
+ Fatal(s, ErrSystemMalfunction)
+ } else {
+ gh.Push(repo, pk)
+ }
+ default:
+ Fatal(s, ErrNotAuthed)
+ }
+ return
+ case "git-upload-archive", "git-upload-pack":
+ switch access {
+ case ReadOnlyAccess, ReadWriteAccess, AdminAccess:
+ err := gitPack(s, gc, repoDir, repo)
+ switch err {
+ case ErrInvalidRepo:
+ Fatal(s, ErrInvalidRepo)
+ case nil:
+ gh.Fetch(repo, pk)
+ default:
+ log.Printf("unknown git error: %s", err)
+ Fatal(s, ErrSystemMalfunction)
+ }
+ default:
+ Fatal(s, ErrNotAuthed)
+ }
+ return
+ }
+ }
+ sh(s)
+ }
+ }
+}
+
+func gitPack(s ssh.Session, gitCmd string, repoDir string, repo string) error {
+ cmd := strings.TrimPrefix(gitCmd, "git-")
+ rp := filepath.Join(repoDir, repo)
+ switch gitCmd {
+ case "git-upload-archive", "git-upload-pack":
+ exists, err := fileExists(rp)
+ if !exists {
+ return ErrInvalidRepo
+ }
+ if err != nil {
+ return err
+ }
+ return runGit(s, "", cmd, rp)
+ case "git-receive-pack":
+ err := ensureRepo(repoDir, repo)
+ if err != nil {
+ return err
+ }
+ err = runGit(s, "", cmd, rp)
+ if err != nil {
+ return err
+ }
+ err = ensureDefaultBranch(s, rp)
+ if err != nil {
+ return err
+ }
+ // Needed for git dumb http server
+ return runGit(s, rp, "update-server-info")
+ default:
+ return fmt.Errorf("unknown git command: %s", gitCmd)
+ }
+}
+
+func fileExists(path string) (bool, error) {
+ _, err := os.Stat(path)
+ if err == nil {
+ return true, nil
+ }
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return true, err
+}
+
+// Fatal prints to the session's STDOUT as a git response and exit 1.
+func Fatal(s ssh.Session, v ...interface{}) {
+ msg := fmt.Sprint(v...)
+ // hex length includes 4 byte length prefix and ending newline
+ pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
+ _, _ = wish.WriteString(s, pktLine)
+ s.Exit(1) // nolint: errcheck
+}
+
+func ensureRepo(dir string, repo string) error {
+ exists, err := fileExists(dir)
+ if err != nil {
+ return err
+ }
+ if !exists {
+ err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
+ if err != nil {
+ return err
+ }
+ }
+ rp := filepath.Join(dir, repo)
+ exists, err = fileExists(rp)
+ if err != nil {
+ return err
+ }
+ if !exists {
+ _, err := git.PlainInit(rp, true)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func runGit(s ssh.Session, dir string, args ...string) error {
+ usi := exec.CommandContext(s.Context(), "git", args...)
+ usi.Dir = dir
+ usi.Stdout = s
+ usi.Stdin = s
+ if err := usi.Run(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func ensureDefaultBranch(s ssh.Session, repoPath string) error {
+ r, err := git.PlainOpen(repoPath)
+ if err != nil {
+ return err
+ }
+ brs, err := r.Branches()
+ if err != nil {
+ return err
+ }
+ defer brs.Close()
+ fb, err := brs.Next()
+ if err != nil {
+ return err
+ }
+ // Rename the default branch to the first branch available
+ _, err = r.Head()
+ if err == plumbing.ErrReferenceNotFound {
+ err = runGit(s, repoPath, "branch", "-M", fb.Name().Short())
+ if err != nil {
+ return err
+ }
+ }
+ if err != nil && err != plumbing.ErrReferenceNotFound {
+ return err
+ }
+ return nil
+}
@@ -0,0 +1,227 @@
+package git
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/charmbracelet/keygen"
+ "github.com/charmbracelet/wish"
+ "github.com/gliderlabs/ssh"
+)
+
+func TestGitMiddleware(t *testing.T) {
+ pubkey, pkPath := createKeyPair(t)
+
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ requireNoError(t, err)
+ remote := "ssh://" + l.Addr().String()
+
+ repoDir := t.TempDir()
+ hooks := &testHooks{
+ pushes: []action{},
+ fetches: []action{},
+ access: []accessDetails{
+ {pubkey, "repo1", AdminAccess},
+ {pubkey, "repo2", AdminAccess},
+ {pubkey, "repo3", AdminAccess},
+ {pubkey, "repo4", AdminAccess},
+ {pubkey, "repo5", NoAccess},
+ {pubkey, "repo6", ReadOnlyAccess},
+ {pubkey, "repo7", AdminAccess},
+ },
+ }
+ srv, err := wish.NewServer(
+ wish.WithMiddleware(Middleware(repoDir, hooks)),
+ wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
+ return true
+ }),
+ )
+ requireNoError(t, err)
+ go func() { srv.Serve(l) }()
+ t.Cleanup(func() { _ = srv.Close() })
+
+ t.Run("create repo on master", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "master"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo1"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "master"))
+ requireHasAction(t, hooks.pushes, pubkey, "repo1")
+ })
+
+ t.Run("create repo on main", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo2"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
+ requireHasAction(t, hooks.pushes, pubkey, "repo2")
+ })
+
+ t.Run("create and clone repo", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo3"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
+
+ cwd = t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo3"))
+
+ requireHasAction(t, hooks.pushes, pubkey, "repo3")
+ requireHasAction(t, hooks.fetches, pubkey, "repo3")
+ })
+
+ t.Run("clone repo that doesn't exist", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo4"))
+ })
+
+ t.Run("clone repo with no access", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo5"))
+ })
+
+ t.Run("push repo with with readonly", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo6"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
+ requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
+ })
+
+ t.Run("create and clone repo on weird branch", func(t *testing.T) {
+ cwd := t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "a-weird-branch-name"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo7"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "a-weird-branch-name"))
+
+ cwd = t.TempDir()
+ requireNoError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo7"))
+
+ requireHasAction(t, hooks.pushes, pubkey, "repo7")
+ requireHasAction(t, hooks.fetches, pubkey, "repo7")
+ })
+}
+
+func runGitHelper(t *testing.T, pk, cwd string, args ...string) error {
+ t.Helper()
+
+ allArgs := []string{
+ "-c", "user.name='wish'",
+ "-c", "user.email='test@wish'",
+ "-c", "commit.gpgSign=false",
+ "-c", "tag.gpgSign=false",
+ "-c", "log.showSignature=false",
+ "-c", "ssh.variant=ssh",
+ }
+ allArgs = append(allArgs, args...)
+
+ cmd := exec.Command("git", allArgs...)
+ cmd.Dir = cwd
+ cmd.Env = []string{fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -i %s -F /dev/null`, pk)}
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Log("git out:", string(out))
+ }
+ return err
+}
+
+func requireNoError(t *testing.T, err error) {
+ t.Helper()
+
+ if err != nil {
+ t.Fatalf("expected no error, got %q", err.Error())
+ }
+}
+
+func requireError(t *testing.T, err error) {
+ t.Helper()
+
+ if err == nil {
+ t.Fatalf("expected an error, got nil")
+ }
+}
+
+func requireHasAction(t *testing.T, actions []action, key ssh.PublicKey, repo string) {
+ t.Helper()
+
+ for _, action := range actions {
+ r1 := repo
+ if !strings.HasSuffix(r1, ".git") {
+ r1 += ".git"
+ }
+ r2 := action.repo
+ if !strings.HasSuffix(r2, ".git") {
+ r2 += ".git"
+ }
+ if r1 == r2 && ssh.KeysEqual(key, action.key) {
+ return
+ }
+ }
+ t.Fatalf("expected action for %q, got none", repo)
+}
+
+func createKeyPair(t *testing.T) (ssh.PublicKey, string) {
+ t.Helper()
+
+ keyDir := t.TempDir()
+ _, err := keygen.NewWithWrite(filepath.Join(keyDir, "id"), nil, keygen.Ed25519)
+ requireNoError(t, err)
+ pk := filepath.Join(keyDir, "id_ed25519")
+ pubBytes, err := os.ReadFile(filepath.Join(keyDir, "id_ed25519.pub"))
+ requireNoError(t, err)
+ pubkey, _, _, _, err := ssh.ParseAuthorizedKey(pubBytes)
+ requireNoError(t, err)
+ return pubkey, pk
+}
+
+type accessDetails struct {
+ key ssh.PublicKey
+ repo string
+ level AccessLevel
+}
+
+type action struct {
+ key ssh.PublicKey
+ repo string
+}
+
+type testHooks struct {
+ sync.Mutex
+ pushes []action
+ fetches []action
+ access []accessDetails
+}
+
+func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) AccessLevel {
+ for _, dets := range h.access {
+ r1 := strings.TrimSuffix(dets.repo, ".git")
+ r2 := strings.TrimSuffix(repo, ".git")
+ if r1 == r2 && ssh.KeysEqual(key, dets.key) {
+ return dets.level
+ }
+ }
+ return NoAccess
+}
+
+func (h *testHooks) Push(repo string, key ssh.PublicKey) {
+ h.Lock()
+ defer h.Unlock()
+
+ h.pushes = append(h.pushes, action{key, repo})
+}
+
+func (h *testHooks) Fetch(repo string, key ssh.PublicKey) {
+ h.Lock()
+ defer h.Unlock()
+
+ h.fetches = append(h.fetches, action{key, repo})
+}