imports.go

  1// Copyright 2013 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5//go:generate go run mkstdlib.go
  6
  7// Package imports implements a Go pretty-printer (like package "go/format")
  8// that also adds or removes import statements as necessary.
  9package imports // import "golang.org/x/tools/imports"
 10
 11import (
 12	"bufio"
 13	"bytes"
 14	"fmt"
 15	"go/ast"
 16	"go/build"
 17	"go/format"
 18	"go/parser"
 19	"go/printer"
 20	"go/token"
 21	"io"
 22	"io/ioutil"
 23	"regexp"
 24	"strconv"
 25	"strings"
 26
 27	"golang.org/x/tools/go/ast/astutil"
 28)
 29
 30// Options specifies options for processing files.
 31type Options struct {
 32	Fragment  bool // Accept fragment of a source file (no package statement)
 33	AllErrors bool // Report all errors (not just the first 10 on different lines)
 34
 35	Comments  bool // Print comments (true if nil *Options provided)
 36	TabIndent bool // Use tabs for indent (true if nil *Options provided)
 37	TabWidth  int  // Tab width (8 if nil *Options provided)
 38
 39	FormatOnly bool // Disable the insertion and deletion of imports
 40}
 41
 42// Process formats and adjusts imports for the provided file.
 43// If opt is nil the defaults are used.
 44//
 45// Note that filename's directory influences which imports can be chosen,
 46// so it is important that filename be accurate.
 47// To process data ``as if'' it were in filename, pass the data as a non-nil src.
 48func Process(filename string, src []byte, opt *Options) ([]byte, error) {
 49	env := &fixEnv{GOPATH: build.Default.GOPATH, GOROOT: build.Default.GOROOT}
 50	return process(filename, src, opt, env)
 51}
 52
 53func process(filename string, src []byte, opt *Options, env *fixEnv) ([]byte, error) {
 54	if opt == nil {
 55		opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
 56	}
 57	if src == nil {
 58		b, err := ioutil.ReadFile(filename)
 59		if err != nil {
 60			return nil, err
 61		}
 62		src = b
 63	}
 64
 65	fileSet := token.NewFileSet()
 66	file, adjust, err := parse(fileSet, filename, src, opt)
 67	if err != nil {
 68		return nil, err
 69	}
 70
 71	if !opt.FormatOnly {
 72		if err := fixImports(fileSet, file, filename, env); err != nil {
 73			return nil, err
 74		}
 75	}
 76
 77	sortImports(fileSet, file)
 78	imps := astutil.Imports(fileSet, file)
 79	var spacesBefore []string // import paths we need spaces before
 80	for _, impSection := range imps {
 81		// Within each block of contiguous imports, see if any
 82		// import lines are in different group numbers. If so,
 83		// we'll need to put a space between them so it's
 84		// compatible with gofmt.
 85		lastGroup := -1
 86		for _, importSpec := range impSection {
 87			importPath, _ := strconv.Unquote(importSpec.Path.Value)
 88			groupNum := importGroup(importPath)
 89			if groupNum != lastGroup && lastGroup != -1 {
 90				spacesBefore = append(spacesBefore, importPath)
 91			}
 92			lastGroup = groupNum
 93		}
 94
 95	}
 96
 97	printerMode := printer.UseSpaces
 98	if opt.TabIndent {
 99		printerMode |= printer.TabIndent
100	}
101	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
102
103	var buf bytes.Buffer
104	err = printConfig.Fprint(&buf, fileSet, file)
105	if err != nil {
106		return nil, err
107	}
108	out := buf.Bytes()
109	if adjust != nil {
110		out = adjust(src, out)
111	}
112	if len(spacesBefore) > 0 {
113		out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
114		if err != nil {
115			return nil, err
116		}
117	}
118
119	out, err = format.Source(out)
120	if err != nil {
121		return nil, err
122	}
123	return out, nil
124}
125
126// parse parses src, which was read from filename,
127// as a Go source file or statement list.
128func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
129	parserMode := parser.Mode(0)
130	if opt.Comments {
131		parserMode |= parser.ParseComments
132	}
133	if opt.AllErrors {
134		parserMode |= parser.AllErrors
135	}
136
137	// Try as whole source file.
138	file, err := parser.ParseFile(fset, filename, src, parserMode)
139	if err == nil {
140		return file, nil, nil
141	}
142	// If the error is that the source file didn't begin with a
143	// package line and we accept fragmented input, fall through to
144	// try as a source fragment.  Stop and return on any other error.
145	if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
146		return nil, nil, err
147	}
148
149	// If this is a declaration list, make it a source file
150	// by inserting a package clause.
151	// Insert using a ;, not a newline, so that parse errors are on
152	// the correct line.
153	const prefix = "package main;"
154	psrc := append([]byte(prefix), src...)
155	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
156	if err == nil {
157		// Gofmt will turn the ; into a \n.
158		// Do that ourselves now and update the file contents,
159		// so that positions and line numbers are correct going forward.
160		psrc[len(prefix)-1] = '\n'
161		fset.File(file.Package).SetLinesForContent(psrc)
162
163		// If a main function exists, we will assume this is a main
164		// package and leave the file.
165		if containsMainFunc(file) {
166			return file, nil, nil
167		}
168
169		adjust := func(orig, src []byte) []byte {
170			// Remove the package clause.
171			src = src[len(prefix):]
172			return matchSpace(orig, src)
173		}
174		return file, adjust, nil
175	}
176	// If the error is that the source file didn't begin with a
177	// declaration, fall through to try as a statement list.
178	// Stop and return on any other error.
179	if !strings.Contains(err.Error(), "expected declaration") {
180		return nil, nil, err
181	}
182
183	// If this is a statement list, make it a source file
184	// by inserting a package clause and turning the list
185	// into a function body.  This handles expressions too.
186	// Insert using a ;, not a newline, so that the line numbers
187	// in fsrc match the ones in src.
188	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
189	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
190	if err == nil {
191		adjust := func(orig, src []byte) []byte {
192			// Remove the wrapping.
193			// Gofmt has turned the ; into a \n\n.
194			src = src[len("package p\n\nfunc _() {"):]
195			src = src[:len(src)-len("}\n")]
196			// Gofmt has also indented the function body one level.
197			// Remove that indent.
198			src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
199			return matchSpace(orig, src)
200		}
201		return file, adjust, nil
202	}
203
204	// Failed, and out of options.
205	return nil, nil, err
206}
207
208// containsMainFunc checks if a file contains a function declaration with the
209// function signature 'func main()'
210func containsMainFunc(file *ast.File) bool {
211	for _, decl := range file.Decls {
212		if f, ok := decl.(*ast.FuncDecl); ok {
213			if f.Name.Name != "main" {
214				continue
215			}
216
217			if len(f.Type.Params.List) != 0 {
218				continue
219			}
220
221			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
222				continue
223			}
224
225			return true
226		}
227	}
228
229	return false
230}
231
232func cutSpace(b []byte) (before, middle, after []byte) {
233	i := 0
234	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
235		i++
236	}
237	j := len(b)
238	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
239		j--
240	}
241	if i <= j {
242		return b[:i], b[i:j], b[j:]
243	}
244	return nil, nil, b[j:]
245}
246
247// matchSpace reformats src to use the same space context as orig.
248// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
249// 2) matchSpace copies the indentation of the first non-blank line in orig
250//    to every non-blank line in src.
251// 3) matchSpace copies the trailing space from orig and uses it in place
252//   of src's trailing space.
253func matchSpace(orig []byte, src []byte) []byte {
254	before, _, after := cutSpace(orig)
255	i := bytes.LastIndex(before, []byte{'\n'})
256	before, indent := before[:i+1], before[i+1:]
257
258	_, src, _ = cutSpace(src)
259
260	var b bytes.Buffer
261	b.Write(before)
262	for len(src) > 0 {
263		line := src
264		if i := bytes.IndexByte(line, '\n'); i >= 0 {
265			line, src = line[:i+1], line[i+1:]
266		} else {
267			src = nil
268		}
269		if len(line) > 0 && line[0] != '\n' { // not blank
270			b.Write(indent)
271		}
272		b.Write(line)
273	}
274	b.Write(after)
275	return b.Bytes()
276}
277
278var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+)"`)
279
280func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
281	var out bytes.Buffer
282	in := bufio.NewReader(r)
283	inImports := false
284	done := false
285	for {
286		s, err := in.ReadString('\n')
287		if err == io.EOF {
288			break
289		} else if err != nil {
290			return nil, err
291		}
292
293		if !inImports && !done && strings.HasPrefix(s, "import") {
294			inImports = true
295		}
296		if inImports && (strings.HasPrefix(s, "var") ||
297			strings.HasPrefix(s, "func") ||
298			strings.HasPrefix(s, "const") ||
299			strings.HasPrefix(s, "type")) {
300			done = true
301			inImports = false
302		}
303		if inImports && len(breaks) > 0 {
304			if m := impLine.FindStringSubmatch(s); m != nil {
305				if m[1] == breaks[0] {
306					out.WriteByte('\n')
307					breaks = breaks[1:]
308				}
309			}
310		}
311
312		fmt.Fprint(&out, s)
313	}
314	return out.Bytes(), nil
315}