1package main
2
3import (
4 "bufio"
5 "fmt"
6 "os"
7 "strings"
8
9 "github.com/charmbracelet/soft-serve/config"
10 "github.com/spf13/cobra"
11 gossh "golang.org/x/crypto/ssh"
12)
13
14var (
15 internalCmd = &cobra.Command{
16 Use: "internal",
17 Short: "Internal Soft Serve API",
18 Long: `Soft Serve internal API.
19This command is used to communicate with the Soft Serve SSH server.`,
20 Hidden: true,
21 }
22
23 hookCmd = &cobra.Command{
24 Use: "hook",
25 Short: "Run git server hooks",
26 Long: "Handles git server hooks. This includes pre-receive, update, and post-receive.",
27 }
28
29 preReceiveCmd = &cobra.Command{
30 Use: "pre-receive",
31 Short: "Run git pre-receive hook",
32 RunE: func(cmd *cobra.Command, args []string) error {
33 c, s, err := commonInit()
34 if err != nil {
35 return err
36 }
37 defer c.Close() //nolint:errcheck
38 defer s.Close() //nolint:errcheck
39 in, err := s.StdinPipe()
40 if err != nil {
41 return err
42 }
43 scanner := bufio.NewScanner(os.Stdin)
44 for scanner.Scan() {
45 in.Write([]byte(scanner.Text()))
46 in.Write([]byte("\n"))
47 }
48 in.Close() //nolint:errcheck
49 b, err := s.Output("internal hook pre-receive")
50 if err != nil {
51 return err
52 }
53 cmd.Print(string(b))
54 return nil
55 },
56 }
57
58 updateCmd = &cobra.Command{
59 Use: "update",
60 Short: "Run git update hook",
61 Args: cobra.ExactArgs(3),
62 RunE: func(cmd *cobra.Command, args []string) error {
63 refName := args[0]
64 oldSha := args[1]
65 newSha := args[2]
66 c, s, err := commonInit()
67 if err != nil {
68 return err
69 }
70 defer c.Close() //nolint:errcheck
71 defer s.Close() //nolint:errcheck
72 b, err := s.Output(fmt.Sprintf("internal hook update %s %s %s", refName, oldSha, newSha))
73 if err != nil {
74 return err
75 }
76 cmd.Print(string(b))
77 return nil
78 },
79 }
80
81 postReceiveCmd = &cobra.Command{
82 Use: "post-receive",
83 Short: "Run git post-receive hook",
84 RunE: func(cmd *cobra.Command, args []string) error {
85 c, s, err := commonInit()
86 if err != nil {
87 return err
88 }
89 defer c.Close() //nolint:errcheck
90 defer s.Close() //nolint:errcheck
91 in, err := s.StdinPipe()
92 if err != nil {
93 return err
94 }
95 scanner := bufio.NewScanner(os.Stdin)
96 for scanner.Scan() {
97 in.Write([]byte(scanner.Text()))
98 in.Write([]byte("\n"))
99 }
100 in.Close() //nolint:errcheck
101 b, err := s.Output("internal hook post-receive")
102 if err != nil {
103 return err
104 }
105 cmd.Print(string(b))
106 return nil
107 },
108 }
109)
110
111func init() {
112 hookCmd.AddCommand(
113 preReceiveCmd,
114 updateCmd,
115 postReceiveCmd,
116 )
117 internalCmd.AddCommand(
118 hookCmd,
119 )
120}
121
122func commonInit() (c *gossh.Client, s *gossh.Session, err error) {
123 cfg := config.DefaultConfig()
124 // Git runs the hook within the repository's directory.
125 // Get the working directory to determine the repository name.
126 wd, err := os.Getwd()
127 if err != nil {
128 return
129 }
130 if !strings.HasPrefix(wd, cfg.RepoPath) {
131 err = fmt.Errorf("hook must be run from within repository directory")
132 return
133 }
134 repoName := strings.TrimPrefix(wd, cfg.RepoPath)
135 repoName = strings.TrimPrefix(repoName, fmt.Sprintf("%c", os.PathSeparator))
136 c, err = newClient(cfg)
137 if err != nil {
138 return
139 }
140 s, err = newSession(c)
141 if err != nil {
142 return
143 }
144 s.Setenv("SOFT_SERVE_REPO_NAME", repoName)
145 return
146}
147
148func newClient(cfg *config.Config) (*gossh.Client, error) {
149 // Only accept the server's host key.
150 pubKey, err := os.ReadFile(cfg.KeyPath)
151 if err != nil {
152 return nil, err
153 }
154 hostKey, err := gossh.ParsePrivateKey(pubKey)
155 if err != nil {
156 return nil, err
157 }
158 pemKey, err := os.ReadFile(cfg.InternalKeyPath)
159 if err != nil {
160 return nil, err
161 }
162 k, err := gossh.ParsePrivateKey(pemKey)
163 if err != nil {
164 return nil, err
165 }
166 cc := &gossh.ClientConfig{
167 User: "internal",
168 Auth: []gossh.AuthMethod{
169 gossh.PublicKeys(k),
170 },
171 HostKeyCallback: gossh.FixedHostKey(hostKey.PublicKey()),
172 }
173 addr := fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.Port)
174 c, err := gossh.Dial("tcp", addr, cc)
175 if err != nil {
176 return nil, err
177 }
178 return c, nil
179}
180
181func newSession(c *gossh.Client) (*gossh.Session, error) {
182 s, err := c.NewSession()
183 if err != nil {
184 return nil, err
185 }
186 return s, nil
187}