1package goose
2
3import (
4 "errors"
5 "fmt"
6 "path/filepath"
7)
8
9var (
10 registeredGoMigrations = make(map[int64]*Migration)
11)
12
13// ResetGlobalMigrations resets the global Go migrations registry.
14//
15// Not safe for concurrent use.
16func ResetGlobalMigrations() {
17 registeredGoMigrations = make(map[int64]*Migration)
18}
19
20// SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the
21// same version has already been registered. Go migrations must be constructed using the
22// [NewGoMigration] function.
23//
24// Not safe for concurrent use.
25func SetGlobalMigrations(migrations ...*Migration) error {
26 for _, m := range migrations {
27 if _, ok := registeredGoMigrations[m.Version]; ok {
28 return fmt.Errorf("go migration with version %d already registered", m.Version)
29 }
30 if err := checkGoMigration(m); err != nil {
31 return fmt.Errorf("invalid go migration: %w", err)
32 }
33 registeredGoMigrations[m.Version] = m
34 }
35 return nil
36}
37
38func checkGoMigration(m *Migration) error {
39 if !m.construct {
40 return errors.New("must use NewGoMigration to construct migrations")
41 }
42 if !m.Registered {
43 return errors.New("must be registered")
44 }
45 if m.Type != TypeGo {
46 return fmt.Errorf("type must be %q", TypeGo)
47 }
48 if m.Version < 1 {
49 return errors.New("version must be greater than zero")
50 }
51 if m.Source != "" {
52 if filepath.Ext(m.Source) != ".go" {
53 return fmt.Errorf("source must have .go extension: %q", m.Source)
54 }
55 // If the source is set, expect it to be a path with a numeric component that matches the
56 // version. This field is not intended to be used for descriptive purposes.
57 version, err := NumericComponent(m.Source)
58 if err != nil {
59 return fmt.Errorf("invalid source: %w", err)
60 }
61 if version != m.Version {
62 return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
63 }
64 }
65 if err := checkGoFunc(m.goUp); err != nil {
66 return fmt.Errorf("up function: %w", err)
67 }
68 if err := checkGoFunc(m.goDown); err != nil {
69 return fmt.Errorf("down function: %w", err)
70 }
71 if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
72 return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
73 }
74 if m.UpFn != nil && m.UpFnNoTx != nil {
75 return errors.New("must specify exactly one of UpFn or UpFnNoTx")
76 }
77 if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
78 return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
79 }
80 if m.DownFn != nil && m.DownFnNoTx != nil {
81 return errors.New("must specify exactly one of DownFn or DownFnNoTx")
82 }
83 return nil
84}
85
86func checkGoFunc(f *GoFunc) error {
87 if f.RunTx != nil && f.RunDB != nil {
88 return errors.New("must specify exactly one of RunTx or RunDB")
89 }
90 switch f.Mode {
91 case TransactionEnabled, TransactionDisabled:
92 // No functions, but mode is set. This is not an error. It means the user wants to
93 // record a version with the given mode but not run any functions.
94 default:
95 return fmt.Errorf("invalid mode: %d", f.Mode)
96 }
97 if f.RunDB != nil && f.Mode != TransactionDisabled {
98 return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
99 }
100 if f.RunTx != nil && f.Mode != TransactionEnabled {
101 return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
102 }
103 return nil
104}