1package git
2
3import (
4 "bufio"
5 "context"
6 "fmt"
7 "log"
8 "os"
9 "os/exec"
10 "smoothie/server/middleware"
11 "strings"
12
13 "github.com/gliderlabs/ssh"
14)
15
16func Middleware(repoDir, authorizedKeys, authorizedKeysFile string) middleware.Middleware {
17 authedKeys := make([]ssh.PublicKey, 0)
18 hasAuth, err := fileExists(authorizedKeysFile)
19 if err != nil {
20 log.Fatal(err)
21 }
22 if hasAuth || authorizedKeys != "" {
23 var scanner *bufio.Scanner
24 if authorizedKeys == "" {
25 log.Printf("Importing authorized keys from file: %s", authorizedKeysFile)
26 f, err := os.Open(authorizedKeysFile)
27 if err != nil {
28 log.Fatal(err)
29 }
30 defer f.Close()
31 scanner = bufio.NewScanner(f)
32 } else {
33 log.Printf("Importing authorized keys from environment")
34 scanner = bufio.NewScanner(strings.NewReader(authorizedKeys))
35 }
36 for scanner.Scan() {
37 pt := scanner.Text()
38 log.Printf("Adding authorized key: %s", pt)
39 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pt))
40 if err != nil {
41 log.Fatal(err)
42 }
43 authedKeys = append(authedKeys, pk)
44 }
45 if err := scanner.Err(); err != nil {
46 log.Fatal(err)
47 }
48 }
49 return func(sh ssh.Handler) ssh.Handler {
50 return func(s ssh.Session) {
51 cmd := s.Command()
52 if len(cmd) == 2 {
53 switch cmd[0] {
54 case "git-upload-pack", "git-upload-archive", "git-receive-pack":
55 if hasAuth && cmd[0] == "git-receive-pack" {
56 authed := false
57 for _, pk := range authedKeys {
58 if ssh.KeysEqual(pk, s.PublicKey()) {
59 authed = true
60 }
61 }
62 if !authed {
63 fatalGit(s, fmt.Errorf("you are not authorized to do this"))
64 break
65 }
66 }
67 r := cmd[1]
68 rp := fmt.Sprintf("%s%s", repoDir, r)
69 ctx := s.Context()
70 err := ensureRepo(ctx, repoDir, r)
71 if err != nil {
72 fatalGit(s, err)
73 break
74 }
75 c := exec.CommandContext(ctx, cmd[0], rp)
76 c.Dir = "./"
77 c.Stdout = s
78 c.Stdin = s
79 err = c.Run()
80 if err != nil {
81 fatalGit(s, err)
82 break
83 }
84 }
85 }
86 sh(s)
87 }
88 }
89}
90
91func fileExists(path string) (bool, error) {
92 _, err := os.Stat(path)
93 if err == nil {
94 return true, nil
95 }
96 if os.IsNotExist(err) {
97 return false, nil
98 }
99 return true, err
100}
101
102func fatalGit(s ssh.Session, err error) {
103 // hex length includes 4 byte length prefix and ending newline
104 msg := err.Error()
105 pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
106 _, _ = s.Write([]byte(pktLine))
107 s.Exit(1)
108}
109
110func ensureRepo(ctx context.Context, dir string, repo string) error {
111 exists, err := fileExists(dir)
112 if err != nil {
113 return err
114 }
115 if !exists {
116 err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
117 if err != nil {
118 return err
119 }
120 }
121 rp := fmt.Sprintf("%s%s", dir, repo)
122 exists, err = fileExists(rp)
123 if err != nil {
124 return err
125 }
126 if !exists {
127 c := exec.CommandContext(ctx, "git", "init", "--bare", rp)
128 err = c.Run()
129 if err != nil {
130 return err
131 }
132 }
133 return nil
134}