1package cobra
2
3import (
4 "bytes"
5 "fmt"
6 "io"
7 "os"
8 "sort"
9 "strings"
10)
11
12// GenZshCompletionFile generates zsh completion file.
13func (c *Command) GenZshCompletionFile(filename string) error {
14 outFile, err := os.Create(filename)
15 if err != nil {
16 return err
17 }
18 defer outFile.Close()
19
20 return c.GenZshCompletion(outFile)
21}
22
23// GenZshCompletion generates a zsh completion file and writes to the passed writer.
24func (c *Command) GenZshCompletion(w io.Writer) error {
25 buf := new(bytes.Buffer)
26
27 writeHeader(buf, c)
28 maxDepth := maxDepth(c)
29 writeLevelMapping(buf, maxDepth)
30 writeLevelCases(buf, maxDepth, c)
31
32 _, err := buf.WriteTo(w)
33 return err
34}
35
36func writeHeader(w io.Writer, cmd *Command) {
37 fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
38}
39
40func maxDepth(c *Command) int {
41 if len(c.Commands()) == 0 {
42 return 0
43 }
44 maxDepthSub := 0
45 for _, s := range c.Commands() {
46 subDepth := maxDepth(s)
47 if subDepth > maxDepthSub {
48 maxDepthSub = subDepth
49 }
50 }
51 return 1 + maxDepthSub
52}
53
54func writeLevelMapping(w io.Writer, numLevels int) {
55 fmt.Fprintln(w, `_arguments \`)
56 for i := 1; i <= numLevels; i++ {
57 fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
58 fmt.Fprintln(w)
59 }
60 fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
61 fmt.Fprintln(w)
62}
63
64func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
65 fmt.Fprintln(w, "case $state in")
66 defer fmt.Fprintln(w, "esac")
67
68 for i := 1; i <= maxDepth; i++ {
69 fmt.Fprintf(w, " level%d)\n", i)
70 writeLevel(w, root, i)
71 fmt.Fprintln(w, " ;;")
72 }
73 fmt.Fprintln(w, " *)")
74 fmt.Fprintln(w, " _arguments '*: :_files'")
75 fmt.Fprintln(w, " ;;")
76}
77
78func writeLevel(w io.Writer, root *Command, i int) {
79 fmt.Fprintf(w, " case $words[%d] in\n", i)
80 defer fmt.Fprintln(w, " esac")
81
82 commands := filterByLevel(root, i)
83 byParent := groupByParent(commands)
84
85 // sort the parents to keep a determinist order
86 parents := make([]string, len(byParent))
87 j := 0
88 for parent := range byParent {
89 parents[j] = parent
90 j++
91 }
92 sort.StringSlice(parents).Sort()
93
94 for _, parent := range parents {
95 names := names(byParent[parent])
96 fmt.Fprintf(w, " %s)\n", parent)
97 fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
98 fmt.Fprintln(w, " ;;")
99 }
100 fmt.Fprintln(w, " *)")
101 fmt.Fprintln(w, " _arguments '*: :_files'")
102 fmt.Fprintln(w, " ;;")
103}
104
105func filterByLevel(c *Command, l int) []*Command {
106 cs := make([]*Command, 0)
107 if l == 0 {
108 cs = append(cs, c)
109 return cs
110 }
111 for _, s := range c.Commands() {
112 cs = append(cs, filterByLevel(s, l-1)...)
113 }
114 return cs
115}
116
117func groupByParent(commands []*Command) map[string][]*Command {
118 m := make(map[string][]*Command)
119 for _, c := range commands {
120 parent := c.Parent()
121 if parent == nil {
122 continue
123 }
124 m[parent.Name()] = append(m[parent.Name()], c)
125 }
126 return m
127}
128
129func names(commands []*Command) []string {
130 ns := make([]string, len(commands))
131 for i, c := range commands {
132 ns[i] = c.Name()
133 }
134 return ns
135}