script_test.go

  1package testscript
  2
  3import (
  4	"context"
  5	"flag"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"testing"
 13	"time"
 14
 15	"github.com/charmbracelet/keygen"
 16	"github.com/charmbracelet/soft-serve/server"
 17	"github.com/charmbracelet/soft-serve/server/config"
 18	"github.com/charmbracelet/soft-serve/server/test"
 19	"github.com/rogpeppe/go-internal/testscript"
 20	"golang.org/x/crypto/ssh"
 21)
 22
 23var update = flag.Bool("update", false, "update script files")
 24
 25func TestScript(t *testing.T) {
 26	flag.Parse()
 27	var lock sync.Mutex
 28
 29	t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
 30
 31	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 32		path := filepath.Join(t.TempDir(), name)
 33		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 34		if err != nil {
 35			t.Fatal(err)
 36		}
 37		return path, pair
 38	}
 39
 40	key, admin1 := mkkey("admin1")
 41	_, admin2 := mkkey("admin2")
 42	_, user1 := mkkey("user1")
 43
 44	sshArgs := []string{
 45		"-F", "/dev/null",
 46		"-o", "StrictHostKeyChecking=no",
 47		"-o", "UserKnownHostsFile=/dev/null",
 48		"-o", "IdentityAgent=none",
 49		"-o", "IdentitiesOnly=yes",
 50		"-o", "ServerAliveInterval=60",
 51		"-i", key,
 52	}
 53
 54	check := func(ts *testscript.TestScript, err error, neg bool) {
 55		if neg && err == nil {
 56			ts.Fatalf("expected error, got nil")
 57		}
 58		if !neg {
 59			ts.Check(err)
 60		}
 61	}
 62
 63	testscript.Run(t, testscript.Params{
 64		Dir:           "./testdata/",
 65		UpdateScripts: *update,
 66		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 67			"soft": func(ts *testscript.TestScript, neg bool, args []string) {
 68				cli, err := ssh.Dial(
 69					"tcp",
 70					net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
 71					&ssh.ClientConfig{
 72						User: "admin",
 73						Auth: []ssh.AuthMethod{
 74							ssh.PublicKeys(admin1.Signer()),
 75						},
 76						HostKeyCallback: ssh.InsecureIgnoreHostKey(),
 77					},
 78				)
 79				ts.Check(err)
 80				defer cli.Close()
 81
 82				sess, err := cli.NewSession()
 83				ts.Check(err)
 84				defer sess.Close()
 85
 86				sess.Stdout = ts.Stdout()
 87				sess.Stderr = ts.Stderr()
 88
 89				check(ts, sess.Run(strings.Join(args, " ")), neg)
 90			},
 91			"git": func(ts *testscript.TestScript, neg bool, args []string) {
 92				ts.Setenv(
 93					"GIT_SSH_COMMAND",
 94					strings.Join(append([]string{"ssh"}, sshArgs...), " "),
 95				)
 96				args = append([]string{
 97					"-c", "user.email=john@example.com",
 98					"-c", "user.name=John Doe",
 99				}, args...)
100				check(ts, ts.Exec("git", args...), neg)
101			},
102			"mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
103				if len(args) != 1 {
104					ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
105				}
106				check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
107			},
108		},
109		Setup: func(e *testscript.Env) error {
110			sshPort := test.RandomPort()
111			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
112			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
113			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
114			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
115			data := t.TempDir()
116			cfg := config.Config{
117				Name:             "Test Soft Serve",
118				DataPath:         data,
119				InitialAdminKeys: []string{admin1.AuthorizedKey()},
120				SSH: config.SSHConfig{
121					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
122					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
123					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
124					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
125				},
126				Git: config.GitConfig{
127					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
128					IdleTimeout:    3,
129					MaxConnections: 32,
130				},
131				HTTP: config.HTTPConfig{
132					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
133					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
134				},
135				Stats: config.StatsConfig{
136					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
137				},
138				Log: config.LogConfig{
139					Format:     "text",
140					TimeFormat: time.DateTime,
141				},
142			}
143			ctx := config.WithContext(context.Background(), &cfg)
144
145			// prevent race condition in lipgloss...
146			// this will probably be autofixed when we start using the colors
147			// from the ssh session instead of the server.
148			// XXX: take another look at this soon
149			lock.Lock()
150			srv, err := server.NewServer(ctx)
151			if err != nil {
152				return err
153			}
154			lock.Unlock()
155
156			go func() {
157				if err := srv.Start(); err != nil {
158					e.T().Fatal(err)
159				}
160			}()
161
162			e.Defer(func() {
163				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
164				defer cancel()
165				if err := srv.Shutdown(ctx); err != nil {
166					e.T().Fatal(err)
167				}
168			})
169
170			// wait until the server is up
171			for {
172				conn, _ := net.DialTimeout(
173					"tcp",
174					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
175					time.Second,
176				)
177				if conn != nil {
178					conn.Close()
179					break
180				}
181			}
182
183			return nil
184		},
185	})
186}