@@ -178,13 +178,8 @@ func (db *DB) Migrate(ctx context.Context) error {
if !executedMigrations[migrationNumber] {
slog.Info("running migration", "file", migration, "number", migrationNumber)
- if err := db.executeMigration(ctx, migration); err != nil {
- return fmt.Errorf("failed to execute migration %s: %w", migration, err)
- }
-
- err = db.pool.Exec(ctx, "INSERT INTO migrations (migration_number, migration_name) VALUES (?, ?)", migrationNumber, migration)
- if err != nil {
- return fmt.Errorf("failed to record migration %s in migrations table: %w", migration, err)
+ if err := db.runMigration(ctx, migration, migrationNumber); err != nil {
+ return err
}
}
}
@@ -192,18 +187,25 @@ func (db *DB) Migrate(ctx context.Context) error {
return nil
}
-// executeMigration executes a single migration file
-func (db *DB) executeMigration(ctx context.Context, filename string) error {
+// runMigration executes a single migration file within a transaction,
+// including recording it in the migrations table.
+func (db *DB) runMigration(ctx context.Context, filename string, migrationNumber int) error {
content, err := schemaFS.ReadFile("schema/" + filename)
if err != nil {
return fmt.Errorf("failed to read migration file %s: %w", filename, err)
}
- if err := db.pool.Exec(ctx, string(content)); err != nil {
- return fmt.Errorf("failed to execute migration %s: %w", filename, err)
- }
+ return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ if _, err := tx.Exec(string(content)); err != nil {
+ return fmt.Errorf("failed to execute migration %s: %w", filename, err)
+ }
- return nil
+ if _, err := tx.Exec("INSERT INTO migrations (migration_number, migration_name) VALUES (?, ?)", migrationNumber, filename); err != nil {
+ return fmt.Errorf("failed to record migration %s in migrations table: %w", filename, err)
+ }
+
+ return nil
+ })
}
// Pool returns the underlying connection pool for advanced operations