fix.go

   1// Copyright 2013 The Go Authors. All rights reserved.
   2// Use of this source code is governed by a BSD-style
   3// license that can be found in the LICENSE file.
   4
   5package imports
   6
   7import (
   8	"bytes"
   9	"context"
  10	"fmt"
  11	"go/ast"
  12	"go/build"
  13	"go/parser"
  14	"go/token"
  15	"io/ioutil"
  16	"log"
  17	"os"
  18	"os/exec"
  19	"path"
  20	"path/filepath"
  21	"sort"
  22	"strconv"
  23	"strings"
  24	"sync"
  25	"time"
  26	"unicode"
  27	"unicode/utf8"
  28
  29	"golang.org/x/tools/go/ast/astutil"
  30	"golang.org/x/tools/go/packages"
  31	"golang.org/x/tools/internal/gopathwalk"
  32)
  33
  34// Debug controls verbose logging.
  35var Debug = false
  36
  37// LocalPrefix is a comma-separated string of import path prefixes, which, if
  38// set, instructs Process to sort the import paths with the given prefixes
  39// into another group after 3rd-party packages.
  40var LocalPrefix string
  41
  42func localPrefixes() []string {
  43	if LocalPrefix != "" {
  44		return strings.Split(LocalPrefix, ",")
  45	}
  46	return nil
  47}
  48
  49// importToGroup is a list of functions which map from an import path to
  50// a group number.
  51var importToGroup = []func(importPath string) (num int, ok bool){
  52	func(importPath string) (num int, ok bool) {
  53		for _, p := range localPrefixes() {
  54			if strings.HasPrefix(importPath, p) || strings.TrimSuffix(p, "/") == importPath {
  55				return 3, true
  56			}
  57		}
  58		return
  59	},
  60	func(importPath string) (num int, ok bool) {
  61		if strings.HasPrefix(importPath, "appengine") {
  62			return 2, true
  63		}
  64		return
  65	},
  66	func(importPath string) (num int, ok bool) {
  67		if strings.Contains(importPath, ".") {
  68			return 1, true
  69		}
  70		return
  71	},
  72}
  73
  74func importGroup(importPath string) int {
  75	for _, fn := range importToGroup {
  76		if n, ok := fn(importPath); ok {
  77			return n
  78		}
  79	}
  80	return 0
  81}
  82
  83// An importInfo represents a single import statement.
  84type importInfo struct {
  85	importPath string // import path, e.g. "crypto/rand".
  86	name       string // import name, e.g. "crand", or "" if none.
  87}
  88
  89// A packageInfo represents what's known about a package.
  90type packageInfo struct {
  91	name    string          // real package name, if known.
  92	exports map[string]bool // known exports.
  93}
  94
  95// parseOtherFiles parses all the Go files in srcDir except filename, including
  96// test files if filename looks like a test.
  97func parseOtherFiles(fset *token.FileSet, srcDir, filename string) []*ast.File {
  98	// This could use go/packages but it doesn't buy much, and it fails
  99	// with https://golang.org/issue/26296 in LoadFiles mode in some cases.
 100	considerTests := strings.HasSuffix(filename, "_test.go")
 101
 102	fileBase := filepath.Base(filename)
 103	packageFileInfos, err := ioutil.ReadDir(srcDir)
 104	if err != nil {
 105		return nil
 106	}
 107
 108	var files []*ast.File
 109	for _, fi := range packageFileInfos {
 110		if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") {
 111			continue
 112		}
 113		if !considerTests && strings.HasSuffix(fi.Name(), "_test.go") {
 114			continue
 115		}
 116
 117		f, err := parser.ParseFile(fset, filepath.Join(srcDir, fi.Name()), nil, 0)
 118		if err != nil {
 119			continue
 120		}
 121
 122		files = append(files, f)
 123	}
 124
 125	return files
 126}
 127
 128// addGlobals puts the names of package vars into the provided map.
 129func addGlobals(f *ast.File, globals map[string]bool) {
 130	for _, decl := range f.Decls {
 131		genDecl, ok := decl.(*ast.GenDecl)
 132		if !ok {
 133			continue
 134		}
 135
 136		for _, spec := range genDecl.Specs {
 137			valueSpec, ok := spec.(*ast.ValueSpec)
 138			if !ok {
 139				continue
 140			}
 141			globals[valueSpec.Names[0].Name] = true
 142		}
 143	}
 144}
 145
 146// collectReferences builds a map of selector expressions, from
 147// left hand side (X) to a set of right hand sides (Sel).
 148func collectReferences(f *ast.File) references {
 149	refs := references{}
 150
 151	var visitor visitFn
 152	visitor = func(node ast.Node) ast.Visitor {
 153		if node == nil {
 154			return visitor
 155		}
 156		switch v := node.(type) {
 157		case *ast.SelectorExpr:
 158			xident, ok := v.X.(*ast.Ident)
 159			if !ok {
 160				break
 161			}
 162			if xident.Obj != nil {
 163				// If the parser can resolve it, it's not a package ref.
 164				break
 165			}
 166			if !ast.IsExported(v.Sel.Name) {
 167				// Whatever this is, it's not exported from a package.
 168				break
 169			}
 170			pkgName := xident.Name
 171			r := refs[pkgName]
 172			if r == nil {
 173				r = make(map[string]bool)
 174				refs[pkgName] = r
 175			}
 176			r[v.Sel.Name] = true
 177		}
 178		return visitor
 179	}
 180	ast.Walk(visitor, f)
 181	return refs
 182}
 183
 184// collectImports returns all the imports in f, keyed by their package name as
 185// determined by pathToName. Unnamed imports (., _) and "C" are ignored.
 186func collectImports(f *ast.File) []*importInfo {
 187	var imports []*importInfo
 188	for _, imp := range f.Imports {
 189		var name string
 190		if imp.Name != nil {
 191			name = imp.Name.Name
 192		}
 193		if imp.Path.Value == `"C"` || name == "_" || name == "." {
 194			continue
 195		}
 196		path := strings.Trim(imp.Path.Value, `"`)
 197		imports = append(imports, &importInfo{
 198			name:       name,
 199			importPath: path,
 200		})
 201	}
 202	return imports
 203}
 204
 205// findMissingImport searches pass's candidates for an import that provides
 206// pkg, containing all of syms.
 207func (p *pass) findMissingImport(pkg string, syms map[string]bool) *importInfo {
 208	for _, candidate := range p.candidates {
 209		pkgInfo, ok := p.knownPackages[candidate.importPath]
 210		if !ok {
 211			continue
 212		}
 213		if p.importIdentifier(candidate) != pkg {
 214			continue
 215		}
 216
 217		allFound := true
 218		for right := range syms {
 219			if !pkgInfo.exports[right] {
 220				allFound = false
 221				break
 222			}
 223		}
 224
 225		if allFound {
 226			return candidate
 227		}
 228	}
 229	return nil
 230}
 231
 232// references is set of references found in a Go file. The first map key is the
 233// left hand side of a selector expression, the second key is the right hand
 234// side, and the value should always be true.
 235type references map[string]map[string]bool
 236
 237// A pass contains all the inputs and state necessary to fix a file's imports.
 238// It can be modified in some ways during use; see comments below.
 239type pass struct {
 240	// Inputs. These must be set before a call to load, and not modified after.
 241	fset                 *token.FileSet // fset used to parse f and its siblings.
 242	f                    *ast.File      // the file being fixed.
 243	srcDir               string         // the directory containing f.
 244	fixEnv               *fixEnv        // the environment to use for go commands, etc.
 245	loadRealPackageNames bool           // if true, load package names from disk rather than guessing them.
 246	otherFiles           []*ast.File    // sibling files.
 247
 248	// Intermediate state, generated by load.
 249	existingImports map[string]*importInfo
 250	allRefs         references
 251	missingRefs     references
 252
 253	// Inputs to fix. These can be augmented between successive fix calls.
 254	lastTry       bool                    // indicates that this is the last call and fix should clean up as best it can.
 255	candidates    []*importInfo           // candidate imports in priority order.
 256	knownPackages map[string]*packageInfo // information about all known packages.
 257}
 258
 259// loadPackageNames saves the package names for everything referenced by imports.
 260func (p *pass) loadPackageNames(imports []*importInfo) error {
 261	var unknown []string
 262	for _, imp := range imports {
 263		if _, ok := p.knownPackages[imp.importPath]; ok {
 264			continue
 265		}
 266		unknown = append(unknown, imp.importPath)
 267	}
 268
 269	names, err := p.fixEnv.getResolver().loadPackageNames(unknown, p.srcDir)
 270	if err != nil {
 271		return err
 272	}
 273
 274	for path, name := range names {
 275		p.knownPackages[path] = &packageInfo{
 276			name:    name,
 277			exports: map[string]bool{},
 278		}
 279	}
 280	return nil
 281}
 282
 283// importIdentifier returns the identifier that imp will introduce. It will
 284// guess if the package name has not been loaded, e.g. because the source
 285// is not available.
 286func (p *pass) importIdentifier(imp *importInfo) string {
 287	if imp.name != "" {
 288		return imp.name
 289	}
 290	known := p.knownPackages[imp.importPath]
 291	if known != nil && known.name != "" {
 292		return known.name
 293	}
 294	return importPathToAssumedName(imp.importPath)
 295}
 296
 297// load reads in everything necessary to run a pass, and reports whether the
 298// file already has all the imports it needs. It fills in p.missingRefs with the
 299// file's missing symbols, if any, or removes unused imports if not.
 300func (p *pass) load() bool {
 301	p.knownPackages = map[string]*packageInfo{}
 302	p.missingRefs = references{}
 303	p.existingImports = map[string]*importInfo{}
 304
 305	// Load basic information about the file in question.
 306	p.allRefs = collectReferences(p.f)
 307
 308	// Load stuff from other files in the same package:
 309	// global variables so we know they don't need resolving, and imports
 310	// that we might want to mimic.
 311	globals := map[string]bool{}
 312	for _, otherFile := range p.otherFiles {
 313		// Don't load globals from files that are in the same directory
 314		// but a different package. Using them to suggest imports is OK.
 315		if p.f.Name.Name == otherFile.Name.Name {
 316			addGlobals(otherFile, globals)
 317		}
 318		p.candidates = append(p.candidates, collectImports(otherFile)...)
 319	}
 320
 321	// Resolve all the import paths we've seen to package names, and store
 322	// f's imports by the identifier they introduce.
 323	imports := collectImports(p.f)
 324	if p.loadRealPackageNames {
 325		err := p.loadPackageNames(append(imports, p.candidates...))
 326		if err != nil {
 327			if Debug {
 328				log.Printf("loading package names: %v", err)
 329			}
 330			return false
 331		}
 332	}
 333	for _, imp := range imports {
 334		p.existingImports[p.importIdentifier(imp)] = imp
 335	}
 336
 337	// Find missing references.
 338	for left, rights := range p.allRefs {
 339		if globals[left] {
 340			continue
 341		}
 342		_, ok := p.existingImports[left]
 343		if !ok {
 344			p.missingRefs[left] = rights
 345			continue
 346		}
 347	}
 348	if len(p.missingRefs) != 0 {
 349		return false
 350	}
 351
 352	return p.fix()
 353}
 354
 355// fix attempts to satisfy missing imports using p.candidates. If it finds
 356// everything, or if p.lastTry is true, it adds the imports it found,
 357// removes anything unused, and returns true.
 358func (p *pass) fix() bool {
 359	// Find missing imports.
 360	var selected []*importInfo
 361	for left, rights := range p.missingRefs {
 362		if imp := p.findMissingImport(left, rights); imp != nil {
 363			selected = append(selected, imp)
 364		}
 365	}
 366
 367	if !p.lastTry && len(selected) != len(p.missingRefs) {
 368		return false
 369	}
 370
 371	// Found everything, or giving up. Add the new imports and remove any unused.
 372	for _, imp := range p.existingImports {
 373		// We deliberately ignore globals here, because we can't be sure
 374		// they're in the same package. People do things like put multiple
 375		// main packages in the same directory, and we don't want to
 376		// remove imports if they happen to have the same name as a var in
 377		// a different package.
 378		if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok {
 379			astutil.DeleteNamedImport(p.fset, p.f, imp.name, imp.importPath)
 380		}
 381	}
 382
 383	for _, imp := range selected {
 384		astutil.AddNamedImport(p.fset, p.f, imp.name, imp.importPath)
 385	}
 386
 387	if p.loadRealPackageNames {
 388		for _, imp := range p.f.Imports {
 389			if imp.Name != nil {
 390				continue
 391			}
 392			path := strings.Trim(imp.Path.Value, `""`)
 393			ident := p.importIdentifier(&importInfo{importPath: path})
 394			if ident != importPathToAssumedName(path) {
 395				imp.Name = &ast.Ident{Name: ident, NamePos: imp.Pos()}
 396			}
 397		}
 398	}
 399
 400	return true
 401}
 402
 403// assumeSiblingImportsValid assumes that siblings' use of packages is valid,
 404// adding the exports they use.
 405func (p *pass) assumeSiblingImportsValid() {
 406	for _, f := range p.otherFiles {
 407		refs := collectReferences(f)
 408		imports := collectImports(f)
 409		importsByName := map[string]*importInfo{}
 410		for _, imp := range imports {
 411			importsByName[p.importIdentifier(imp)] = imp
 412		}
 413		for left, rights := range refs {
 414			if imp, ok := importsByName[left]; ok {
 415				if _, ok := stdlib[imp.importPath]; ok {
 416					// We have the stdlib in memory; no need to guess.
 417					rights = stdlib[imp.importPath]
 418				}
 419				p.addCandidate(imp, &packageInfo{
 420					// no name; we already know it.
 421					exports: rights,
 422				})
 423			}
 424		}
 425	}
 426}
 427
 428// addCandidate adds a candidate import to p, and merges in the information
 429// in pkg.
 430func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) {
 431	p.candidates = append(p.candidates, imp)
 432	if existing, ok := p.knownPackages[imp.importPath]; ok {
 433		if existing.name == "" {
 434			existing.name = pkg.name
 435		}
 436		for export := range pkg.exports {
 437			existing.exports[export] = true
 438		}
 439	} else {
 440		p.knownPackages[imp.importPath] = pkg
 441	}
 442}
 443
 444// fixImports adds and removes imports from f so that all its references are
 445// satisfied and there are no unused imports.
 446//
 447// This is declared as a variable rather than a function so goimports can
 448// easily be extended by adding a file with an init function.
 449var fixImports = fixImportsDefault
 450
 451func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *fixEnv) error {
 452	abs, err := filepath.Abs(filename)
 453	if err != nil {
 454		return err
 455	}
 456	srcDir := filepath.Dir(abs)
 457	if Debug {
 458		log.Printf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir)
 459	}
 460
 461	// First pass: looking only at f, and using the naive algorithm to
 462	// derive package names from import paths, see if the file is already
 463	// complete. We can't add any imports yet, because we don't know
 464	// if missing references are actually package vars.
 465	p := &pass{fset: fset, f: f, srcDir: srcDir}
 466	if p.load() {
 467		return nil
 468	}
 469
 470	otherFiles := parseOtherFiles(fset, srcDir, filename)
 471
 472	// Second pass: add information from other files in the same package,
 473	// like their package vars and imports.
 474	p.otherFiles = otherFiles
 475	if p.load() {
 476		return nil
 477	}
 478
 479	// Now we can try adding imports from the stdlib.
 480	p.assumeSiblingImportsValid()
 481	addStdlibCandidates(p, p.missingRefs)
 482	if p.fix() {
 483		return nil
 484	}
 485
 486	// Third pass: get real package names where we had previously used
 487	// the naive algorithm. This is the first step that will use the
 488	// environment, so we provide it here for the first time.
 489	p = &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env}
 490	p.loadRealPackageNames = true
 491	p.otherFiles = otherFiles
 492	if p.load() {
 493		return nil
 494	}
 495
 496	addStdlibCandidates(p, p.missingRefs)
 497	p.assumeSiblingImportsValid()
 498	if p.fix() {
 499		return nil
 500	}
 501
 502	// Go look for candidates in $GOPATH, etc. We don't necessarily load
 503	// the real exports of sibling imports, so keep assuming their contents.
 504	if err := addExternalCandidates(p, p.missingRefs, filename); err != nil {
 505		return err
 506	}
 507
 508	p.lastTry = true
 509	p.fix()
 510	return nil
 511}
 512
 513// fixEnv contains environment variables and settings that affect the use of
 514// the go command, the go/build package, etc.
 515type fixEnv struct {
 516	// If non-empty, these will be used instead of the
 517	// process-wide values.
 518	GOPATH, GOROOT, GO111MODULE, GOPROXY, GOFLAGS string
 519	WorkingDir                                    string
 520
 521	// If true, use go/packages regardless of the environment.
 522	ForceGoPackages bool
 523
 524	resolver resolver
 525}
 526
 527func (e *fixEnv) env() []string {
 528	env := os.Environ()
 529	add := func(k, v string) {
 530		if v != "" {
 531			env = append(env, k+"="+v)
 532		}
 533	}
 534	add("GOPATH", e.GOPATH)
 535	add("GOROOT", e.GOROOT)
 536	add("GO111MODULE", e.GO111MODULE)
 537	add("GOPROXY", e.GOPROXY)
 538	add("GOFLAGS", e.GOFLAGS)
 539	if e.WorkingDir != "" {
 540		add("PWD", e.WorkingDir)
 541	}
 542	return env
 543}
 544
 545func (e *fixEnv) getResolver() resolver {
 546	if e.resolver != nil {
 547		return e.resolver
 548	}
 549	if e.ForceGoPackages {
 550		return &goPackagesResolver{env: e}
 551	}
 552
 553	out, err := e.invokeGo("env", "GOMOD")
 554	if err != nil || len(bytes.TrimSpace(out.Bytes())) == 0 {
 555		return &gopathResolver{env: e}
 556	}
 557	return &moduleResolver{env: e}
 558}
 559
 560func (e *fixEnv) newPackagesConfig(mode packages.LoadMode) *packages.Config {
 561	return &packages.Config{
 562		Mode: mode,
 563		Dir:  e.WorkingDir,
 564		Env:  e.env(),
 565	}
 566}
 567
 568func (e *fixEnv) buildContext() *build.Context {
 569	ctx := build.Default
 570	ctx.GOROOT = e.GOROOT
 571	ctx.GOPATH = e.GOPATH
 572	return &ctx
 573}
 574
 575func (e *fixEnv) invokeGo(args ...string) (*bytes.Buffer, error) {
 576	cmd := exec.Command("go", args...)
 577	stdout := &bytes.Buffer{}
 578	stderr := &bytes.Buffer{}
 579	cmd.Stdout = stdout
 580	cmd.Stderr = stderr
 581	cmd.Env = e.env()
 582	cmd.Dir = e.WorkingDir
 583
 584	if Debug {
 585		defer func(start time.Time) { log.Printf("%s for %v", time.Since(start), cmdDebugStr(cmd)) }(time.Now())
 586	}
 587	if err := cmd.Run(); err != nil {
 588		return nil, fmt.Errorf("running go: %v (stderr:\n%s)", err, stderr)
 589	}
 590	return stdout, nil
 591}
 592
 593func cmdDebugStr(cmd *exec.Cmd) string {
 594	env := make(map[string]string)
 595	for _, kv := range cmd.Env {
 596		split := strings.Split(kv, "=")
 597		k, v := split[0], split[1]
 598		env[k] = v
 599	}
 600
 601	return fmt.Sprintf("GOROOT=%v GOPATH=%v GO111MODULE=%v GOPROXY=%v PWD=%v go %v", env["GOROOT"], env["GOPATH"], env["GO111MODULE"], env["GOPROXY"], env["PWD"], cmd.Args)
 602}
 603
 604func addStdlibCandidates(pass *pass, refs references) {
 605	add := func(pkg string) {
 606		pass.addCandidate(
 607			&importInfo{importPath: pkg},
 608			&packageInfo{name: path.Base(pkg), exports: stdlib[pkg]})
 609	}
 610	for left := range refs {
 611		if left == "rand" {
 612			// Make sure we try crypto/rand before math/rand.
 613			add("crypto/rand")
 614			add("math/rand")
 615			continue
 616		}
 617		for importPath := range stdlib {
 618			if path.Base(importPath) == left {
 619				add(importPath)
 620			}
 621		}
 622	}
 623}
 624
 625// A resolver does the build-system-specific parts of goimports.
 626type resolver interface {
 627	// loadPackageNames loads the package names in importPaths.
 628	loadPackageNames(importPaths []string, srcDir string) (map[string]string, error)
 629	// scan finds (at least) the packages satisfying refs. The returned slice is unordered.
 630	scan(refs references) ([]*pkg, error)
 631}
 632
 633// gopathResolver implements resolver for GOPATH and module workspaces using go/packages.
 634type goPackagesResolver struct {
 635	env *fixEnv
 636}
 637
 638func (r *goPackagesResolver) loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) {
 639	cfg := r.env.newPackagesConfig(packages.LoadFiles)
 640	pkgs, err := packages.Load(cfg, importPaths...)
 641	if err != nil {
 642		return nil, err
 643	}
 644	names := map[string]string{}
 645	for _, pkg := range pkgs {
 646		names[VendorlessPath(pkg.PkgPath)] = pkg.Name
 647	}
 648	// We may not have found all the packages. Guess the rest.
 649	for _, path := range importPaths {
 650		if _, ok := names[path]; ok {
 651			continue
 652		}
 653		names[path] = importPathToAssumedName(path)
 654	}
 655	return names, nil
 656
 657}
 658
 659func (r *goPackagesResolver) scan(refs references) ([]*pkg, error) {
 660	var loadQueries []string
 661	for pkgName := range refs {
 662		loadQueries = append(loadQueries, "iamashamedtousethedisabledqueryname="+pkgName)
 663	}
 664	sort.Strings(loadQueries)
 665	cfg := r.env.newPackagesConfig(packages.LoadFiles)
 666	goPackages, err := packages.Load(cfg, loadQueries...)
 667	if err != nil {
 668		return nil, err
 669	}
 670
 671	var scan []*pkg
 672	for _, goPackage := range goPackages {
 673		scan = append(scan, &pkg{
 674			dir:             filepath.Dir(goPackage.CompiledGoFiles[0]),
 675			importPathShort: VendorlessPath(goPackage.PkgPath),
 676			goPackage:       goPackage,
 677		})
 678	}
 679	return scan, nil
 680}
 681
 682func addExternalCandidates(pass *pass, refs references, filename string) error {
 683	dirScan, err := pass.fixEnv.getResolver().scan(refs)
 684	if err != nil {
 685		return err
 686	}
 687
 688	// Search for imports matching potential package references.
 689	type result struct {
 690		imp *importInfo
 691		pkg *packageInfo
 692	}
 693	results := make(chan result, len(refs))
 694
 695	ctx, cancel := context.WithCancel(context.TODO())
 696	var wg sync.WaitGroup
 697	defer func() {
 698		cancel()
 699		wg.Wait()
 700	}()
 701	var (
 702		firstErr     error
 703		firstErrOnce sync.Once
 704	)
 705	for pkgName, symbols := range refs {
 706		wg.Add(1)
 707		go func(pkgName string, symbols map[string]bool) {
 708			defer wg.Done()
 709
 710			found, err := findImport(ctx, pass.fixEnv, dirScan, pkgName, symbols, filename)
 711
 712			if err != nil {
 713				firstErrOnce.Do(func() {
 714					firstErr = err
 715					cancel()
 716				})
 717				return
 718			}
 719
 720			if found == nil {
 721				return // No matching package.
 722			}
 723
 724			imp := &importInfo{
 725				importPath: found.importPathShort,
 726			}
 727
 728			pkg := &packageInfo{
 729				name:    pkgName,
 730				exports: symbols,
 731			}
 732			results <- result{imp, pkg}
 733		}(pkgName, symbols)
 734	}
 735	go func() {
 736		wg.Wait()
 737		close(results)
 738	}()
 739
 740	for result := range results {
 741		pass.addCandidate(result.imp, result.pkg)
 742	}
 743	return firstErr
 744}
 745
 746// notIdentifier reports whether ch is an invalid identifier character.
 747func notIdentifier(ch rune) bool {
 748	return !('a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' ||
 749		'0' <= ch && ch <= '9' ||
 750		ch == '_' ||
 751		ch >= utf8.RuneSelf && (unicode.IsLetter(ch) || unicode.IsDigit(ch)))
 752}
 753
 754// importPathToAssumedName returns the assumed package name of an import path.
 755// It does this using only string parsing of the import path.
 756// It picks the last element of the path that does not look like a major
 757// version, and then picks the valid identifier off the start of that element.
 758// It is used to determine if a local rename should be added to an import for
 759// clarity.
 760// This function could be moved to a standard package and exported if we want
 761// for use in other tools.
 762func importPathToAssumedName(importPath string) string {
 763	base := path.Base(importPath)
 764	if strings.HasPrefix(base, "v") {
 765		if _, err := strconv.Atoi(base[1:]); err == nil {
 766			dir := path.Dir(importPath)
 767			if dir != "." {
 768				base = path.Base(dir)
 769			}
 770		}
 771	}
 772	base = strings.TrimPrefix(base, "go-")
 773	if i := strings.IndexFunc(base, notIdentifier); i >= 0 {
 774		base = base[:i]
 775	}
 776	return base
 777}
 778
 779// gopathResolver implements resolver for GOPATH workspaces.
 780type gopathResolver struct {
 781	env *fixEnv
 782}
 783
 784func (r *gopathResolver) loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) {
 785	names := map[string]string{}
 786	for _, path := range importPaths {
 787		names[path] = importPathToName(r.env, path, srcDir)
 788	}
 789	return names, nil
 790}
 791
 792// importPathToNameGoPath finds out the actual package name, as declared in its .go files.
 793// If there's a problem, it returns "".
 794func importPathToName(env *fixEnv, importPath, srcDir string) (packageName string) {
 795	// Fast path for standard library without going to disk.
 796	if _, ok := stdlib[importPath]; ok {
 797		return path.Base(importPath) // stdlib packages always match their paths.
 798	}
 799
 800	buildPkg, err := env.buildContext().Import(importPath, srcDir, build.FindOnly)
 801	if err != nil {
 802		return ""
 803	}
 804	pkgName, err := packageDirToName(buildPkg.Dir)
 805	if err != nil {
 806		return ""
 807	}
 808	return pkgName
 809}
 810
 811// packageDirToName is a faster version of build.Import if
 812// the only thing desired is the package name. It uses build.FindOnly
 813// to find the directory and then only parses one file in the package,
 814// trusting that the files in the directory are consistent.
 815func packageDirToName(dir string) (packageName string, err error) {
 816	d, err := os.Open(dir)
 817	if err != nil {
 818		return "", err
 819	}
 820	names, err := d.Readdirnames(-1)
 821	d.Close()
 822	if err != nil {
 823		return "", err
 824	}
 825	sort.Strings(names) // to have predictable behavior
 826	var lastErr error
 827	var nfile int
 828	for _, name := range names {
 829		if !strings.HasSuffix(name, ".go") {
 830			continue
 831		}
 832		if strings.HasSuffix(name, "_test.go") {
 833			continue
 834		}
 835		nfile++
 836		fullFile := filepath.Join(dir, name)
 837
 838		fset := token.NewFileSet()
 839		f, err := parser.ParseFile(fset, fullFile, nil, parser.PackageClauseOnly)
 840		if err != nil {
 841			lastErr = err
 842			continue
 843		}
 844		pkgName := f.Name.Name
 845		if pkgName == "documentation" {
 846			// Special case from go/build.ImportDir, not
 847			// handled by ctx.MatchFile.
 848			continue
 849		}
 850		if pkgName == "main" {
 851			// Also skip package main, assuming it's a +build ignore generator or example.
 852			// Since you can't import a package main anyway, there's no harm here.
 853			continue
 854		}
 855		return pkgName, nil
 856	}
 857	if lastErr != nil {
 858		return "", lastErr
 859	}
 860	return "", fmt.Errorf("no importable package found in %d Go files", nfile)
 861}
 862
 863type pkg struct {
 864	goPackage       *packages.Package
 865	dir             string // absolute file path to pkg directory ("/usr/lib/go/src/net/http")
 866	importPathShort string // vendorless import path ("net/http", "a/b")
 867}
 868
 869type pkgDistance struct {
 870	pkg      *pkg
 871	distance int // relative distance to target
 872}
 873
 874// byDistanceOrImportPathShortLength sorts by relative distance breaking ties
 875// on the short import path length and then the import string itself.
 876type byDistanceOrImportPathShortLength []pkgDistance
 877
 878func (s byDistanceOrImportPathShortLength) Len() int { return len(s) }
 879func (s byDistanceOrImportPathShortLength) Less(i, j int) bool {
 880	di, dj := s[i].distance, s[j].distance
 881	if di == -1 {
 882		return false
 883	}
 884	if dj == -1 {
 885		return true
 886	}
 887	if di != dj {
 888		return di < dj
 889	}
 890
 891	vi, vj := s[i].pkg.importPathShort, s[j].pkg.importPathShort
 892	if len(vi) != len(vj) {
 893		return len(vi) < len(vj)
 894	}
 895	return vi < vj
 896}
 897func (s byDistanceOrImportPathShortLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
 898
 899func distance(basepath, targetpath string) int {
 900	p, err := filepath.Rel(basepath, targetpath)
 901	if err != nil {
 902		return -1
 903	}
 904	if p == "." {
 905		return 0
 906	}
 907	return strings.Count(p, string(filepath.Separator)) + 1
 908}
 909
 910func (r *gopathResolver) scan(_ references) ([]*pkg, error) {
 911	dupCheck := make(map[string]bool)
 912	var result []*pkg
 913
 914	var mu sync.Mutex
 915
 916	add := func(root gopathwalk.Root, dir string) {
 917		mu.Lock()
 918		defer mu.Unlock()
 919
 920		if _, dup := dupCheck[dir]; dup {
 921			return
 922		}
 923		dupCheck[dir] = true
 924		importpath := filepath.ToSlash(dir[len(root.Path)+len("/"):])
 925		result = append(result, &pkg{
 926			importPathShort: VendorlessPath(importpath),
 927			dir:             dir,
 928		})
 929	}
 930	gopathwalk.Walk(gopathwalk.SrcDirsRoots(r.env.buildContext()), add, gopathwalk.Options{Debug: Debug, ModulesEnabled: false})
 931	return result, nil
 932}
 933
 934// VendorlessPath returns the devendorized version of the import path ipath.
 935// For example, VendorlessPath("foo/bar/vendor/a/b") returns "a/b".
 936func VendorlessPath(ipath string) string {
 937	// Devendorize for use in import statement.
 938	if i := strings.LastIndex(ipath, "/vendor/"); i >= 0 {
 939		return ipath[i+len("/vendor/"):]
 940	}
 941	if strings.HasPrefix(ipath, "vendor/") {
 942		return ipath[len("vendor/"):]
 943	}
 944	return ipath
 945}
 946
 947// loadExports returns the set of exported symbols in the package at dir.
 948// It returns nil on error or if the package name in dir does not match expectPackage.
 949func loadExports(ctx context.Context, env *fixEnv, expectPackage string, pkg *pkg) (map[string]bool, error) {
 950	if Debug {
 951		log.Printf("loading exports in dir %s (seeking package %s)", pkg.dir, expectPackage)
 952	}
 953	if pkg.goPackage != nil {
 954		exports := map[string]bool{}
 955		fset := token.NewFileSet()
 956		for _, fname := range pkg.goPackage.CompiledGoFiles {
 957			f, err := parser.ParseFile(fset, fname, nil, 0)
 958			if err != nil {
 959				return nil, fmt.Errorf("parsing %s: %v", fname, err)
 960			}
 961			for name := range f.Scope.Objects {
 962				if ast.IsExported(name) {
 963					exports[name] = true
 964				}
 965			}
 966		}
 967		return exports, nil
 968	}
 969
 970	exports := make(map[string]bool)
 971
 972	// Look for non-test, buildable .go files which could provide exports.
 973	all, err := ioutil.ReadDir(pkg.dir)
 974	if err != nil {
 975		return nil, err
 976	}
 977	var files []os.FileInfo
 978	for _, fi := range all {
 979		name := fi.Name()
 980		if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") {
 981			continue
 982		}
 983		match, err := env.buildContext().MatchFile(pkg.dir, fi.Name())
 984		if err != nil || !match {
 985			continue
 986		}
 987		files = append(files, fi)
 988	}
 989
 990	if len(files) == 0 {
 991		return nil, fmt.Errorf("dir %v contains no buildable, non-test .go files", pkg.dir)
 992	}
 993
 994	fset := token.NewFileSet()
 995	for _, fi := range files {
 996		select {
 997		case <-ctx.Done():
 998			return nil, ctx.Err()
 999		default:
1000		}
1001
1002		fullFile := filepath.Join(pkg.dir, fi.Name())
1003		f, err := parser.ParseFile(fset, fullFile, nil, 0)
1004		if err != nil {
1005			return nil, fmt.Errorf("parsing %s: %v", fullFile, err)
1006		}
1007		pkgName := f.Name.Name
1008		if pkgName == "documentation" {
1009			// Special case from go/build.ImportDir, not
1010			// handled by MatchFile above.
1011			continue
1012		}
1013		if pkgName != expectPackage {
1014			return nil, fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", pkg.dir, expectPackage, pkgName)
1015		}
1016		for name := range f.Scope.Objects {
1017			if ast.IsExported(name) {
1018				exports[name] = true
1019			}
1020		}
1021	}
1022
1023	if Debug {
1024		exportList := make([]string, 0, len(exports))
1025		for k := range exports {
1026			exportList = append(exportList, k)
1027		}
1028		sort.Strings(exportList)
1029		log.Printf("loaded exports in dir %v (package %v): %v", pkg.dir, expectPackage, strings.Join(exportList, ", "))
1030	}
1031	return exports, nil
1032}
1033
1034// findImport searches for a package with the given symbols.
1035// If no package is found, findImport returns ("", false, nil)
1036func findImport(ctx context.Context, env *fixEnv, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) {
1037	pkgDir, err := filepath.Abs(filename)
1038	if err != nil {
1039		return nil, err
1040	}
1041	pkgDir = filepath.Dir(pkgDir)
1042
1043	// Find candidate packages, looking only at their directory names first.
1044	var candidates []pkgDistance
1045	for _, pkg := range dirScan {
1046		if pkg.dir != pkgDir && pkgIsCandidate(filename, pkgName, pkg) {
1047			candidates = append(candidates, pkgDistance{
1048				pkg:      pkg,
1049				distance: distance(pkgDir, pkg.dir),
1050			})
1051		}
1052	}
1053
1054	// Sort the candidates by their import package length,
1055	// assuming that shorter package names are better than long
1056	// ones.  Note that this sorts by the de-vendored name, so
1057	// there's no "penalty" for vendoring.
1058	sort.Sort(byDistanceOrImportPathShortLength(candidates))
1059	if Debug {
1060		for i, c := range candidates {
1061			log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
1062		}
1063	}
1064
1065	// Collect exports for packages with matching names.
1066
1067	rescv := make([]chan *pkg, len(candidates))
1068	for i := range candidates {
1069		rescv[i] = make(chan *pkg, 1)
1070	}
1071	const maxConcurrentPackageImport = 4
1072	loadExportsSem := make(chan struct{}, maxConcurrentPackageImport)
1073
1074	ctx, cancel := context.WithCancel(ctx)
1075	var wg sync.WaitGroup
1076	defer func() {
1077		cancel()
1078		wg.Wait()
1079	}()
1080
1081	wg.Add(1)
1082	go func() {
1083		defer wg.Done()
1084		for i, c := range candidates {
1085			select {
1086			case loadExportsSem <- struct{}{}:
1087			case <-ctx.Done():
1088				return
1089			}
1090
1091			wg.Add(1)
1092			go func(c pkgDistance, resc chan<- *pkg) {
1093				defer func() {
1094					<-loadExportsSem
1095					wg.Done()
1096				}()
1097
1098				exports, err := loadExports(ctx, env, pkgName, c.pkg)
1099				if err != nil {
1100					if Debug {
1101						log.Printf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
1102					}
1103					resc <- nil
1104					return
1105				}
1106
1107				// If it doesn't have the right
1108				// symbols, send nil to mean no match.
1109				for symbol := range symbols {
1110					if !exports[symbol] {
1111						resc <- nil
1112						return
1113					}
1114				}
1115				resc <- c.pkg
1116			}(c, rescv[i])
1117		}
1118	}()
1119
1120	for _, resc := range rescv {
1121		pkg := <-resc
1122		if pkg == nil {
1123			continue
1124		}
1125		return pkg, nil
1126	}
1127	return nil, nil
1128}
1129
1130// pkgIsCandidate reports whether pkg is a candidate for satisfying the
1131// finding which package pkgIdent in the file named by filename is trying
1132// to refer to.
1133//
1134// This check is purely lexical and is meant to be as fast as possible
1135// because it's run over all $GOPATH directories to filter out poor
1136// candidates in order to limit the CPU and I/O later parsing the
1137// exports in candidate packages.
1138//
1139// filename is the file being formatted.
1140// pkgIdent is the package being searched for, like "client" (if
1141// searching for "client.New")
1142func pkgIsCandidate(filename, pkgIdent string, pkg *pkg) bool {
1143	// Check "internal" and "vendor" visibility:
1144	if !canUse(filename, pkg.dir) {
1145		return false
1146	}
1147
1148	// Speed optimization to minimize disk I/O:
1149	// the last two components on disk must contain the
1150	// package name somewhere.
1151	//
1152	// This permits mismatch naming like directory
1153	// "go-foo" being package "foo", or "pkg.v3" being "pkg",
1154	// or directory "google.golang.org/api/cloudbilling/v1"
1155	// being package "cloudbilling", but doesn't
1156	// permit a directory "foo" to be package
1157	// "bar", which is strongly discouraged
1158	// anyway. There's no reason goimports needs
1159	// to be slow just to accommodate that.
1160	lastTwo := lastTwoComponents(pkg.importPathShort)
1161	if strings.Contains(lastTwo, pkgIdent) {
1162		return true
1163	}
1164	if hasHyphenOrUpperASCII(lastTwo) && !hasHyphenOrUpperASCII(pkgIdent) {
1165		lastTwo = lowerASCIIAndRemoveHyphen(lastTwo)
1166		if strings.Contains(lastTwo, pkgIdent) {
1167			return true
1168		}
1169	}
1170
1171	return false
1172}
1173
1174func hasHyphenOrUpperASCII(s string) bool {
1175	for i := 0; i < len(s); i++ {
1176		b := s[i]
1177		if b == '-' || ('A' <= b && b <= 'Z') {
1178			return true
1179		}
1180	}
1181	return false
1182}
1183
1184func lowerASCIIAndRemoveHyphen(s string) (ret string) {
1185	buf := make([]byte, 0, len(s))
1186	for i := 0; i < len(s); i++ {
1187		b := s[i]
1188		switch {
1189		case b == '-':
1190			continue
1191		case 'A' <= b && b <= 'Z':
1192			buf = append(buf, b+('a'-'A'))
1193		default:
1194			buf = append(buf, b)
1195		}
1196	}
1197	return string(buf)
1198}
1199
1200// canUse reports whether the package in dir is usable from filename,
1201// respecting the Go "internal" and "vendor" visibility rules.
1202func canUse(filename, dir string) bool {
1203	// Fast path check, before any allocations. If it doesn't contain vendor
1204	// or internal, it's not tricky:
1205	// Note that this can false-negative on directories like "notinternal",
1206	// but we check it correctly below. This is just a fast path.
1207	if !strings.Contains(dir, "vendor") && !strings.Contains(dir, "internal") {
1208		return true
1209	}
1210
1211	dirSlash := filepath.ToSlash(dir)
1212	if !strings.Contains(dirSlash, "/vendor/") && !strings.Contains(dirSlash, "/internal/") && !strings.HasSuffix(dirSlash, "/internal") {
1213		return true
1214	}
1215	// Vendor or internal directory only visible from children of parent.
1216	// That means the path from the current directory to the target directory
1217	// can contain ../vendor or ../internal but not ../foo/vendor or ../foo/internal
1218	// or bar/vendor or bar/internal.
1219	// After stripping all the leading ../, the only okay place to see vendor or internal
1220	// is at the very beginning of the path.
1221	absfile, err := filepath.Abs(filename)
1222	if err != nil {
1223		return false
1224	}
1225	absdir, err := filepath.Abs(dir)
1226	if err != nil {
1227		return false
1228	}
1229	rel, err := filepath.Rel(absfile, absdir)
1230	if err != nil {
1231		return false
1232	}
1233	relSlash := filepath.ToSlash(rel)
1234	if i := strings.LastIndex(relSlash, "../"); i >= 0 {
1235		relSlash = relSlash[i+len("../"):]
1236	}
1237	return !strings.Contains(relSlash, "/vendor/") && !strings.Contains(relSlash, "/internal/") && !strings.HasSuffix(relSlash, "/internal")
1238}
1239
1240// lastTwoComponents returns at most the last two path components
1241// of v, using either / or \ as the path separator.
1242func lastTwoComponents(v string) string {
1243	nslash := 0
1244	for i := len(v) - 1; i >= 0; i-- {
1245		if v[i] == '/' || v[i] == '\\' {
1246			nslash++
1247			if nslash == 2 {
1248				return v[i:]
1249			}
1250		}
1251	}
1252	return v
1253}
1254
1255type visitFn func(node ast.Node) ast.Visitor
1256
1257func (fn visitFn) Visit(node ast.Node) ast.Visitor {
1258	return fn(node)
1259}