script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io/ioutil"
  9	"net"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"sync"
 14	"testing"
 15	"time"
 16
 17	"github.com/charmbracelet/keygen"
 18	"github.com/charmbracelet/soft-serve/server"
 19	"github.com/charmbracelet/soft-serve/server/config"
 20	"github.com/charmbracelet/soft-serve/server/test"
 21	"github.com/rogpeppe/go-internal/testscript"
 22	"golang.org/x/crypto/ssh"
 23)
 24
 25var update = flag.Bool("update", false, "update script files")
 26
 27func TestScript(t *testing.T) {
 28	flag.Parse()
 29	var lock sync.Mutex
 30
 31	t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
 32
 33	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 34		path := filepath.Join(t.TempDir(), name)
 35		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 36		if err != nil {
 37			t.Fatal(err)
 38		}
 39		return path, pair
 40	}
 41
 42	key, admin1 := mkkey("admin1")
 43	_, admin2 := mkkey("admin2")
 44	_, user1 := mkkey("user1")
 45
 46	testscript.Run(t, testscript.Params{
 47		Dir:           "./testdata/",
 48		UpdateScripts: *update,
 49		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 50			"soft":     cmdSoft(admin1.Signer()),
 51			"git":      cmdGit(key),
 52			"mkreadme": cmdMkReadMe,
 53			"unix2dos": cmdUNIX2DOS,
 54		},
 55		Setup: func(e *testscript.Env) error {
 56			sshPort := test.RandomPort()
 57			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 58			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 59			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 60			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 61			data := t.TempDir()
 62			cfg := config.Config{
 63				Name:             "Test Soft Serve",
 64				DataPath:         data,
 65				InitialAdminKeys: []string{admin1.AuthorizedKey()},
 66				SSH: config.SSHConfig{
 67					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
 68					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
 69					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
 70					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
 71				},
 72				Git: config.GitConfig{
 73					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
 74					IdleTimeout:    3,
 75					MaxConnections: 32,
 76				},
 77				HTTP: config.HTTPConfig{
 78					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 79					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
 80				},
 81				Stats: config.StatsConfig{
 82					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 83				},
 84				Log: config.LogConfig{
 85					Format:     "text",
 86					TimeFormat: time.DateTime,
 87				},
 88			}
 89			ctx := config.WithContext(context.Background(), &cfg)
 90
 91			// prevent race condition in lipgloss...
 92			// this will probably be autofixed when we start using the colors
 93			// from the ssh session instead of the server.
 94			// XXX: take another look at this soon
 95			lock.Lock()
 96			srv, err := server.NewServer(ctx)
 97			if err != nil {
 98				return err
 99			}
100			lock.Unlock()
101
102			go func() {
103				if err := srv.Start(); err != nil {
104					e.T().Fatal(err)
105				}
106			}()
107
108			e.Defer(func() {
109				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
110				defer cancel()
111				if err := srv.Shutdown(ctx); err != nil {
112					e.T().Fatal(err)
113				}
114			})
115
116			// wait until the server is up
117			for {
118				conn, _ := net.DialTimeout(
119					"tcp",
120					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
121					time.Second,
122				)
123				if conn != nil {
124					conn.Close()
125					break
126				}
127			}
128
129			return nil
130		},
131	})
132}
133
134func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
135	return func(ts *testscript.TestScript, neg bool, args []string) {
136		cli, err := ssh.Dial(
137			"tcp",
138			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
139			&ssh.ClientConfig{
140				User:            "admin",
141				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
142				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
143			},
144		)
145		ts.Check(err)
146		defer cli.Close()
147
148		sess, err := cli.NewSession()
149		ts.Check(err)
150		defer sess.Close()
151
152		sess.Stdout = ts.Stdout()
153		sess.Stderr = ts.Stderr()
154
155		check(ts, sess.Run(strings.Join(args, " ")), neg)
156	}
157}
158
159// cmdUNIX2DOS converts files from UNIX line endings to DOS line endings.
160func cmdUNIX2DOS(ts *testscript.TestScript, neg bool, args []string) {
161	if neg {
162		ts.Fatalf("unsupported: ! unix2dos")
163	}
164	if len(args) < 1 {
165		ts.Fatalf("usage: unix2dos paths...")
166	}
167	for _, arg := range args {
168		filename := ts.MkAbs(arg)
169		data, err := ioutil.ReadFile(filename)
170		if err != nil {
171			ts.Fatalf("%s: %v", filename, err)
172		}
173		data = bytes.Join(bytes.Split(data, []byte{'\n'}), []byte{'\r', '\n'})
174		//nolint:gosec
175		if err := ioutil.WriteFile(filename, data, 0o666); err != nil {
176			ts.Fatalf("%s: %v", filename, err)
177		}
178	}
179}
180
181func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
182	return func(ts *testscript.TestScript, neg bool, args []string) {
183		sshArgs := []string{
184			"-F", "/dev/null",
185			"-o", "StrictHostKeyChecking=no",
186			"-o", "UserKnownHostsFile=/dev/null",
187			"-o", "IdentityAgent=none",
188			"-o", "IdentitiesOnly=yes",
189			"-o", "ServerAliveInterval=60",
190			"-i", key,
191		}
192		ts.Setenv(
193			"GIT_SSH_COMMAND",
194			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
195		)
196		args = append([]string{
197			"-c", "user.email=john@example.com",
198			"-c", "user.name=John Doe",
199		}, args...)
200		check(ts, ts.Exec("git", args...), neg)
201	}
202}
203
204func cmdMkReadMe(ts *testscript.TestScript, neg bool, args []string) {
205	if len(args) != 1 {
206		ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
207	}
208	check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
209}
210
211func check(ts *testscript.TestScript, err error, neg bool) {
212	if neg && err == nil {
213		ts.Fatalf("expected error, got nil")
214	}
215	if !neg {
216		ts.Check(err)
217	}
218}