imports.go

  1// Copyright 2013 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5//go:generate go run mkstdlib.go
  6
  7// Package imports implements a Go pretty-printer (like package "go/format")
  8// that also adds or removes import statements as necessary.
  9package imports // import "golang.org/x/tools/imports"
 10
 11import (
 12	"bufio"
 13	"bytes"
 14	"fmt"
 15	"go/ast"
 16	"go/format"
 17	"go/parser"
 18	"go/printer"
 19	"go/token"
 20	"io"
 21	"io/ioutil"
 22	"regexp"
 23	"strconv"
 24	"strings"
 25
 26	"golang.org/x/tools/go/ast/astutil"
 27)
 28
 29// Options specifies options for processing files.
 30type Options struct {
 31	Fragment  bool // Accept fragment of a source file (no package statement)
 32	AllErrors bool // Report all errors (not just the first 10 on different lines)
 33
 34	Comments  bool // Print comments (true if nil *Options provided)
 35	TabIndent bool // Use tabs for indent (true if nil *Options provided)
 36	TabWidth  int  // Tab width (8 if nil *Options provided)
 37
 38	FormatOnly bool // Disable the insertion and deletion of imports
 39}
 40
 41// Process formats and adjusts imports for the provided file.
 42// If opt is nil the defaults are used.
 43//
 44// Note that filename's directory influences which imports can be chosen,
 45// so it is important that filename be accurate.
 46// To process data ``as if'' it were in filename, pass the data as a non-nil src.
 47func Process(filename string, src []byte, opt *Options) ([]byte, error) {
 48	if opt == nil {
 49		opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
 50	}
 51	if src == nil {
 52		b, err := ioutil.ReadFile(filename)
 53		if err != nil {
 54			return nil, err
 55		}
 56		src = b
 57	}
 58
 59	fileSet := token.NewFileSet()
 60	file, adjust, err := parse(fileSet, filename, src, opt)
 61	if err != nil {
 62		return nil, err
 63	}
 64
 65	if !opt.FormatOnly {
 66		_, err = fixImports(fileSet, file, filename)
 67		if err != nil {
 68			return nil, err
 69		}
 70	}
 71
 72	sortImports(fileSet, file)
 73	imps := astutil.Imports(fileSet, file)
 74	var spacesBefore []string // import paths we need spaces before
 75	for _, impSection := range imps {
 76		// Within each block of contiguous imports, see if any
 77		// import lines are in different group numbers. If so,
 78		// we'll need to put a space between them so it's
 79		// compatible with gofmt.
 80		lastGroup := -1
 81		for _, importSpec := range impSection {
 82			importPath, _ := strconv.Unquote(importSpec.Path.Value)
 83			groupNum := importGroup(importPath)
 84			if groupNum != lastGroup && lastGroup != -1 {
 85				spacesBefore = append(spacesBefore, importPath)
 86			}
 87			lastGroup = groupNum
 88		}
 89
 90	}
 91
 92	printerMode := printer.UseSpaces
 93	if opt.TabIndent {
 94		printerMode |= printer.TabIndent
 95	}
 96	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
 97
 98	var buf bytes.Buffer
 99	err = printConfig.Fprint(&buf, fileSet, file)
100	if err != nil {
101		return nil, err
102	}
103	out := buf.Bytes()
104	if adjust != nil {
105		out = adjust(src, out)
106	}
107	if len(spacesBefore) > 0 {
108		out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
109		if err != nil {
110			return nil, err
111		}
112	}
113
114	out, err = format.Source(out)
115	if err != nil {
116		return nil, err
117	}
118	return out, nil
119}
120
121// parse parses src, which was read from filename,
122// as a Go source file or statement list.
123func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
124	parserMode := parser.Mode(0)
125	if opt.Comments {
126		parserMode |= parser.ParseComments
127	}
128	if opt.AllErrors {
129		parserMode |= parser.AllErrors
130	}
131
132	// Try as whole source file.
133	file, err := parser.ParseFile(fset, filename, src, parserMode)
134	if err == nil {
135		return file, nil, nil
136	}
137	// If the error is that the source file didn't begin with a
138	// package line and we accept fragmented input, fall through to
139	// try as a source fragment.  Stop and return on any other error.
140	if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
141		return nil, nil, err
142	}
143
144	// If this is a declaration list, make it a source file
145	// by inserting a package clause.
146	// Insert using a ;, not a newline, so that parse errors are on
147	// the correct line.
148	const prefix = "package main;"
149	psrc := append([]byte(prefix), src...)
150	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
151	if err == nil {
152		// Gofmt will turn the ; into a \n.
153		// Do that ourselves now and update the file contents,
154		// so that positions and line numbers are correct going forward.
155		psrc[len(prefix)-1] = '\n'
156		fset.File(file.Package).SetLinesForContent(psrc)
157
158		// If a main function exists, we will assume this is a main
159		// package and leave the file.
160		if containsMainFunc(file) {
161			return file, nil, nil
162		}
163
164		adjust := func(orig, src []byte) []byte {
165			// Remove the package clause.
166			src = src[len(prefix):]
167			return matchSpace(orig, src)
168		}
169		return file, adjust, nil
170	}
171	// If the error is that the source file didn't begin with a
172	// declaration, fall through to try as a statement list.
173	// Stop and return on any other error.
174	if !strings.Contains(err.Error(), "expected declaration") {
175		return nil, nil, err
176	}
177
178	// If this is a statement list, make it a source file
179	// by inserting a package clause and turning the list
180	// into a function body.  This handles expressions too.
181	// Insert using a ;, not a newline, so that the line numbers
182	// in fsrc match the ones in src.
183	fsrc := append(append([]byte("package p; func _() {"), src...), '}')
184	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
185	if err == nil {
186		adjust := func(orig, src []byte) []byte {
187			// Remove the wrapping.
188			// Gofmt has turned the ; into a \n\n.
189			src = src[len("package p\n\nfunc _() {"):]
190			src = src[:len(src)-len("}\n")]
191			// Gofmt has also indented the function body one level.
192			// Remove that indent.
193			src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
194			return matchSpace(orig, src)
195		}
196		return file, adjust, nil
197	}
198
199	// Failed, and out of options.
200	return nil, nil, err
201}
202
203// containsMainFunc checks if a file contains a function declaration with the
204// function signature 'func main()'
205func containsMainFunc(file *ast.File) bool {
206	for _, decl := range file.Decls {
207		if f, ok := decl.(*ast.FuncDecl); ok {
208			if f.Name.Name != "main" {
209				continue
210			}
211
212			if len(f.Type.Params.List) != 0 {
213				continue
214			}
215
216			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
217				continue
218			}
219
220			return true
221		}
222	}
223
224	return false
225}
226
227func cutSpace(b []byte) (before, middle, after []byte) {
228	i := 0
229	for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
230		i++
231	}
232	j := len(b)
233	for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
234		j--
235	}
236	if i <= j {
237		return b[:i], b[i:j], b[j:]
238	}
239	return nil, nil, b[j:]
240}
241
242// matchSpace reformats src to use the same space context as orig.
243// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
244// 2) matchSpace copies the indentation of the first non-blank line in orig
245//    to every non-blank line in src.
246// 3) matchSpace copies the trailing space from orig and uses it in place
247//   of src's trailing space.
248func matchSpace(orig []byte, src []byte) []byte {
249	before, _, after := cutSpace(orig)
250	i := bytes.LastIndex(before, []byte{'\n'})
251	before, indent := before[:i+1], before[i+1:]
252
253	_, src, _ = cutSpace(src)
254
255	var b bytes.Buffer
256	b.Write(before)
257	for len(src) > 0 {
258		line := src
259		if i := bytes.IndexByte(line, '\n'); i >= 0 {
260			line, src = line[:i+1], line[i+1:]
261		} else {
262			src = nil
263		}
264		if len(line) > 0 && line[0] != '\n' { // not blank
265			b.Write(indent)
266		}
267		b.Write(line)
268	}
269	b.Write(after)
270	return b.Bytes()
271}
272
273var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+)"`)
274
275func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
276	var out bytes.Buffer
277	in := bufio.NewReader(r)
278	inImports := false
279	done := false
280	for {
281		s, err := in.ReadString('\n')
282		if err == io.EOF {
283			break
284		} else if err != nil {
285			return nil, err
286		}
287
288		if !inImports && !done && strings.HasPrefix(s, "import") {
289			inImports = true
290		}
291		if inImports && (strings.HasPrefix(s, "var") ||
292			strings.HasPrefix(s, "func") ||
293			strings.HasPrefix(s, "const") ||
294			strings.HasPrefix(s, "type")) {
295			done = true
296			inImports = false
297		}
298		if inImports && len(breaks) > 0 {
299			if m := impLine.FindStringSubmatch(s); m != nil {
300				if m[1] == breaks[0] {
301					out.WriteByte('\n')
302					breaks = breaks[1:]
303				}
304			}
305		}
306
307		fmt.Fprint(&out, s)
308	}
309	return out.Bytes(), nil
310}