@@ -3,6 +3,7 @@ package testscript
import (
"bytes"
"context"
+ "database/sql"
"flag"
"fmt"
"io"
@@ -30,7 +31,6 @@ import (
"github.com/rogpeppe/go-internal/testscript"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
- _ "modernc.org/sqlite" // sqlite Driver
)
var update = flag.Bool("update", false, "update script files")
@@ -101,6 +101,26 @@ func TestScript(t *testing.T) {
cfg.LFS.Enabled = true
cfg.LFS.SSHEnabled = true
+ dbDriver := os.Getenv("DB_DRIVER")
+ if dbDriver != "" {
+ cfg.DB.Driver = dbDriver
+ }
+
+ dbDsn := os.Getenv("DB_DATA_SOURCE")
+ if dbDsn != "" {
+ cfg.DB.DataSource = dbDsn
+ }
+
+ if cfg.DB.Driver == "postgres" {
+ err, cleanup := setupPostgres(e.T(), cfg)
+ if err != nil {
+ return err
+ }
+ if cleanup != nil {
+ e.Defer(cleanup)
+ }
+ }
+
if err := cfg.Validate(); err != nil {
return err
}
@@ -117,7 +137,6 @@ func TestScript(t *testing.T) {
defer f.Close() // nolint: errcheck
}
- // TODO: test postgres
dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
if err != nil {
return fmt.Errorf("open database: %w", err)
@@ -385,3 +404,68 @@ func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
check(ts, cmd.Execute(), neg)
}
+
+func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
+ // Indicates postgres
+ // Create a disposable database
+ dbName := fmt.Sprintf("softserve_test_%d", time.Now().UnixNano())
+ dbDsn := os.Getenv("DB_DATA_SOURCE")
+ if dbDsn == "" {
+ cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
+ }
+
+ dbUrl, err := url.Parse(cfg.DB.DataSource)
+ if err != nil {
+ return err, nil
+ }
+
+ connInfo := fmt.Sprintf("host=%s sslmode=disable", dbUrl.Hostname())
+ username := dbUrl.User.Username()
+ if username != "" {
+ connInfo += fmt.Sprintf(" user=%s", username)
+ password, ok := dbUrl.User.Password()
+ if ok {
+ username = fmt.Sprintf("%s:%s", username, password)
+ connInfo += fmt.Sprintf(" password=%s", password)
+ }
+ username = fmt.Sprintf("%s@", username)
+ } else {
+ connInfo += " user=postgres"
+ }
+
+ port := dbUrl.Port()
+ if port != "" {
+ connInfo += fmt.Sprintf(" port=%s", port)
+ port = fmt.Sprintf(":%s", port)
+ }
+
+ cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
+ dbUrl.Scheme,
+ username,
+ dbUrl.Hostname(),
+ port,
+ dbName,
+ )
+
+ // Create the database
+ db, err := sql.Open(cfg.DB.Driver, connInfo)
+ if err != nil {
+ return err, nil
+ }
+
+ if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil {
+ return err, nil
+ }
+
+ return nil, func() {
+ db, err := sql.Open(cfg.DB.Driver, connInfo)
+ if err != nil {
+ t.Log("failed to open database", dbName, err)
+ return
+ }
+
+ if _, err := db.Exec("DROP DATABASE " + dbName); err != nil {
+ t.Log("failed to drop database", dbName, err)
+ }
+ }
+}