git.go

  1package git
  2
  3import (
  4	"bufio"
  5	"context"
  6	"fmt"
  7	"log"
  8	"os"
  9	"os/exec"
 10	"smoothie/server/middleware"
 11	"strings"
 12
 13	"github.com/gliderlabs/ssh"
 14)
 15
 16func Middleware(repoDir, authorizedKeys, authorizedKeysFile string) middleware.Middleware {
 17	authedKeys := make([]ssh.PublicKey, 0)
 18	hasAuth, err := fileExists(authorizedKeysFile)
 19	if err != nil {
 20		log.Fatal(err)
 21	}
 22	if hasAuth || authorizedKeys != "" {
 23		var scanner *bufio.Scanner
 24		if authorizedKeys == "" {
 25			log.Printf("Importing authorized keys from file: %s", authorizedKeysFile)
 26			f, err := os.Open(authorizedKeysFile)
 27			if err != nil {
 28				log.Fatal(err)
 29			}
 30			defer f.Close()
 31			scanner = bufio.NewScanner(f)
 32		} else {
 33			log.Printf("Importing authorized keys from environment")
 34			scanner = bufio.NewScanner(strings.NewReader(authorizedKeys))
 35		}
 36		for scanner.Scan() {
 37			pt := scanner.Text()
 38			log.Printf("Adding authorized key: %s", pt)
 39			pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pt))
 40			if err != nil {
 41				log.Fatal(err)
 42			}
 43			authedKeys = append(authedKeys, pk)
 44		}
 45		if err := scanner.Err(); err != nil {
 46			log.Fatal(err)
 47		}
 48	}
 49	return func(sh ssh.Handler) ssh.Handler {
 50		return func(s ssh.Session) {
 51			cmd := s.Command()
 52			if len(cmd) == 2 {
 53				switch cmd[0] {
 54				case "git-upload-pack", "git-upload-archive", "git-receive-pack":
 55					if hasAuth && cmd[0] == "git-receive-pack" {
 56						authed := false
 57						for _, pk := range authedKeys {
 58							if ssh.KeysEqual(pk, s.PublicKey()) {
 59								authed = true
 60							}
 61						}
 62						if !authed {
 63							fatalGit(s, fmt.Errorf("you are not authorized to do this"))
 64							break
 65						}
 66					}
 67					r := cmd[1]
 68					rp := fmt.Sprintf("%s%s", repoDir, r)
 69					ctx := s.Context()
 70					err := ensureRepo(ctx, repoDir, r)
 71					if err != nil {
 72						fatalGit(s, err)
 73						break
 74					}
 75					c := exec.CommandContext(ctx, cmd[0], rp)
 76					c.Dir = "./"
 77					c.Stdout = s
 78					c.Stdin = s
 79					err = c.Run()
 80					if err != nil {
 81						fatalGit(s, err)
 82						break
 83					}
 84				}
 85			}
 86			sh(s)
 87		}
 88	}
 89}
 90
 91func fileExists(path string) (bool, error) {
 92	_, err := os.Stat(path)
 93	if err == nil {
 94		return true, nil
 95	}
 96	if os.IsNotExist(err) {
 97		return false, nil
 98	}
 99	return true, err
100}
101
102func fatalGit(s ssh.Session, err error) {
103	// hex length includes 4 byte length prefix and ending newline
104	msg := err.Error()
105	pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
106	_, _ = s.Write([]byte(pktLine))
107	s.Exit(1)
108}
109
110func ensureRepo(ctx context.Context, dir string, repo string) error {
111	exists, err := fileExists(dir)
112	if err != nil {
113		return err
114	}
115	if !exists {
116		err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
117		if err != nil {
118			return err
119		}
120	}
121	rp := fmt.Sprintf("%s%s", dir, repo)
122	exists, err = fileExists(rp)
123	if err != nil {
124		return err
125	}
126	if !exists {
127		c := exec.CommandContext(ctx, "git", "init", "--bare", rp)
128		err = c.Run()
129		if err != nil {
130			return err
131		}
132	}
133	return nil
134}