quote.go

  1package sqlite3
  2
  3import (
  4	"bytes"
  5	"math"
  6	"reflect"
  7	"strconv"
  8	"strings"
  9	"time"
 10	"unsafe"
 11
 12	"github.com/ncruces/go-sqlite3/internal/util"
 13)
 14
 15// Quote escapes and quotes a value
 16// making it safe to embed in SQL text.
 17// Strings with embedded NUL characters are truncated.
 18//
 19// https://sqlite.org/lang_corefunc.html#quote
 20func Quote(value any) string {
 21	switch v := value.(type) {
 22	case nil:
 23		return "NULL"
 24	case bool:
 25		if v {
 26			return "1"
 27		} else {
 28			return "0"
 29		}
 30
 31	case int:
 32		return strconv.Itoa(v)
 33	case int64:
 34		return strconv.FormatInt(v, 10)
 35	case float64:
 36		switch {
 37		case math.IsNaN(v):
 38			return "NULL"
 39		case math.IsInf(v, 1):
 40			return "9.0e999"
 41		case math.IsInf(v, -1):
 42			return "-9.0e999"
 43		}
 44		return strconv.FormatFloat(v, 'g', -1, 64)
 45	case time.Time:
 46		return "'" + v.Format(time.RFC3339Nano) + "'"
 47
 48	case string:
 49		if i := strings.IndexByte(v, 0); i >= 0 {
 50			v = v[:i]
 51		}
 52
 53		buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
 54		buf[0] = '\''
 55		i := 1
 56		for _, b := range []byte(v) {
 57			if b == '\'' {
 58				buf[i] = b
 59				i += 1
 60			}
 61			buf[i] = b
 62			i += 1
 63		}
 64		buf[len(buf)-1] = '\''
 65		return unsafe.String(&buf[0], len(buf))
 66
 67	case []byte:
 68		buf := make([]byte, 3+2*len(v))
 69		buf[1] = '\''
 70		buf[0] = 'x'
 71		i := 2
 72		for _, b := range v {
 73			const hex = "0123456789ABCDEF"
 74			buf[i+0] = hex[b/16]
 75			buf[i+1] = hex[b%16]
 76			i += 2
 77		}
 78		buf[len(buf)-1] = '\''
 79		return unsafe.String(&buf[0], len(buf))
 80
 81	case ZeroBlob:
 82		buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
 83		buf[1] = '\''
 84		buf[0] = 'x'
 85		buf[len(buf)-1] = '\''
 86		return unsafe.String(&buf[0], len(buf))
 87	}
 88
 89	v := reflect.ValueOf(value)
 90	k := v.Kind()
 91
 92	if k == reflect.Interface || k == reflect.Pointer {
 93		if v.IsNil() {
 94			return "NULL"
 95		}
 96		v = v.Elem()
 97		k = v.Kind()
 98	}
 99
100	switch {
101	case v.CanInt():
102		return strconv.FormatInt(v.Int(), 10)
103	case v.CanUint():
104		return strconv.FormatUint(v.Uint(), 10)
105	case v.CanFloat():
106		return Quote(v.Float())
107	case k == reflect.Bool:
108		return Quote(v.Bool())
109	case k == reflect.String:
110		return Quote(v.String())
111	case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
112		v.Type().Elem().Kind() == reflect.Uint8:
113		return Quote(v.Bytes())
114	}
115
116	panic(util.ValueErr)
117}
118
119// QuoteIdentifier escapes and quotes an identifier
120// making it safe to embed in SQL text.
121// Strings with embedded NUL characters panic.
122func QuoteIdentifier(id string) string {
123	if strings.IndexByte(id, 0) >= 0 {
124		panic(util.ValueErr)
125	}
126
127	buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
128	buf[0] = '"'
129	i := 1
130	for _, b := range []byte(id) {
131		if b == '"' {
132			buf[i] = b
133			i += 1
134		}
135		buf[i] = b
136		i += 1
137	}
138	buf[len(buf)-1] = '"'
139	return unsafe.String(&buf[0], len(buf))
140}