script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/keygen"
 17	"github.com/charmbracelet/log"
 18	"github.com/charmbracelet/soft-serve/server"
 19	"github.com/charmbracelet/soft-serve/server/backend"
 20	"github.com/charmbracelet/soft-serve/server/config"
 21	"github.com/charmbracelet/soft-serve/server/db"
 22	"github.com/charmbracelet/soft-serve/server/db/migrate"
 23	logr "github.com/charmbracelet/soft-serve/server/log"
 24	"github.com/charmbracelet/soft-serve/server/store"
 25	"github.com/charmbracelet/soft-serve/server/store/database"
 26	"github.com/charmbracelet/soft-serve/server/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"golang.org/x/crypto/ssh"
 29	_ "modernc.org/sqlite" // sqlite Driver
 30)
 31
 32var update = flag.Bool("update", false, "update script files")
 33
 34func TestScript(t *testing.T) {
 35	flag.Parse()
 36	var lock sync.Mutex
 37
 38	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 39		path := filepath.Join(t.TempDir(), name)
 40		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 41		if err != nil {
 42			t.Fatal(err)
 43		}
 44		return path, pair
 45	}
 46
 47	key, admin1 := mkkey("admin1")
 48	_, admin2 := mkkey("admin2")
 49	_, user1 := mkkey("user1")
 50
 51	testscript.Run(t, testscript.Params{
 52		Dir:           "./testdata/",
 53		UpdateScripts: *update,
 54		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 55			"soft":     cmdSoft(admin1.Signer()),
 56			"usoft":    cmdSoft(user1.Signer()),
 57			"git":      cmdGit(key),
 58			"mkfile":   cmdMkfile,
 59			"envfile":  cmdEnvfile,
 60			"readfile": cmdReadfile,
 61			"dos2unix": cmdDos2Unix,
 62		},
 63		Setup: func(e *testscript.Env) error {
 64			data := t.TempDir()
 65
 66			sshPort := test.RandomPort()
 67			sshListen := fmt.Sprintf("localhost:%d", sshPort)
 68			gitPort := test.RandomPort()
 69			gitListen := fmt.Sprintf("localhost:%d", gitPort)
 70			httpPort := test.RandomPort()
 71			httpListen := fmt.Sprintf("localhost:%d", httpPort)
 72			statsPort := test.RandomPort()
 73			statsListen := fmt.Sprintf("localhost:%d", statsPort)
 74			serverName := "Test Soft Serve"
 75
 76			e.Setenv("DATA_PATH", data)
 77			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 78			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
 79			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 80			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 81			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 82			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
 83			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
 84
 85			cfg := config.DefaultConfig()
 86			cfg.DataPath = data
 87			cfg.Name = serverName
 88			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
 89			cfg.SSH.ListenAddr = sshListen
 90			cfg.SSH.PublicURL = "ssh://" + sshListen
 91			cfg.Git.ListenAddr = gitListen
 92			cfg.HTTP.ListenAddr = httpListen
 93			cfg.HTTP.PublicURL = "http://" + httpListen
 94			cfg.Stats.ListenAddr = statsListen
 95			cfg.DB.Driver = "sqlite"
 96			cfg.LFS.Enabled = true
 97			// TODO: run tests with both SSH enabled/disabled
 98			cfg.LFS.SSHEnabled = false
 99
100			if err := cfg.Validate(); err != nil {
101				return err
102			}
103
104			ctx := config.WithContext(context.Background(), cfg)
105
106			logger, f, err := logr.NewLogger(cfg)
107			if err != nil {
108				log.Errorf("failed to create logger: %v", err)
109			}
110
111			ctx = log.WithContext(ctx, logger)
112			if f != nil {
113				defer f.Close() // nolint: errcheck
114			}
115
116			// TODO: test postgres
117			dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
118			if err != nil {
119				return fmt.Errorf("open database: %w", err)
120			}
121
122			if err := migrate.Migrate(ctx, dbx); err != nil {
123				return fmt.Errorf("migrate database: %w", err)
124			}
125
126			ctx = db.WithContext(ctx, dbx)
127			datastore := database.New(ctx, dbx)
128			ctx = store.WithContext(ctx, datastore)
129			be := backend.New(ctx, cfg, dbx)
130			ctx = backend.WithContext(ctx, be)
131
132			// prevent race condition in lipgloss...
133			// this will probably be autofixed when we start using the colors
134			// from the ssh session instead of the server.
135			// XXX: take another look at this soon
136			lock.Lock()
137			srv, err := server.NewServer(ctx)
138			if err != nil {
139				return err
140			}
141			lock.Unlock()
142
143			go func() {
144				if err := srv.Start(); err != nil {
145					e.T().Fatal(err)
146				}
147			}()
148
149			e.Defer(func() {
150				defer dbx.Close() // nolint: errcheck
151				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
152				defer cancel()
153				if err := srv.Shutdown(ctx); err != nil {
154					e.T().Fatal(err)
155				}
156			})
157
158			// wait until the server is up
159			for {
160				conn, _ := net.DialTimeout(
161					"tcp",
162					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
163					time.Second,
164				)
165				if conn != nil {
166					conn.Close()
167					break
168				}
169			}
170
171			return nil
172		},
173	})
174}
175
176func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
177	return func(ts *testscript.TestScript, neg bool, args []string) {
178		cli, err := ssh.Dial(
179			"tcp",
180			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
181			&ssh.ClientConfig{
182				User:            "admin",
183				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
184				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
185			},
186		)
187		ts.Check(err)
188		defer cli.Close()
189
190		sess, err := cli.NewSession()
191		ts.Check(err)
192		defer sess.Close()
193
194		sess.Stdout = ts.Stdout()
195		sess.Stderr = ts.Stderr()
196
197		check(ts, sess.Run(strings.Join(args, " ")), neg)
198	}
199}
200
201// P.S. Windows sucks!
202func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
203	if neg {
204		ts.Fatalf("unsupported: ! dos2unix")
205	}
206	if len(args) < 1 {
207		ts.Fatalf("usage: dos2unix paths...")
208	}
209	for _, arg := range args {
210		filename := ts.MkAbs(arg)
211		data, err := os.ReadFile(filename)
212		if err != nil {
213			ts.Fatalf("%s: %v", filename, err)
214		}
215
216		// Replace all '\r\n' with '\n'.
217		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
218
219		if err := os.WriteFile(filename, data, 0o644); err != nil {
220			ts.Fatalf("%s: %v", filename, err)
221		}
222	}
223}
224
225var sshConfig = `
226Host *
227  UserKnownHostsFile %q
228  StrictHostKeyChecking no
229  IdentityAgent none
230  IdentitiesOnly yes
231  ServerAliveInterval 60
232`
233
234func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
235	return func(ts *testscript.TestScript, neg bool, args []string) {
236		ts.Check(os.WriteFile(
237			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
238			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
239			0o600,
240		))
241		sshArgs := []string{
242			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
243			"-i", filepath.ToSlash(key),
244		}
245		ts.Setenv(
246			"GIT_SSH_COMMAND",
247			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
248		)
249		// Disable git prompting for credentials.
250		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
251		args = append([]string{
252			"-c", "user.email=john@example.com",
253			"-c", "user.name=John Doe",
254		}, args...)
255		check(ts, ts.Exec("git", args...), neg)
256	}
257}
258
259func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
260	if len(args) < 2 {
261		ts.Fatalf("usage: mkfile path content")
262	}
263	check(ts, os.WriteFile(
264		ts.MkAbs(args[0]),
265		[]byte(strings.Join(args[1:], " ")),
266		0o644,
267	), neg)
268}
269
270func check(ts *testscript.TestScript, err error, neg bool) {
271	if neg && err == nil {
272		ts.Fatalf("expected error, got nil")
273	}
274	if !neg {
275		ts.Check(err)
276	}
277}
278
279func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
280	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
281}
282
283func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
284	if len(args) < 1 {
285		ts.Fatalf("usage: envfile key=file...")
286	}
287
288	for _, arg := range args {
289		parts := strings.SplitN(arg, "=", 2)
290		if len(parts) != 2 {
291			ts.Fatalf("usage: envfile key=file...")
292		}
293		key := parts[0]
294		file := parts[1]
295		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
296	}
297}