1package vfs
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "encoding/binary"
8 "strconv"
9
10 "github.com/tetratelabs/wazero/api"
11
12 "github.com/ncruces/go-sqlite3/internal/util"
13 "github.com/ncruces/go-sqlite3/util/sql3util"
14)
15
16func cksmWrapFile(name *Filename, flags OpenFlag, file File) File {
17 // Checksum only main databases and WALs.
18 if flags&(OPEN_MAIN_DB|OPEN_WAL) == 0 {
19 return file
20 }
21
22 cksm := cksmFile{File: file}
23
24 if flags&OPEN_WAL != 0 {
25 main, _ := name.DatabaseFile().(cksmFile)
26 cksm.cksmFlags = main.cksmFlags
27 } else {
28 cksm.cksmFlags = new(cksmFlags)
29 cksm.isDB = true
30 }
31
32 return cksm
33}
34
35type cksmFile struct {
36 File
37 *cksmFlags
38 isDB bool
39}
40
41type cksmFlags struct {
42 computeCksm bool
43 verifyCksm bool
44 inCkpt bool
45 pageSize int
46}
47
48func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
49 n, err = c.File.ReadAt(p, off)
50 p = p[:n]
51
52 if isHeader(c.isDB, p, off) {
53 c.init((*[100]byte)(p))
54 }
55
56 // Verify checksums.
57 if c.verifyCksm && !c.inCkpt && len(p) == c.pageSize {
58 cksm1 := cksmCompute(p[:len(p)-8])
59 cksm2 := *(*[8]byte)(p[len(p)-8:])
60 if cksm1 != cksm2 {
61 return 0, _IOERR_DATA
62 }
63 }
64 return n, err
65}
66
67func (c cksmFile) WriteAt(p []byte, off int64) (n int, err error) {
68 if isHeader(c.isDB, p, off) {
69 c.init((*[100]byte)(p))
70 }
71
72 // Compute checksums.
73 if c.computeCksm && !c.inCkpt && len(p) == c.pageSize {
74 *(*[8]byte)(p[len(p)-8:]) = cksmCompute(p[:len(p)-8])
75 }
76
77 return c.File.WriteAt(p, off)
78}
79
80func (c cksmFile) Pragma(name string, value string) (string, error) {
81 switch name {
82 case "checksum_verification":
83 b, ok := sql3util.ParseBool(value)
84 if ok {
85 c.verifyCksm = b && c.computeCksm
86 }
87 if !c.verifyCksm {
88 return "0", nil
89 }
90 return "1", nil
91
92 case "page_size":
93 if c.computeCksm {
94 // Do not allow page size changes on a checksum database.
95 return strconv.Itoa(c.pageSize), nil
96 }
97 }
98 return "", _NOTFOUND
99}
100
101func (c cksmFile) DeviceCharacteristics() DeviceCharacteristic {
102 ret := c.File.DeviceCharacteristics()
103 if c.verifyCksm {
104 ret &^= IOCAP_SUBPAGE_READ
105 }
106 return ret
107}
108
109func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode {
110 switch op {
111 case _FCNTL_CKPT_START:
112 c.inCkpt = true
113 case _FCNTL_CKPT_DONE:
114 c.inCkpt = false
115 case _FCNTL_PRAGMA:
116 rc := vfsFileControlImpl(ctx, mod, c, op, pArg)
117 if rc != _NOTFOUND {
118 return rc
119 }
120 }
121 return vfsFileControlImpl(ctx, mod, c.File, op, pArg)
122}
123
124func (f *cksmFlags) init(header *[100]byte) {
125 f.pageSize = 256 * int(binary.LittleEndian.Uint16(header[16:18]))
126 if r := header[20] == 8; r != f.computeCksm {
127 f.computeCksm = r
128 f.verifyCksm = r
129 }
130 if !sql3util.ValidPageSize(f.pageSize) {
131 f.computeCksm = false
132 f.verifyCksm = false
133 }
134}
135
136func isHeader(isDB bool, p []byte, off int64) bool {
137 check := sql3util.ValidPageSize(len(p))
138 if isDB {
139 check = off == 0 && len(p) >= 100
140 }
141 return check && bytes.HasPrefix(p, []byte("SQLite format 3\000"))
142}
143
144func cksmCompute(a []byte) (cksm [8]byte) {
145 var s1, s2 uint32
146 for len(a) >= 8 {
147 s1 += binary.LittleEndian.Uint32(a[0:4]) + s2
148 s2 += binary.LittleEndian.Uint32(a[4:8]) + s1
149 a = a[8:]
150 }
151 if len(a) != 0 {
152 panic(util.AssertErr())
153 }
154 binary.LittleEndian.PutUint32(cksm[0:4], s1)
155 binary.LittleEndian.PutUint32(cksm[4:8], s2)
156 return
157}
158
159func (c cksmFile) SharedMemory() SharedMemory {
160 if f, ok := c.File.(FileSharedMemory); ok {
161 return f.SharedMemory()
162 }
163 return nil
164}
165
166func (c cksmFile) Unwrap() File {
167 return c.File
168}