1package sqlite3
2
3import (
4 "bytes"
5 "math"
6 "reflect"
7 "strconv"
8 "strings"
9 "time"
10 "unsafe"
11
12 "github.com/ncruces/go-sqlite3/internal/util"
13)
14
15// Quote escapes and quotes a value
16// making it safe to embed in SQL text.
17// Strings with embedded NUL characters are truncated.
18//
19// https://sqlite.org/lang_corefunc.html#quote
20func Quote(value any) string {
21 switch v := value.(type) {
22 case nil:
23 return "NULL"
24 case bool:
25 if v {
26 return "1"
27 } else {
28 return "0"
29 }
30
31 case int:
32 return strconv.Itoa(v)
33 case int64:
34 return strconv.FormatInt(v, 10)
35 case float64:
36 switch {
37 case math.IsNaN(v):
38 return "NULL"
39 case math.IsInf(v, 1):
40 return "9.0e999"
41 case math.IsInf(v, -1):
42 return "-9.0e999"
43 }
44 return strconv.FormatFloat(v, 'g', -1, 64)
45 case time.Time:
46 return "'" + v.Format(time.RFC3339Nano) + "'"
47
48 case string:
49 if i := strings.IndexByte(v, 0); i >= 0 {
50 v = v[:i]
51 }
52
53 buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
54 buf[0] = '\''
55 i := 1
56 for _, b := range []byte(v) {
57 if b == '\'' {
58 buf[i] = b
59 i += 1
60 }
61 buf[i] = b
62 i += 1
63 }
64 buf[len(buf)-1] = '\''
65 return unsafe.String(&buf[0], len(buf))
66
67 case []byte:
68 buf := make([]byte, 3+2*len(v))
69 buf[1] = '\''
70 buf[0] = 'x'
71 i := 2
72 for _, b := range v {
73 const hex = "0123456789ABCDEF"
74 buf[i+0] = hex[b/16]
75 buf[i+1] = hex[b%16]
76 i += 2
77 }
78 buf[len(buf)-1] = '\''
79 return unsafe.String(&buf[0], len(buf))
80
81 case ZeroBlob:
82 buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
83 buf[1] = '\''
84 buf[0] = 'x'
85 buf[len(buf)-1] = '\''
86 return unsafe.String(&buf[0], len(buf))
87 }
88
89 v := reflect.ValueOf(value)
90 k := v.Kind()
91
92 if k == reflect.Interface || k == reflect.Pointer {
93 if v.IsNil() {
94 return "NULL"
95 }
96 v = v.Elem()
97 k = v.Kind()
98 }
99
100 switch {
101 case v.CanInt():
102 return strconv.FormatInt(v.Int(), 10)
103 case v.CanUint():
104 return strconv.FormatUint(v.Uint(), 10)
105 case v.CanFloat():
106 return Quote(v.Float())
107 case k == reflect.Bool:
108 return Quote(v.Bool())
109 case k == reflect.String:
110 return Quote(v.String())
111 case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
112 v.Type().Elem().Kind() == reflect.Uint8:
113 return Quote(v.Bytes())
114 }
115
116 panic(util.ValueErr)
117}
118
119// QuoteIdentifier escapes and quotes an identifier
120// making it safe to embed in SQL text.
121// Strings with embedded NUL characters panic.
122func QuoteIdentifier(id string) string {
123 if strings.IndexByte(id, 0) >= 0 {
124 panic(util.ValueErr)
125 }
126
127 buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
128 buf[0] = '"'
129 i := 1
130 for _, b := range []byte(id) {
131 if b == '"' {
132 buf[i] = b
133 i += 1
134 }
135 buf[i] = b
136 i += 1
137 }
138 buf[len(buf)-1] = '"'
139 return unsafe.String(&buf[0], len(buf))
140}