zsh_completions.go

  1package cobra
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"io"
  7	"os"
  8	"sort"
  9	"strings"
 10)
 11
 12// GenZshCompletionFile generates zsh completion file.
 13func (c *Command) GenZshCompletionFile(filename string) error {
 14	outFile, err := os.Create(filename)
 15	if err != nil {
 16		return err
 17	}
 18	defer outFile.Close()
 19
 20	return c.GenZshCompletion(outFile)
 21}
 22
 23// GenZshCompletion generates a zsh completion file and writes to the passed writer.
 24func (c *Command) GenZshCompletion(w io.Writer) error {
 25	buf := new(bytes.Buffer)
 26
 27	writeHeader(buf, c)
 28	maxDepth := maxDepth(c)
 29	writeLevelMapping(buf, maxDepth)
 30	writeLevelCases(buf, maxDepth, c)
 31
 32	_, err := buf.WriteTo(w)
 33	return err
 34}
 35
 36func writeHeader(w io.Writer, cmd *Command) {
 37	fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
 38}
 39
 40func maxDepth(c *Command) int {
 41	if len(c.Commands()) == 0 {
 42		return 0
 43	}
 44	maxDepthSub := 0
 45	for _, s := range c.Commands() {
 46		subDepth := maxDepth(s)
 47		if subDepth > maxDepthSub {
 48			maxDepthSub = subDepth
 49		}
 50	}
 51	return 1 + maxDepthSub
 52}
 53
 54func writeLevelMapping(w io.Writer, numLevels int) {
 55	fmt.Fprintln(w, `_arguments \`)
 56	for i := 1; i <= numLevels; i++ {
 57		fmt.Fprintf(w, `  '%d: :->level%d' \`, i, i)
 58		fmt.Fprintln(w)
 59	}
 60	fmt.Fprintf(w, `  '%d: :%s'`, numLevels+1, "_files")
 61	fmt.Fprintln(w)
 62}
 63
 64func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
 65	fmt.Fprintln(w, "case $state in")
 66	defer fmt.Fprintln(w, "esac")
 67
 68	for i := 1; i <= maxDepth; i++ {
 69		fmt.Fprintf(w, "  level%d)\n", i)
 70		writeLevel(w, root, i)
 71		fmt.Fprintln(w, "  ;;")
 72	}
 73	fmt.Fprintln(w, "  *)")
 74	fmt.Fprintln(w, "    _arguments '*: :_files'")
 75	fmt.Fprintln(w, "  ;;")
 76}
 77
 78func writeLevel(w io.Writer, root *Command, i int) {
 79	fmt.Fprintf(w, "    case $words[%d] in\n", i)
 80	defer fmt.Fprintln(w, "    esac")
 81
 82	commands := filterByLevel(root, i)
 83	byParent := groupByParent(commands)
 84
 85	// sort the parents to keep a determinist order
 86	parents := make([]string, len(byParent))
 87	j := 0
 88	for parent := range byParent {
 89		parents[j] = parent
 90		j++
 91	}
 92	sort.StringSlice(parents).Sort()
 93
 94	for _, parent := range parents {
 95		names := names(byParent[parent])
 96		fmt.Fprintf(w, "      %s)\n", parent)
 97		fmt.Fprintf(w, "        _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
 98		fmt.Fprintln(w, "      ;;")
 99	}
100	fmt.Fprintln(w, "      *)")
101	fmt.Fprintln(w, "        _arguments '*: :_files'")
102	fmt.Fprintln(w, "      ;;")
103}
104
105func filterByLevel(c *Command, l int) []*Command {
106	cs := make([]*Command, 0)
107	if l == 0 {
108		cs = append(cs, c)
109		return cs
110	}
111	for _, s := range c.Commands() {
112		cs = append(cs, filterByLevel(s, l-1)...)
113	}
114	return cs
115}
116
117func groupByParent(commands []*Command) map[string][]*Command {
118	m := make(map[string][]*Command)
119	for _, c := range commands {
120		parent := c.Parent()
121		if parent == nil {
122			continue
123		}
124		m[parent.Name()] = append(m[parent.Name()], c)
125	}
126	return m
127}
128
129func names(commands []*Command) []string {
130	ns := make([]string, len(commands))
131	for i, c := range commands {
132		ns[i] = c.Name()
133	}
134	return ns
135}