gofunc.go

  1package wasm
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"math"
  9	"reflect"
 10
 11	"github.com/tetratelabs/wazero/api"
 12)
 13
 14type paramsKind byte
 15
 16const (
 17	paramsKindNoContext paramsKind = iota
 18	paramsKindContext
 19	paramsKindContextModule
 20)
 21
 22// Below are reflection code to get the interface type used to parse functions and set values.
 23
 24var (
 25	moduleType    = reflect.TypeOf((*api.Module)(nil)).Elem()
 26	goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
 27	errorType     = reflect.TypeOf((*error)(nil)).Elem()
 28)
 29
 30// compile-time check to ensure reflectGoModuleFunction implements
 31// api.GoModuleFunction.
 32var _ api.GoModuleFunction = (*reflectGoModuleFunction)(nil)
 33
 34type reflectGoModuleFunction struct {
 35	fn              *reflect.Value
 36	params, results []ValueType
 37}
 38
 39// Call implements the same method as documented on api.GoModuleFunction.
 40func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, stack []uint64) {
 41	callGoFunc(ctx, mod, f.fn, stack)
 42}
 43
 44// EqualTo is exposed for testing.
 45func (f *reflectGoModuleFunction) EqualTo(that interface{}) bool {
 46	if f2, ok := that.(*reflectGoModuleFunction); !ok {
 47		return false
 48	} else {
 49		// TODO compare reflect pointers
 50		return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
 51	}
 52}
 53
 54// compile-time check to ensure reflectGoFunction implements api.GoFunction.
 55var _ api.GoFunction = (*reflectGoFunction)(nil)
 56
 57type reflectGoFunction struct {
 58	fn              *reflect.Value
 59	pk              paramsKind
 60	params, results []ValueType
 61}
 62
 63// EqualTo is exposed for testing.
 64func (f *reflectGoFunction) EqualTo(that interface{}) bool {
 65	if f2, ok := that.(*reflectGoFunction); !ok {
 66		return false
 67	} else {
 68		// TODO compare reflect pointers
 69		return f.pk == f2.pk &&
 70			bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
 71	}
 72}
 73
 74// Call implements the same method as documented on api.GoFunction.
 75func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
 76	if f.pk == paramsKindNoContext {
 77		ctx = nil
 78	}
 79	callGoFunc(ctx, nil, f.fn, stack)
 80}
 81
 82// callGoFunc executes the reflective function by converting params to Go
 83// types. The results of the function call are converted back to api.ValueType.
 84func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {
 85	tp := fn.Type()
 86
 87	var in []reflect.Value
 88	pLen := tp.NumIn()
 89	if pLen != 0 {
 90		in = make([]reflect.Value, pLen)
 91
 92		i := 0
 93		if ctx != nil {
 94			in[0] = newContextVal(ctx)
 95			i++
 96		}
 97		if mod != nil {
 98			in[1] = newModuleVal(mod)
 99			i++
100		}
101
102		for j := 0; i < pLen; i++ {
103			next := tp.In(i)
104			val := reflect.New(next).Elem()
105			k := next.Kind()
106			raw := stack[j]
107			j++
108
109			switch k {
110			case reflect.Float32:
111				val.SetFloat(float64(math.Float32frombits(uint32(raw))))
112			case reflect.Float64:
113				val.SetFloat(math.Float64frombits(raw))
114			case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
115				val.SetUint(raw)
116			case reflect.Int32, reflect.Int64:
117				val.SetInt(int64(raw))
118			default:
119				panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k))
120			}
121			in[i] = val
122		}
123	}
124
125	// Execute the host function and push back the call result onto the stack.
126	for i, ret := range fn.Call(in) {
127		switch ret.Kind() {
128		case reflect.Float32:
129			stack[i] = uint64(math.Float32bits(float32(ret.Float())))
130		case reflect.Float64:
131			stack[i] = math.Float64bits(ret.Float())
132		case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
133			stack[i] = ret.Uint()
134		case reflect.Int32, reflect.Int64:
135			stack[i] = uint64(ret.Int())
136		default:
137			panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind()))
138		}
139	}
140}
141
142func newContextVal(ctx context.Context) reflect.Value {
143	val := reflect.New(goContextType).Elem()
144	val.Set(reflect.ValueOf(ctx))
145	return val
146}
147
148func newModuleVal(m api.Module) reflect.Value {
149	val := reflect.New(moduleType).Elem()
150	val.Set(reflect.ValueOf(m))
151	return val
152}
153
154// MustParseGoReflectFuncCode parses Code from the go function or panics.
155//
156// Exposing this simplifies FunctionDefinition of host functions in built-in host
157// modules and tests.
158func MustParseGoReflectFuncCode(fn interface{}) Code {
159	_, _, code, err := parseGoReflectFunc(fn)
160	if err != nil {
161		panic(err)
162	}
163	return code
164}
165
166func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code Code, err error) {
167	fnV := reflect.ValueOf(fn)
168	p := fnV.Type()
169
170	if fnV.Kind() != reflect.Func {
171		err = fmt.Errorf("kind != func: %s", fnV.Kind().String())
172		return
173	}
174
175	pk, kindErr := kind(p)
176	if kindErr != nil {
177		err = kindErr
178		return
179	}
180
181	pOffset := 0
182	switch pk {
183	case paramsKindNoContext:
184	case paramsKindContext:
185		pOffset = 1
186	case paramsKindContextModule:
187		pOffset = 2
188	}
189
190	pCount := p.NumIn() - pOffset
191	if pCount > 0 {
192		params = make([]ValueType, pCount)
193	}
194	for i := 0; i < len(params); i++ {
195		pI := p.In(i + pOffset)
196		if t, ok := getTypeOf(pI.Kind()); ok {
197			params[i] = t
198			continue
199		}
200
201		// Now, we will definitely err, decide which message is best
202		var arg0Type reflect.Type
203		if hc := pI.Implements(moduleType); hc {
204			arg0Type = moduleType
205		} else if gc := pI.Implements(goContextType); gc {
206			arg0Type = goContextType
207		}
208
209		if arg0Type != nil {
210			err = fmt.Errorf("param[%d] is a %s, which may be defined only once as param[0]", i+pOffset, arg0Type)
211		} else {
212			err = fmt.Errorf("param[%d] is unsupported: %s", i+pOffset, pI.Kind())
213		}
214		return
215	}
216
217	rCount := p.NumOut()
218	if rCount > 0 {
219		results = make([]ValueType, rCount)
220	}
221	for i := 0; i < len(results); i++ {
222		rI := p.Out(i)
223		if t, ok := getTypeOf(rI.Kind()); ok {
224			results[i] = t
225			continue
226		}
227
228		// Now, we will definitely err, decide which message is best
229		if rI.Implements(errorType) {
230			err = fmt.Errorf("result[%d] is an error, which is unsupported", i)
231		} else {
232			err = fmt.Errorf("result[%d] is unsupported: %s", i, rI.Kind())
233		}
234		return
235	}
236
237	code = Code{}
238	if pk == paramsKindContextModule {
239		code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
240	} else {
241		code.GoFunc = &reflectGoFunction{pk: pk, fn: &fnV, params: params, results: results}
242	}
243	return
244}
245
246func kind(p reflect.Type) (paramsKind, error) {
247	pCount := p.NumIn()
248	if pCount > 0 && p.In(0).Kind() == reflect.Interface {
249		p0 := p.In(0)
250		if p0.Implements(moduleType) {
251			return 0, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
252		} else if p0.Implements(goContextType) {
253			if pCount >= 2 && p.In(1).Implements(moduleType) {
254				return paramsKindContextModule, nil
255			}
256			return paramsKindContext, nil
257		}
258	}
259	// Without context param allows portability with reflective runtimes.
260	// This allows people to more easily port to wazero.
261	return paramsKindNoContext, nil
262}
263
264func getTypeOf(kind reflect.Kind) (ValueType, bool) {
265	switch kind {
266	case reflect.Float64:
267		return ValueTypeF64, true
268	case reflect.Float32:
269		return ValueTypeF32, true
270	case reflect.Int32, reflect.Uint32:
271		return ValueTypeI32, true
272	case reflect.Int64, reflect.Uint64:
273		return ValueTypeI64, true
274	case reflect.Uintptr:
275		return ValueTypeExternref, true
276	default:
277		return 0x00, false
278	}
279}