git.go

  1package git
  2
  3import (
  4	"bufio"
  5	"context"
  6	"fmt"
  7	"log"
  8	"os"
  9	"os/exec"
 10	"smoothie/server/middleware"
 11	"strings"
 12
 13	"github.com/gliderlabs/ssh"
 14)
 15
 16func gitMiddleware(repoDir string, authedKeys []ssh.PublicKey) middleware.Middleware {
 17	return func(sh ssh.Handler) ssh.Handler {
 18		return func(s ssh.Session) {
 19			cmd := s.Command()
 20			if len(cmd) == 2 {
 21				switch cmd[0] {
 22				case "git-upload-pack", "git-upload-archive", "git-receive-pack":
 23					if len(authedKeys) > 0 && cmd[0] == "git-receive-pack" {
 24						authed := false
 25						for _, pk := range authedKeys {
 26							if ssh.KeysEqual(pk, s.PublicKey()) {
 27								authed = true
 28							}
 29						}
 30						if !authed {
 31							fatalGit(s, fmt.Errorf("you are not authorized to do this"))
 32							break
 33						}
 34					}
 35					r := cmd[1]
 36					rp := fmt.Sprintf("%s%s", repoDir, r)
 37					ctx := s.Context()
 38					err := ensureRepo(ctx, repoDir, r)
 39					if err != nil {
 40						fatalGit(s, err)
 41						break
 42					}
 43					c := exec.CommandContext(ctx, cmd[0], rp)
 44					c.Dir = "./"
 45					c.Stdout = s
 46					c.Stdin = s
 47					err = c.Run()
 48					if err != nil {
 49						fatalGit(s, err)
 50						break
 51					}
 52				}
 53			}
 54			sh(s)
 55		}
 56	}
 57}
 58
 59func Middleware(repoDir, authorizedKeys, authorizedKeysFile string) middleware.Middleware {
 60	ak1, err := parseKeysFromString(authorizedKeys)
 61	if err != nil {
 62		log.Fatal(err)
 63	}
 64	ak2, err := parseKeysFromFile(authorizedKeysFile)
 65	if err != nil {
 66		log.Fatal(err)
 67	}
 68	authedKeys := append(ak1, ak2...)
 69	return gitMiddleware(repoDir, authedKeys)
 70}
 71
 72func MiddlewareWithKeys(repoDir, authorizedKeys string) middleware.Middleware {
 73	return Middleware(repoDir, authorizedKeys, "")
 74}
 75
 76func MiddlewareWithKeyPath(repoDir, authorizedKeysFile string) middleware.Middleware {
 77	return Middleware(repoDir, "", authorizedKeysFile)
 78}
 79
 80func parseKeysFromFile(path string) ([]ssh.PublicKey, error) {
 81	authedKeys := make([]ssh.PublicKey, 0)
 82	hasAuth, err := fileExists(path)
 83	if err != nil {
 84		return nil, err
 85	}
 86	if hasAuth {
 87		f, err := os.Open(path)
 88		if err != nil {
 89			log.Fatal(err)
 90		}
 91		defer f.Close()
 92		scanner := bufio.NewScanner(f)
 93		err = addKeys(scanner, &authedKeys)
 94		if err != nil {
 95			return nil, err
 96		}
 97	}
 98	return authedKeys, nil
 99}
100
101func parseKeysFromString(keys string) ([]ssh.PublicKey, error) {
102	authedKeys := make([]ssh.PublicKey, 0)
103	scanner := bufio.NewScanner(strings.NewReader(keys))
104	err := addKeys(scanner, &authedKeys)
105	if err != nil {
106		return nil, err
107	}
108	return authedKeys, nil
109}
110
111func addKeys(s *bufio.Scanner, keys *[]ssh.PublicKey) error {
112	for s.Scan() {
113		pt := s.Text()
114		log.Printf("Adding authorized key: %s", pt)
115		pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pt))
116		if err != nil {
117			return err
118		}
119		*keys = append(*keys, pk)
120	}
121	if err := s.Err(); err != nil {
122		return err
123	}
124	return nil
125}
126
127func fileExists(path string) (bool, error) {
128	_, err := os.Stat(path)
129	if err == nil {
130		return true, nil
131	}
132	if os.IsNotExist(err) {
133		return false, nil
134	}
135	return true, err
136}
137
138func fatalGit(s ssh.Session, err error) {
139	// hex length includes 4 byte length prefix and ending newline
140	msg := err.Error()
141	pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
142	_, _ = s.Write([]byte(pktLine))
143	s.Exit(1)
144}
145
146func ensureRepo(ctx context.Context, dir string, repo string) error {
147	exists, err := fileExists(dir)
148	if err != nil {
149		return err
150	}
151	if !exists {
152		err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
153		if err != nil {
154			return err
155		}
156	}
157	rp := fmt.Sprintf("%s%s", dir, repo)
158	exists, err = fileExists(rp)
159	if err != nil {
160		return err
161	}
162	if !exists {
163		c := exec.CommandContext(ctx, "git", "init", "--bare", rp)
164		err = c.Run()
165		if err != nil {
166			return err
167		}
168	}
169	return nil
170}