1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/keygen"
 17	"github.com/charmbracelet/soft-serve/server"
 18	"github.com/charmbracelet/soft-serve/server/backend"
 19	"github.com/charmbracelet/soft-serve/server/config"
 20	"github.com/charmbracelet/soft-serve/server/db"
 21	"github.com/charmbracelet/soft-serve/server/db/migrate"
 22	"github.com/charmbracelet/soft-serve/server/test"
 23	"github.com/rogpeppe/go-internal/testscript"
 24	"golang.org/x/crypto/ssh"
 25	_ "modernc.org/sqlite" // sqlite Driver
 26)
 27
 28var update = flag.Bool("update", false, "update script files")
 29
 30func TestScript(t *testing.T) {
 31	flag.Parse()
 32	var lock sync.Mutex
 33
 34	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 35		path := filepath.Join(t.TempDir(), name)
 36		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 37		if err != nil {
 38			t.Fatal(err)
 39		}
 40		return path, pair
 41	}
 42
 43	key, admin1 := mkkey("admin1")
 44	_, admin2 := mkkey("admin2")
 45	_, user1 := mkkey("user1")
 46
 47	testscript.Run(t, testscript.Params{
 48		Dir:           "./testdata/",
 49		UpdateScripts: *update,
 50		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 51			"soft":     cmdSoft(admin1.Signer()),
 52			"usoft":    cmdSoft(user1.Signer()),
 53			"git":      cmdGit(key),
 54			"mkfile":   cmdMkfile,
 55			"readfile": cmdReadfile,
 56			"dos2unix": cmdDos2Unix,
 57		},
 58		Setup: func(e *testscript.Env) error {
 59			data := t.TempDir()
 60
 61			sshPort := test.RandomPort()
 62			sshListen := fmt.Sprintf("localhost:%d", sshPort)
 63			gitPort := test.RandomPort()
 64			gitListen := fmt.Sprintf("localhost:%d", gitPort)
 65			httpPort := test.RandomPort()
 66			httpListen := fmt.Sprintf("localhost:%d", httpPort)
 67			statsPort := test.RandomPort()
 68			statsListen := fmt.Sprintf("localhost:%d", statsPort)
 69			serverName := "Test Soft Serve"
 70
 71			e.Setenv("DATA_PATH", data)
 72			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 73			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 74			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 75			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 76			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
 77			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
 78
 79			cfg := config.DefaultConfig()
 80			cfg.DataPath = data
 81			cfg.Name = serverName
 82			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
 83			cfg.SSH.ListenAddr = sshListen
 84			cfg.SSH.PublicURL = "ssh://" + sshListen
 85			cfg.Git.ListenAddr = gitListen
 86			cfg.HTTP.ListenAddr = httpListen
 87			cfg.HTTP.PublicURL = "http://" + httpListen
 88			cfg.Stats.ListenAddr = statsListen
 89			cfg.DB.Driver = "sqlite"
 90
 91			if err := cfg.Validate(); err != nil {
 92				return err
 93			}
 94
 95			ctx := config.WithContext(context.Background(), cfg)
 96
 97			// TODO: test postgres
 98			dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
 99			if err != nil {
100				return fmt.Errorf("open database: %w", err)
101			}
102
103			if err := migrate.Migrate(ctx, dbx); err != nil {
104				return fmt.Errorf("migrate database: %w", err)
105			}
106
107			ctx = db.WithContext(ctx, dbx)
108			be := backend.New(ctx, cfg, dbx)
109			ctx = backend.WithContext(ctx, be)
110
111			// prevent race condition in lipgloss...
112			// this will probably be autofixed when we start using the colors
113			// from the ssh session instead of the server.
114			// XXX: take another look at this soon
115			lock.Lock()
116			srv, err := server.NewServer(ctx)
117			if err != nil {
118				return err
119			}
120			lock.Unlock()
121
122			go func() {
123				if err := srv.Start(); err != nil {
124					e.T().Fatal(err)
125				}
126			}()
127
128			e.Defer(func() {
129				defer dbx.Close() // nolint: errcheck
130				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
131				defer cancel()
132				if err := srv.Shutdown(ctx); err != nil {
133					e.T().Fatal(err)
134				}
135			})
136
137			// wait until the server is up
138			for {
139				conn, _ := net.DialTimeout(
140					"tcp",
141					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
142					time.Second,
143				)
144				if conn != nil {
145					conn.Close()
146					break
147				}
148			}
149
150			return nil
151		},
152	})
153}
154
155func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
156	return func(ts *testscript.TestScript, neg bool, args []string) {
157		cli, err := ssh.Dial(
158			"tcp",
159			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
160			&ssh.ClientConfig{
161				User:            "admin",
162				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
163				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
164			},
165		)
166		ts.Check(err)
167		defer cli.Close()
168
169		sess, err := cli.NewSession()
170		ts.Check(err)
171		defer sess.Close()
172
173		sess.Stdout = ts.Stdout()
174		sess.Stderr = ts.Stderr()
175
176		check(ts, sess.Run(strings.Join(args, " ")), neg)
177	}
178}
179
180// P.S. Windows sucks!
181func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
182	if neg {
183		ts.Fatalf("unsupported: ! dos2unix")
184	}
185	if len(args) < 1 {
186		ts.Fatalf("usage: dos2unix paths...")
187	}
188	for _, arg := range args {
189		filename := ts.MkAbs(arg)
190		data, err := os.ReadFile(filename)
191		if err != nil {
192			ts.Fatalf("%s: %v", filename, err)
193		}
194
195		// Replace all '\r\n' with '\n'.
196		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
197
198		if err := os.WriteFile(filename, data, 0o644); err != nil {
199			ts.Fatalf("%s: %v", filename, err)
200		}
201	}
202}
203
204var sshConfig = `
205Host *
206  UserKnownHostsFile %q
207  StrictHostKeyChecking no
208  IdentityAgent none
209  IdentitiesOnly yes
210  ServerAliveInterval 60
211`
212
213func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
214	return func(ts *testscript.TestScript, neg bool, args []string) {
215		ts.Check(os.WriteFile(
216			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
217			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
218			0o600,
219		))
220		sshArgs := []string{
221			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
222			"-i", filepath.ToSlash(key),
223		}
224		ts.Setenv(
225			"GIT_SSH_COMMAND",
226			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
227		)
228		args = append([]string{
229			"-c", "user.email=john@example.com",
230			"-c", "user.name=John Doe",
231		}, args...)
232		check(ts, ts.Exec("git", args...), neg)
233	}
234}
235
236func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
237	if len(args) < 2 {
238		ts.Fatalf("usage: mkfile path content")
239	}
240	check(ts, os.WriteFile(
241		ts.MkAbs(args[0]),
242		[]byte(strings.Join(args[1:], " ")),
243		0o644,
244	), neg)
245}
246
247func check(ts *testscript.TestScript, err error, neg bool) {
248	if neg && err == nil {
249		ts.Fatalf("expected error, got nil")
250	}
251	if !neg {
252		ts.Check(err)
253	}
254}
255
256func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
257	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
258}