dns_resolver.go

  1/*
  2 *
  3 * Copyright 2018 gRPC authors.
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     http://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// Package dns implements a dns resolver to be installed as the default resolver
 20// in grpc.
 21package dns
 22
 23import (
 24	"context"
 25	"encoding/json"
 26	"fmt"
 27	rand "math/rand/v2"
 28	"net"
 29	"net/netip"
 30	"os"
 31	"strconv"
 32	"strings"
 33	"sync"
 34	"time"
 35
 36	grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
 37	"google.golang.org/grpc/grpclog"
 38	"google.golang.org/grpc/internal/backoff"
 39	"google.golang.org/grpc/internal/envconfig"
 40	"google.golang.org/grpc/internal/resolver/dns/internal"
 41	"google.golang.org/grpc/resolver"
 42	"google.golang.org/grpc/serviceconfig"
 43)
 44
 45var (
 46	// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
 47	// addresses from SRV records.  Must not be changed after init time.
 48	EnableSRVLookups = false
 49
 50	// MinResolutionInterval is the minimum interval at which re-resolutions are
 51	// allowed. This helps to prevent excessive re-resolution.
 52	MinResolutionInterval = 30 * time.Second
 53
 54	// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
 55	// If the timeout expires before a response is received, the request will be canceled.
 56	//
 57	// It is recommended to set this value at application startup. Avoid modifying this variable
 58	// after initialization as it's not thread-safe for concurrent modification.
 59	ResolvingTimeout = 30 * time.Second
 60
 61	logger = grpclog.Component("dns")
 62)
 63
 64func init() {
 65	resolver.Register(NewBuilder())
 66	internal.TimeAfterFunc = time.After
 67	internal.TimeNowFunc = time.Now
 68	internal.TimeUntilFunc = time.Until
 69	internal.NewNetResolver = newNetResolver
 70	internal.AddressDialer = addressDialer
 71}
 72
 73const (
 74	defaultPort       = "443"
 75	defaultDNSSvrPort = "53"
 76	golang            = "GO"
 77	// txtPrefix is the prefix string to be prepended to the host name for txt
 78	// record lookup.
 79	txtPrefix = "_grpc_config."
 80	// In DNS, service config is encoded in a TXT record via the mechanism
 81	// described in RFC-1464 using the attribute name grpc_config.
 82	txtAttribute = "grpc_config="
 83)
 84
 85var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
 86	return func(ctx context.Context, network, _ string) (net.Conn, error) {
 87		var dialer net.Dialer
 88		return dialer.DialContext(ctx, network, address)
 89	}
 90}
 91
 92var newNetResolver = func(authority string) (internal.NetResolver, error) {
 93	if authority == "" {
 94		return net.DefaultResolver, nil
 95	}
 96
 97	host, port, err := parseTarget(authority, defaultDNSSvrPort)
 98	if err != nil {
 99		return nil, err
100	}
101
102	authorityWithPort := net.JoinHostPort(host, port)
103
104	return &net.Resolver{
105		PreferGo: true,
106		Dial:     internal.AddressDialer(authorityWithPort),
107	}, nil
108}
109
110// NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
111func NewBuilder() resolver.Builder {
112	return &dnsBuilder{}
113}
114
115type dnsBuilder struct{}
116
117// Build creates and starts a DNS resolver that watches the name resolution of
118// the target.
119func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
120	host, port, err := parseTarget(target.Endpoint(), defaultPort)
121	if err != nil {
122		return nil, err
123	}
124
125	// IP address.
126	if ipAddr, err := formatIP(host); err == nil {
127		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
128		cc.UpdateState(resolver.State{Addresses: addr})
129		return deadResolver{}, nil
130	}
131
132	// DNS address (non-IP).
133	ctx, cancel := context.WithCancel(context.Background())
134	d := &dnsResolver{
135		host:                 host,
136		port:                 port,
137		ctx:                  ctx,
138		cancel:               cancel,
139		cc:                   cc,
140		rn:                   make(chan struct{}, 1),
141		disableServiceConfig: opts.DisableServiceConfig,
142	}
143
144	d.resolver, err = internal.NewNetResolver(target.URL.Host)
145	if err != nil {
146		return nil, err
147	}
148
149	d.wg.Add(1)
150	go d.watcher()
151	return d, nil
152}
153
154// Scheme returns the naming scheme of this resolver builder, which is "dns".
155func (b *dnsBuilder) Scheme() string {
156	return "dns"
157}
158
159// deadResolver is a resolver that does nothing.
160type deadResolver struct{}
161
162func (deadResolver) ResolveNow(resolver.ResolveNowOptions) {}
163
164func (deadResolver) Close() {}
165
166// dnsResolver watches for the name resolution update for a non-IP target.
167type dnsResolver struct {
168	host     string
169	port     string
170	resolver internal.NetResolver
171	ctx      context.Context
172	cancel   context.CancelFunc
173	cc       resolver.ClientConn
174	// rn channel is used by ResolveNow() to force an immediate resolution of the
175	// target.
176	rn chan struct{}
177	// wg is used to enforce Close() to return after the watcher() goroutine has
178	// finished. Otherwise, data race will be possible. [Race Example] in
179	// dns_resolver_test we replace the real lookup functions with mocked ones to
180	// facilitate testing. If Close() doesn't wait for watcher() goroutine
181	// finishes, race detector sometimes will warn lookup (READ the lookup
182	// function pointers) inside watcher() goroutine has data race with
183	// replaceNetFunc (WRITE the lookup function pointers).
184	wg                   sync.WaitGroup
185	disableServiceConfig bool
186}
187
188// ResolveNow invoke an immediate resolution of the target that this
189// dnsResolver watches.
190func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
191	select {
192	case d.rn <- struct{}{}:
193	default:
194	}
195}
196
197// Close closes the dnsResolver.
198func (d *dnsResolver) Close() {
199	d.cancel()
200	d.wg.Wait()
201}
202
203func (d *dnsResolver) watcher() {
204	defer d.wg.Done()
205	backoffIndex := 1
206	for {
207		state, err := d.lookup()
208		if err != nil {
209			// Report error to the underlying grpc.ClientConn.
210			d.cc.ReportError(err)
211		} else {
212			err = d.cc.UpdateState(*state)
213		}
214
215		var nextResolutionTime time.Time
216		if err == nil {
217			// Success resolving, wait for the next ResolveNow. However, also wait 30
218			// seconds at the very least to prevent constantly re-resolving.
219			backoffIndex = 1
220			nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
221			select {
222			case <-d.ctx.Done():
223				return
224			case <-d.rn:
225			}
226		} else {
227			// Poll on an error found in DNS Resolver or an error received from
228			// ClientConn.
229			nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
230			backoffIndex++
231		}
232		select {
233		case <-d.ctx.Done():
234			return
235		case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
236		}
237	}
238}
239
240func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
241	// Skip this particular host to avoid timeouts with some versions of
242	// systemd-resolved.
243	if !EnableSRVLookups || d.host == "metadata.google.internal." {
244		return nil, nil
245	}
246	var newAddrs []resolver.Address
247	_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
248	if err != nil {
249		err = handleDNSError(err, "SRV") // may become nil
250		return nil, err
251	}
252	for _, s := range srvs {
253		lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
254		if err != nil {
255			err = handleDNSError(err, "A") // may become nil
256			if err == nil {
257				// If there are other SRV records, look them up and ignore this
258				// one that does not exist.
259				continue
260			}
261			return nil, err
262		}
263		for _, a := range lbAddrs {
264			ip, err := formatIP(a)
265			if err != nil {
266				return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
267			}
268			addr := ip + ":" + strconv.Itoa(int(s.Port))
269			newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
270		}
271	}
272	return newAddrs, nil
273}
274
275func handleDNSError(err error, lookupType string) error {
276	dnsErr, ok := err.(*net.DNSError)
277	if ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
278		// Timeouts and temporary errors should be communicated to gRPC to
279		// attempt another DNS query (with backoff).  Other errors should be
280		// suppressed (they may represent the absence of a TXT record).
281		return nil
282	}
283	if err != nil {
284		err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
285		logger.Info(err)
286	}
287	return err
288}
289
290func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
291	ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
292	if err != nil {
293		if envconfig.TXTErrIgnore {
294			return nil
295		}
296		if err = handleDNSError(err, "TXT"); err != nil {
297			return &serviceconfig.ParseResult{Err: err}
298		}
299		return nil
300	}
301	var res string
302	for _, s := range ss {
303		res += s
304	}
305
306	// TXT record must have "grpc_config=" attribute in order to be used as
307	// service config.
308	if !strings.HasPrefix(res, txtAttribute) {
309		logger.Warningf("dns: TXT record %v missing %v attribute", res, txtAttribute)
310		// This is not an error; it is the equivalent of not having a service
311		// config.
312		return nil
313	}
314	sc := canaryingSC(strings.TrimPrefix(res, txtAttribute))
315	return d.cc.ParseServiceConfig(sc)
316}
317
318func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
319	addrs, err := d.resolver.LookupHost(ctx, d.host)
320	if err != nil {
321		err = handleDNSError(err, "A")
322		return nil, err
323	}
324	newAddrs := make([]resolver.Address, 0, len(addrs))
325	for _, a := range addrs {
326		ip, err := formatIP(a)
327		if err != nil {
328			return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
329		}
330		addr := ip + ":" + d.port
331		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
332	}
333	return newAddrs, nil
334}
335
336func (d *dnsResolver) lookup() (*resolver.State, error) {
337	ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
338	defer cancel()
339	srv, srvErr := d.lookupSRV(ctx)
340	addrs, hostErr := d.lookupHost(ctx)
341	if hostErr != nil && (srvErr != nil || len(srv) == 0) {
342		return nil, hostErr
343	}
344
345	state := resolver.State{Addresses: addrs}
346	if len(srv) > 0 {
347		state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
348	}
349	if !d.disableServiceConfig {
350		state.ServiceConfig = d.lookupTXT(ctx)
351	}
352	return &state, nil
353}
354
355// formatIP returns an error if addr is not a valid textual representation of
356// an IP address. If addr is an IPv4 address, return the addr and error = nil.
357// If addr is an IPv6 address, return the addr enclosed in square brackets and
358// error = nil.
359func formatIP(addr string) (string, error) {
360	ip, err := netip.ParseAddr(addr)
361	if err != nil {
362		return "", err
363	}
364	if ip.Is4() {
365		return addr, nil
366	}
367	return "[" + addr + "]", nil
368}
369
370// parseTarget takes the user input target string and default port, returns
371// formatted host and port info. If target doesn't specify a port, set the port
372// to be the defaultPort. If target is in IPv6 format and host-name is enclosed
373// in square brackets, brackets are stripped when setting the host.
374// examples:
375// target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
376// target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
377// target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
378// target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
379func parseTarget(target, defaultPort string) (host, port string, err error) {
380	if target == "" {
381		return "", "", internal.ErrMissingAddr
382	}
383	if _, err := netip.ParseAddr(target); err == nil {
384		// target is an IPv4 or IPv6(without brackets) address
385		return target, defaultPort, nil
386	}
387	if host, port, err = net.SplitHostPort(target); err == nil {
388		if port == "" {
389			// If the port field is empty (target ends with colon), e.g. "[::1]:",
390			// this is an error.
391			return "", "", internal.ErrEndsWithColon
392		}
393		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
394		if host == "" {
395			// Keep consistent with net.Dial(): If the host is empty, as in ":80",
396			// the local system is assumed.
397			host = "localhost"
398		}
399		return host, port, nil
400	}
401	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
402		// target doesn't have port
403		return host, port, nil
404	}
405	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
406}
407
408type rawChoice struct {
409	ClientLanguage *[]string        `json:"clientLanguage,omitempty"`
410	Percentage     *int             `json:"percentage,omitempty"`
411	ClientHostName *[]string        `json:"clientHostName,omitempty"`
412	ServiceConfig  *json.RawMessage `json:"serviceConfig,omitempty"`
413}
414
415func containsString(a *[]string, b string) bool {
416	if a == nil {
417		return true
418	}
419	for _, c := range *a {
420		if c == b {
421			return true
422		}
423	}
424	return false
425}
426
427func chosenByPercentage(a *int) bool {
428	if a == nil {
429		return true
430	}
431	return rand.IntN(100)+1 <= *a
432}
433
434func canaryingSC(js string) string {
435	if js == "" {
436		return ""
437	}
438	var rcs []rawChoice
439	err := json.Unmarshal([]byte(js), &rcs)
440	if err != nil {
441		logger.Warningf("dns: error parsing service config json: %v", err)
442		return ""
443	}
444	cliHostname, err := os.Hostname()
445	if err != nil {
446		logger.Warningf("dns: error getting client hostname: %v", err)
447		return ""
448	}
449	var sc string
450	for _, c := range rcs {
451		if !containsString(c.ClientLanguage, golang) ||
452			!chosenByPercentage(c.Percentage) ||
453			!containsString(c.ClientHostName, cliHostname) ||
454			c.ServiceConfig == nil {
455			continue
456		}
457		sc = string(*c.ServiceConfig)
458		break
459	}
460	return sc
461}