1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/keygen"
 17	"github.com/charmbracelet/soft-serve/server"
 18	"github.com/charmbracelet/soft-serve/server/backend"
 19	"github.com/charmbracelet/soft-serve/server/config"
 20	"github.com/charmbracelet/soft-serve/server/db"
 21	"github.com/charmbracelet/soft-serve/server/db/migrate"
 22	"github.com/charmbracelet/soft-serve/server/store"
 23	"github.com/charmbracelet/soft-serve/server/store/database"
 24	"github.com/charmbracelet/soft-serve/server/test"
 25	"github.com/rogpeppe/go-internal/testscript"
 26	"golang.org/x/crypto/ssh"
 27	_ "modernc.org/sqlite" // sqlite Driver
 28)
 29
 30var update = flag.Bool("update", false, "update script files")
 31
 32func TestScript(t *testing.T) {
 33	flag.Parse()
 34	var lock sync.Mutex
 35
 36	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 37		path := filepath.Join(t.TempDir(), name)
 38		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 39		if err != nil {
 40			t.Fatal(err)
 41		}
 42		return path, pair
 43	}
 44
 45	key, admin1 := mkkey("admin1")
 46	_, admin2 := mkkey("admin2")
 47	_, user1 := mkkey("user1")
 48
 49	testscript.Run(t, testscript.Params{
 50		Dir:           "./testdata/",
 51		UpdateScripts: *update,
 52		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 53			"soft":     cmdSoft(admin1.Signer()),
 54			"usoft":    cmdSoft(user1.Signer()),
 55			"git":      cmdGit(key),
 56			"mkfile":   cmdMkfile,
 57			"readfile": cmdReadfile,
 58			"dos2unix": cmdDos2Unix,
 59		},
 60		Setup: func(e *testscript.Env) error {
 61			data := t.TempDir()
 62
 63			sshPort := test.RandomPort()
 64			sshListen := fmt.Sprintf("localhost:%d", sshPort)
 65			gitPort := test.RandomPort()
 66			gitListen := fmt.Sprintf("localhost:%d", gitPort)
 67			httpPort := test.RandomPort()
 68			httpListen := fmt.Sprintf("localhost:%d", httpPort)
 69			statsPort := test.RandomPort()
 70			statsListen := fmt.Sprintf("localhost:%d", statsPort)
 71			serverName := "Test Soft Serve"
 72
 73			e.Setenv("DATA_PATH", data)
 74			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 75			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 76			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 77			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 78			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
 79			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
 80
 81			cfg := config.DefaultConfig()
 82			cfg.DataPath = data
 83			cfg.Name = serverName
 84			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
 85			cfg.SSH.ListenAddr = sshListen
 86			cfg.SSH.PublicURL = "ssh://" + sshListen
 87			cfg.Git.ListenAddr = gitListen
 88			cfg.HTTP.ListenAddr = httpListen
 89			cfg.HTTP.PublicURL = "http://" + httpListen
 90			cfg.Stats.ListenAddr = statsListen
 91			cfg.DB.Driver = "sqlite"
 92
 93			if err := cfg.Validate(); err != nil {
 94				return err
 95			}
 96
 97			ctx := config.WithContext(context.Background(), cfg)
 98
 99			// TODO: test postgres
100			dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
101			if err != nil {
102				return fmt.Errorf("open database: %w", err)
103			}
104
105			if err := migrate.Migrate(ctx, dbx); err != nil {
106				return fmt.Errorf("migrate database: %w", err)
107			}
108
109			ctx = db.WithContext(ctx, dbx)
110			datastore := database.New(ctx, dbx)
111			ctx = store.WithContext(ctx, datastore)
112			be := backend.New(ctx, cfg, dbx)
113			ctx = backend.WithContext(ctx, be)
114
115			// prevent race condition in lipgloss...
116			// this will probably be autofixed when we start using the colors
117			// from the ssh session instead of the server.
118			// XXX: take another look at this soon
119			lock.Lock()
120			srv, err := server.NewServer(ctx)
121			if err != nil {
122				return err
123			}
124			lock.Unlock()
125
126			go func() {
127				if err := srv.Start(); err != nil {
128					e.T().Fatal(err)
129				}
130			}()
131
132			e.Defer(func() {
133				defer dbx.Close() // nolint: errcheck
134				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
135				defer cancel()
136				if err := srv.Shutdown(ctx); err != nil {
137					e.T().Fatal(err)
138				}
139			})
140
141			// wait until the server is up
142			for {
143				conn, _ := net.DialTimeout(
144					"tcp",
145					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
146					time.Second,
147				)
148				if conn != nil {
149					conn.Close()
150					break
151				}
152			}
153
154			return nil
155		},
156	})
157}
158
159func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
160	return func(ts *testscript.TestScript, neg bool, args []string) {
161		cli, err := ssh.Dial(
162			"tcp",
163			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
164			&ssh.ClientConfig{
165				User:            "admin",
166				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
167				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
168			},
169		)
170		ts.Check(err)
171		defer cli.Close()
172
173		sess, err := cli.NewSession()
174		ts.Check(err)
175		defer sess.Close()
176
177		sess.Stdout = ts.Stdout()
178		sess.Stderr = ts.Stderr()
179
180		check(ts, sess.Run(strings.Join(args, " ")), neg)
181	}
182}
183
184// P.S. Windows sucks!
185func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
186	if neg {
187		ts.Fatalf("unsupported: ! dos2unix")
188	}
189	if len(args) < 1 {
190		ts.Fatalf("usage: dos2unix paths...")
191	}
192	for _, arg := range args {
193		filename := ts.MkAbs(arg)
194		data, err := os.ReadFile(filename)
195		if err != nil {
196			ts.Fatalf("%s: %v", filename, err)
197		}
198
199		// Replace all '\r\n' with '\n'.
200		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
201
202		if err := os.WriteFile(filename, data, 0o644); err != nil {
203			ts.Fatalf("%s: %v", filename, err)
204		}
205	}
206}
207
208var sshConfig = `
209Host *
210  UserKnownHostsFile %q
211  StrictHostKeyChecking no
212  IdentityAgent none
213  IdentitiesOnly yes
214  ServerAliveInterval 60
215`
216
217func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
218	return func(ts *testscript.TestScript, neg bool, args []string) {
219		ts.Check(os.WriteFile(
220			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
221			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
222			0o600,
223		))
224		sshArgs := []string{
225			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
226			"-i", filepath.ToSlash(key),
227		}
228		ts.Setenv(
229			"GIT_SSH_COMMAND",
230			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
231		)
232		args = append([]string{
233			"-c", "user.email=john@example.com",
234			"-c", "user.name=John Doe",
235		}, args...)
236		check(ts, ts.Exec("git", args...), neg)
237	}
238}
239
240func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
241	if len(args) < 2 {
242		ts.Fatalf("usage: mkfile path content")
243	}
244	check(ts, os.WriteFile(
245		ts.MkAbs(args[0]),
246		[]byte(strings.Join(args[1:], " ")),
247		0o644,
248	), neg)
249}
250
251func check(ts *testscript.TestScript, err error, neg bool) {
252	if neg && err == nil {
253		ts.Fatalf("expected error, got nil")
254	}
255	if !neg {
256		ts.Check(err)
257	}
258}
259
260func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
261	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
262}