git.go

  1package git
  2
  3import (
  4	"bufio"
  5	"context"
  6	"fmt"
  7	"log"
  8	"os"
  9	"os/exec"
 10	"smoothie/server/middleware"
 11
 12	"github.com/gliderlabs/ssh"
 13)
 14
 15func Middleware(repoDir string, authorizedKeysPath string) middleware.Middleware {
 16	authedKeys := make([]ssh.PublicKey, 0)
 17	hasAuth, err := fileExists(authorizedKeysPath)
 18	if err != nil {
 19		log.Fatal(err)
 20	}
 21	if hasAuth {
 22		f, err := os.Open(authorizedKeysPath)
 23		if err != nil {
 24			log.Fatal(err)
 25		}
 26		defer f.Close()
 27		scanner := bufio.NewScanner(f)
 28		for scanner.Scan() {
 29			pt := scanner.Text()
 30			log.Printf("Adding authorized key: %s", pt)
 31			pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pt))
 32			if err != nil {
 33				log.Fatal(err)
 34			}
 35			authedKeys = append(authedKeys, pk)
 36		}
 37		if err := scanner.Err(); err != nil {
 38			log.Fatal(err)
 39		}
 40	}
 41	return func(sh ssh.Handler) ssh.Handler {
 42		return func(s ssh.Session) {
 43			cmd := s.Command()
 44			if len(cmd) == 2 {
 45				switch cmd[0] {
 46				case "git-upload-pack", "git-upload-archive", "git-receive-pack":
 47					if hasAuth && cmd[0] == "git-receive-pack" {
 48						authed := false
 49						for _, pk := range authedKeys {
 50							if ssh.KeysEqual(pk, s.PublicKey()) {
 51								authed = true
 52							}
 53						}
 54						if !authed {
 55							fatalGit(s, fmt.Errorf("you are not authorized to do this"))
 56							break
 57						}
 58					}
 59					r := cmd[1]
 60					rp := fmt.Sprintf("%s%s", repoDir, r)
 61					ctx := s.Context()
 62					err := ensureRepo(ctx, repoDir, r)
 63					if err != nil {
 64						fatalGit(s, err)
 65						break
 66					}
 67					c := exec.CommandContext(ctx, cmd[0], rp)
 68					c.Dir = "./"
 69					c.Stdout = s
 70					c.Stdin = s
 71					err = c.Run()
 72					if err != nil {
 73						fatalGit(s, err)
 74						break
 75					}
 76				}
 77			}
 78			sh(s)
 79		}
 80	}
 81}
 82
 83func fileExists(path string) (bool, error) {
 84	_, err := os.Stat(path)
 85	if err == nil {
 86		return true, nil
 87	}
 88	if os.IsNotExist(err) {
 89		return false, nil
 90	}
 91	return true, err
 92}
 93
 94func fatalGit(s ssh.Session, err error) {
 95	// hex length includes 4 byte length prefix and ending newline
 96	msg := err.Error()
 97	pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
 98	_, _ = s.Write([]byte(pktLine))
 99	s.Exit(1)
100}
101
102func ensureRepo(ctx context.Context, dir string, repo string) error {
103	exists, err := fileExists(dir)
104	if err != nil {
105		return err
106	}
107	if !exists {
108		err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
109		if err != nil {
110			return err
111		}
112	}
113	rp := fmt.Sprintf("%s%s", dir, repo)
114	exists, err = fileExists(rp)
115	if err != nil {
116		return err
117	}
118	if !exists {
119		c := exec.CommandContext(ctx, "git", "init", "--bare", rp)
120		err = c.Run()
121		if err != nil {
122			return err
123		}
124	}
125	return nil
126}