1// Copyright 2013-2023 The Cobra Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package cobra
16
17import (
18 "fmt"
19 "sort"
20 "strings"
21
22 flag "github.com/spf13/pflag"
23)
24
25const (
26 requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
27 oneRequiredAnnotation = "cobra_annotation_one_required"
28 mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
29)
30
31// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
32// if the command is invoked with a subset (but not all) of the given flags.
33func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
34 c.mergePersistentFlags()
35 for _, v := range flagNames {
36 f := c.Flags().Lookup(v)
37 if f == nil {
38 panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
39 }
40 if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
41 // Only errs if the flag isn't found.
42 panic(err)
43 }
44 }
45}
46
47// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
48// if the command is invoked without at least one flag from the given set of flags.
49func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
50 c.mergePersistentFlags()
51 for _, v := range flagNames {
52 f := c.Flags().Lookup(v)
53 if f == nil {
54 panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
55 }
56 if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
57 // Only errs if the flag isn't found.
58 panic(err)
59 }
60 }
61}
62
63// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
64// if the command is invoked with more than one flag from the given set of flags.
65func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
66 c.mergePersistentFlags()
67 for _, v := range flagNames {
68 f := c.Flags().Lookup(v)
69 if f == nil {
70 panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
71 }
72 // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
73 if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
74 panic(err)
75 }
76 }
77}
78
79// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
80// first error encountered.
81func (c *Command) ValidateFlagGroups() error {
82 if c.DisableFlagParsing {
83 return nil
84 }
85
86 flags := c.Flags()
87
88 // groupStatus format is the list of flags as a unique ID,
89 // then a map of each flag name and whether it is set or not.
90 groupStatus := map[string]map[string]bool{}
91 oneRequiredGroupStatus := map[string]map[string]bool{}
92 mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
93 flags.VisitAll(func(pflag *flag.Flag) {
94 processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
95 processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
96 processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
97 })
98
99 if err := validateRequiredFlagGroups(groupStatus); err != nil {
100 return err
101 }
102 if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
103 return err
104 }
105 if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
106 return err
107 }
108 return nil
109}
110
111func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
112 for _, fname := range flagnames {
113 f := fs.Lookup(fname)
114 if f == nil {
115 return false
116 }
117 }
118 return true
119}
120
121func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
122 groupInfo, found := pflag.Annotations[annotation]
123 if found {
124 for _, group := range groupInfo {
125 if groupStatus[group] == nil {
126 flagnames := strings.Split(group, " ")
127
128 // Only consider this flag group at all if all the flags are defined.
129 if !hasAllFlags(flags, flagnames...) {
130 continue
131 }
132
133 groupStatus[group] = make(map[string]bool, len(flagnames))
134 for _, name := range flagnames {
135 groupStatus[group][name] = false
136 }
137 }
138
139 groupStatus[group][pflag.Name] = pflag.Changed
140 }
141 }
142}
143
144func validateRequiredFlagGroups(data map[string]map[string]bool) error {
145 keys := sortedKeys(data)
146 for _, flagList := range keys {
147 flagnameAndStatus := data[flagList]
148
149 unset := []string{}
150 for flagname, isSet := range flagnameAndStatus {
151 if !isSet {
152 unset = append(unset, flagname)
153 }
154 }
155 if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
156 continue
157 }
158
159 // Sort values, so they can be tested/scripted against consistently.
160 sort.Strings(unset)
161 return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
162 }
163
164 return nil
165}
166
167func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
168 keys := sortedKeys(data)
169 for _, flagList := range keys {
170 flagnameAndStatus := data[flagList]
171 var set []string
172 for flagname, isSet := range flagnameAndStatus {
173 if isSet {
174 set = append(set, flagname)
175 }
176 }
177 if len(set) >= 1 {
178 continue
179 }
180
181 // Sort values, so they can be tested/scripted against consistently.
182 sort.Strings(set)
183 return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
184 }
185 return nil
186}
187
188func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
189 keys := sortedKeys(data)
190 for _, flagList := range keys {
191 flagnameAndStatus := data[flagList]
192 var set []string
193 for flagname, isSet := range flagnameAndStatus {
194 if isSet {
195 set = append(set, flagname)
196 }
197 }
198 if len(set) == 0 || len(set) == 1 {
199 continue
200 }
201
202 // Sort values, so they can be tested/scripted against consistently.
203 sort.Strings(set)
204 return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
205 }
206 return nil
207}
208
209func sortedKeys(m map[string]map[string]bool) []string {
210 keys := make([]string, len(m))
211 i := 0
212 for k := range m {
213 keys[i] = k
214 i++
215 }
216 sort.Strings(keys)
217 return keys
218}
219
220// enforceFlagGroupsForCompletion will do the following:
221// - when a flag in a group is present, other flags in the group will be marked required
222// - when none of the flags in a one-required group are present, all flags in the group will be marked required
223// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
224// This allows the standard completion logic to behave appropriately for flag groups
225func (c *Command) enforceFlagGroupsForCompletion() {
226 if c.DisableFlagParsing {
227 return
228 }
229
230 flags := c.Flags()
231 groupStatus := map[string]map[string]bool{}
232 oneRequiredGroupStatus := map[string]map[string]bool{}
233 mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
234 c.Flags().VisitAll(func(pflag *flag.Flag) {
235 processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
236 processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
237 processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
238 })
239
240 // If a flag that is part of a group is present, we make all the other flags
241 // of that group required so that the shell completion suggests them automatically
242 for flagList, flagnameAndStatus := range groupStatus {
243 for _, isSet := range flagnameAndStatus {
244 if isSet {
245 // One of the flags of the group is set, mark the other ones as required
246 for _, fName := range strings.Split(flagList, " ") {
247 _ = c.MarkFlagRequired(fName)
248 }
249 }
250 }
251 }
252
253 // If none of the flags of a one-required group are present, we make all the flags
254 // of that group required so that the shell completion suggests them automatically
255 for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
256 isSet := false
257
258 for _, isSet = range flagnameAndStatus {
259 if isSet {
260 break
261 }
262 }
263
264 // None of the flags of the group are set, mark all flags in the group
265 // as required
266 if !isSet {
267 for _, fName := range strings.Split(flagList, " ") {
268 _ = c.MarkFlagRequired(fName)
269 }
270 }
271 }
272
273 // If a flag that is mutually exclusive to others is present, we hide the other
274 // flags of that group so the shell completion does not suggest them
275 for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
276 for flagName, isSet := range flagnameAndStatus {
277 if isSet {
278 // One of the flags of the mutually exclusive group is set, mark the other ones as hidden
279 // Don't mark the flag that is already set as hidden because it may be an
280 // array or slice flag and therefore must continue being suggested
281 for _, fName := range strings.Split(flagList, " ") {
282 if fName != flagName {
283 flag := c.Flags().Lookup(fName)
284 flag.Hidden = true
285 }
286 }
287 }
288 }
289 }
290}