script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/keygen"
 17	"github.com/charmbracelet/soft-serve/server"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/test"
 20	"github.com/rogpeppe/go-internal/testscript"
 21	"golang.org/x/crypto/ssh"
 22)
 23
 24var update = flag.Bool("update", false, "update script files")
 25
 26func TestScript(t *testing.T) {
 27	flag.Parse()
 28	var lock sync.Mutex
 29
 30	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 31		path := filepath.Join(t.TempDir(), name)
 32		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 33		if err != nil {
 34			t.Fatal(err)
 35		}
 36		return path, pair
 37	}
 38
 39	key, admin1 := mkkey("admin1")
 40	_, admin2 := mkkey("admin2")
 41	_, user1 := mkkey("user1")
 42
 43	testscript.Run(t, testscript.Params{
 44		Dir:           "./testdata/",
 45		UpdateScripts: *update,
 46		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 47			"soft":     cmdSoft(admin1.Signer()),
 48			"git":      cmdGit(key),
 49			"mkfile":   cmdMkfile,
 50			"dos2unix": cmdDos2Unix,
 51		},
 52		Setup: func(e *testscript.Env) error {
 53			sshPort := test.RandomPort()
 54			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 55			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 56			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 57			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 58			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
 59			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
 60			data := t.TempDir()
 61			cfg := config.Config{
 62				Name:             "Test Soft Serve",
 63				DataPath:         data,
 64				InitialAdminKeys: []string{admin1.AuthorizedKey()},
 65				SSH: config.SSHConfig{
 66					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
 67					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
 68					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
 69					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
 70				},
 71				Git: config.GitConfig{
 72					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
 73					IdleTimeout:    3,
 74					MaxConnections: 32,
 75				},
 76				HTTP: config.HTTPConfig{
 77					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 78					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
 79				},
 80				Stats: config.StatsConfig{
 81					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 82				},
 83				Log: config.LogConfig{
 84					Format:     "text",
 85					TimeFormat: time.DateTime,
 86				},
 87			}
 88			ctx := config.WithContext(context.Background(), &cfg)
 89
 90			// prevent race condition in lipgloss...
 91			// this will probably be autofixed when we start using the colors
 92			// from the ssh session instead of the server.
 93			// XXX: take another look at this soon
 94			lock.Lock()
 95			srv, err := server.NewServer(ctx)
 96			if err != nil {
 97				return err
 98			}
 99			lock.Unlock()
100
101			go func() {
102				if err := srv.Start(); err != nil {
103					e.T().Fatal(err)
104				}
105			}()
106
107			e.Defer(func() {
108				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
109				defer cancel()
110				if err := srv.Shutdown(ctx); err != nil {
111					e.T().Fatal(err)
112				}
113			})
114
115			// wait until the server is up
116			for {
117				conn, _ := net.DialTimeout(
118					"tcp",
119					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
120					time.Second,
121				)
122				if conn != nil {
123					conn.Close()
124					break
125				}
126			}
127
128			return nil
129		},
130	})
131}
132
133func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
134	return func(ts *testscript.TestScript, neg bool, args []string) {
135		cli, err := ssh.Dial(
136			"tcp",
137			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
138			&ssh.ClientConfig{
139				User:            "admin",
140				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
141				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
142			},
143		)
144		ts.Check(err)
145		defer cli.Close()
146
147		sess, err := cli.NewSession()
148		ts.Check(err)
149		defer sess.Close()
150
151		sess.Stdout = ts.Stdout()
152		sess.Stderr = ts.Stderr()
153
154		check(ts, sess.Run(strings.Join(args, " ")), neg)
155	}
156}
157
158// P.S. Windows sucks!
159func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
160	if neg {
161		ts.Fatalf("unsupported: ! dos2unix")
162	}
163	if len(args) < 1 {
164		ts.Fatalf("usage: dos2unix paths...")
165	}
166	for _, arg := range args {
167		filename := ts.MkAbs(arg)
168		data, err := os.ReadFile(filename)
169		if err != nil {
170			ts.Fatalf("%s: %v", filename, err)
171		}
172
173		// Replace all '\r\n' with '\n'.
174		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
175
176		if err := os.WriteFile(filename, data, 0o644); err != nil {
177			ts.Fatalf("%s: %v", filename, err)
178		}
179	}
180}
181
182var sshConfig = `
183Host *
184  UserKnownHostsFile %q
185  StrictHostKeyChecking no
186  IdentityAgent none
187  IdentitiesOnly yes
188  ServerAliveInterval 60
189`
190
191func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
192	return func(ts *testscript.TestScript, neg bool, args []string) {
193		ts.Check(os.WriteFile(
194			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
195			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
196			0o600,
197		))
198		sshArgs := []string{
199			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
200			"-i", filepath.ToSlash(key),
201		}
202		ts.Setenv(
203			"GIT_SSH_COMMAND",
204			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
205		)
206		args = append([]string{
207			"-c", "user.email=john@example.com",
208			"-c", "user.name=John Doe",
209		}, args...)
210		check(ts, ts.Exec("git", args...), neg)
211	}
212}
213
214func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
215	if len(args) < 2 {
216		ts.Fatalf("usage: mkfile path content")
217	}
218	check(ts, os.WriteFile(
219		ts.MkAbs(args[0]),
220		[]byte(strings.Join(args[1:], " ")),
221		0o644,
222	), neg)
223}
224
225func check(ts *testscript.TestScript, err error, neg bool) {
226	if neg && err == nil {
227		ts.Fatalf("expected error, got nil")
228	}
229	if !neg {
230		ts.Check(err)
231	}
232}