zsh_completions.go

  1package cobra
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"io"
  7	"os"
  8	"strings"
  9)
 10
 11// GenZshCompletionFile generates zsh completion file.
 12func (c *Command) GenZshCompletionFile(filename string) error {
 13	outFile, err := os.Create(filename)
 14	if err != nil {
 15		return err
 16	}
 17	defer outFile.Close()
 18
 19	return c.GenZshCompletion(outFile)
 20}
 21
 22// GenZshCompletion generates a zsh completion file and writes to the passed writer.
 23func (c *Command) GenZshCompletion(w io.Writer) error {
 24	buf := new(bytes.Buffer)
 25
 26	writeHeader(buf, c)
 27	maxDepth := maxDepth(c)
 28	writeLevelMapping(buf, maxDepth)
 29	writeLevelCases(buf, maxDepth, c)
 30
 31	_, err := buf.WriteTo(w)
 32	return err
 33}
 34
 35func writeHeader(w io.Writer, cmd *Command) {
 36	fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
 37}
 38
 39func maxDepth(c *Command) int {
 40	if len(c.Commands()) == 0 {
 41		return 0
 42	}
 43	maxDepthSub := 0
 44	for _, s := range c.Commands() {
 45		subDepth := maxDepth(s)
 46		if subDepth > maxDepthSub {
 47			maxDepthSub = subDepth
 48		}
 49	}
 50	return 1 + maxDepthSub
 51}
 52
 53func writeLevelMapping(w io.Writer, numLevels int) {
 54	fmt.Fprintln(w, `_arguments \`)
 55	for i := 1; i <= numLevels; i++ {
 56		fmt.Fprintf(w, `  '%d: :->level%d' \`, i, i)
 57		fmt.Fprintln(w)
 58	}
 59	fmt.Fprintf(w, `  '%d: :%s'`, numLevels+1, "_files")
 60	fmt.Fprintln(w)
 61}
 62
 63func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
 64	fmt.Fprintln(w, "case $state in")
 65	defer fmt.Fprintln(w, "esac")
 66
 67	for i := 1; i <= maxDepth; i++ {
 68		fmt.Fprintf(w, "  level%d)\n", i)
 69		writeLevel(w, root, i)
 70		fmt.Fprintln(w, "  ;;")
 71	}
 72	fmt.Fprintln(w, "  *)")
 73	fmt.Fprintln(w, "    _arguments '*: :_files'")
 74	fmt.Fprintln(w, "  ;;")
 75}
 76
 77func writeLevel(w io.Writer, root *Command, i int) {
 78	fmt.Fprintf(w, "    case $words[%d] in\n", i)
 79	defer fmt.Fprintln(w, "    esac")
 80
 81	commands := filterByLevel(root, i)
 82	byParent := groupByParent(commands)
 83
 84	for p, c := range byParent {
 85		names := names(c)
 86		fmt.Fprintf(w, "      %s)\n", p)
 87		fmt.Fprintf(w, "        _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
 88		fmt.Fprintln(w, "      ;;")
 89	}
 90	fmt.Fprintln(w, "      *)")
 91	fmt.Fprintln(w, "        _arguments '*: :_files'")
 92	fmt.Fprintln(w, "      ;;")
 93
 94}
 95
 96func filterByLevel(c *Command, l int) []*Command {
 97	cs := make([]*Command, 0)
 98	if l == 0 {
 99		cs = append(cs, c)
100		return cs
101	}
102	for _, s := range c.Commands() {
103		cs = append(cs, filterByLevel(s, l-1)...)
104	}
105	return cs
106}
107
108func groupByParent(commands []*Command) map[string][]*Command {
109	m := make(map[string][]*Command)
110	for _, c := range commands {
111		parent := c.Parent()
112		if parent == nil {
113			continue
114		}
115		m[parent.Name()] = append(m[parent.Name()], c)
116	}
117	return m
118}
119
120func names(commands []*Command) []string {
121	ns := make([]string, len(commands))
122	for i, c := range commands {
123		ns[i] = c.Name()
124	}
125	return ns
126}