flag_groups.go

  1// Copyright 2013-2023 The Cobra Authors
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package cobra
 16
 17import (
 18	"fmt"
 19	"sort"
 20	"strings"
 21
 22	flag "github.com/spf13/pflag"
 23)
 24
 25const (
 26	requiredAsGroupAnnotation   = "cobra_annotation_required_if_others_set"
 27	oneRequiredAnnotation       = "cobra_annotation_one_required"
 28	mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
 29)
 30
 31// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
 32// if the command is invoked with a subset (but not all) of the given flags.
 33func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
 34	c.mergePersistentFlags()
 35	for _, v := range flagNames {
 36		f := c.Flags().Lookup(v)
 37		if f == nil {
 38			panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
 39		}
 40		if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
 41			// Only errs if the flag isn't found.
 42			panic(err)
 43		}
 44	}
 45}
 46
 47// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
 48// if the command is invoked without at least one flag from the given set of flags.
 49func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
 50	c.mergePersistentFlags()
 51	for _, v := range flagNames {
 52		f := c.Flags().Lookup(v)
 53		if f == nil {
 54			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
 55		}
 56		if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
 57			// Only errs if the flag isn't found.
 58			panic(err)
 59		}
 60	}
 61}
 62
 63// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
 64// if the command is invoked with more than one flag from the given set of flags.
 65func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
 66	c.mergePersistentFlags()
 67	for _, v := range flagNames {
 68		f := c.Flags().Lookup(v)
 69		if f == nil {
 70			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
 71		}
 72		// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
 73		if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
 74			panic(err)
 75		}
 76	}
 77}
 78
 79// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
 80// first error encountered.
 81func (c *Command) ValidateFlagGroups() error {
 82	if c.DisableFlagParsing {
 83		return nil
 84	}
 85
 86	flags := c.Flags()
 87
 88	// groupStatus format is the list of flags as a unique ID,
 89	// then a map of each flag name and whether it is set or not.
 90	groupStatus := map[string]map[string]bool{}
 91	oneRequiredGroupStatus := map[string]map[string]bool{}
 92	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
 93	flags.VisitAll(func(pflag *flag.Flag) {
 94		processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
 95		processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
 96		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
 97	})
 98
 99	if err := validateRequiredFlagGroups(groupStatus); err != nil {
100		return err
101	}
102	if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
103		return err
104	}
105	if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
106		return err
107	}
108	return nil
109}
110
111func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
112	for _, fname := range flagnames {
113		f := fs.Lookup(fname)
114		if f == nil {
115			return false
116		}
117	}
118	return true
119}
120
121func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
122	groupInfo, found := pflag.Annotations[annotation]
123	if found {
124		for _, group := range groupInfo {
125			if groupStatus[group] == nil {
126				flagnames := strings.Split(group, " ")
127
128				// Only consider this flag group at all if all the flags are defined.
129				if !hasAllFlags(flags, flagnames...) {
130					continue
131				}
132
133				groupStatus[group] = make(map[string]bool, len(flagnames))
134				for _, name := range flagnames {
135					groupStatus[group][name] = false
136				}
137			}
138
139			groupStatus[group][pflag.Name] = pflag.Changed
140		}
141	}
142}
143
144func validateRequiredFlagGroups(data map[string]map[string]bool) error {
145	keys := sortedKeys(data)
146	for _, flagList := range keys {
147		flagnameAndStatus := data[flagList]
148
149		unset := []string{}
150		for flagname, isSet := range flagnameAndStatus {
151			if !isSet {
152				unset = append(unset, flagname)
153			}
154		}
155		if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
156			continue
157		}
158
159		// Sort values, so they can be tested/scripted against consistently.
160		sort.Strings(unset)
161		return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
162	}
163
164	return nil
165}
166
167func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
168	keys := sortedKeys(data)
169	for _, flagList := range keys {
170		flagnameAndStatus := data[flagList]
171		var set []string
172		for flagname, isSet := range flagnameAndStatus {
173			if isSet {
174				set = append(set, flagname)
175			}
176		}
177		if len(set) >= 1 {
178			continue
179		}
180
181		// Sort values, so they can be tested/scripted against consistently.
182		sort.Strings(set)
183		return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
184	}
185	return nil
186}
187
188func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
189	keys := sortedKeys(data)
190	for _, flagList := range keys {
191		flagnameAndStatus := data[flagList]
192		var set []string
193		for flagname, isSet := range flagnameAndStatus {
194			if isSet {
195				set = append(set, flagname)
196			}
197		}
198		if len(set) == 0 || len(set) == 1 {
199			continue
200		}
201
202		// Sort values, so they can be tested/scripted against consistently.
203		sort.Strings(set)
204		return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
205	}
206	return nil
207}
208
209func sortedKeys(m map[string]map[string]bool) []string {
210	keys := make([]string, len(m))
211	i := 0
212	for k := range m {
213		keys[i] = k
214		i++
215	}
216	sort.Strings(keys)
217	return keys
218}
219
220// enforceFlagGroupsForCompletion will do the following:
221// - when a flag in a group is present, other flags in the group will be marked required
222// - when none of the flags in a one-required group are present, all flags in the group will be marked required
223// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
224// This allows the standard completion logic to behave appropriately for flag groups
225func (c *Command) enforceFlagGroupsForCompletion() {
226	if c.DisableFlagParsing {
227		return
228	}
229
230	flags := c.Flags()
231	groupStatus := map[string]map[string]bool{}
232	oneRequiredGroupStatus := map[string]map[string]bool{}
233	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
234	c.Flags().VisitAll(func(pflag *flag.Flag) {
235		processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
236		processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
237		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
238	})
239
240	// If a flag that is part of a group is present, we make all the other flags
241	// of that group required so that the shell completion suggests them automatically
242	for flagList, flagnameAndStatus := range groupStatus {
243		for _, isSet := range flagnameAndStatus {
244			if isSet {
245				// One of the flags of the group is set, mark the other ones as required
246				for _, fName := range strings.Split(flagList, " ") {
247					_ = c.MarkFlagRequired(fName)
248				}
249			}
250		}
251	}
252
253	// If none of the flags of a one-required group are present, we make all the flags
254	// of that group required so that the shell completion suggests them automatically
255	for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
256		isSet := false
257
258		for _, isSet = range flagnameAndStatus {
259			if isSet {
260				break
261			}
262		}
263
264		// None of the flags of the group are set, mark all flags in the group
265		// as required
266		if !isSet {
267			for _, fName := range strings.Split(flagList, " ") {
268				_ = c.MarkFlagRequired(fName)
269			}
270		}
271	}
272
273	// If a flag that is mutually exclusive to others is present, we hide the other
274	// flags of that group so the shell completion does not suggest them
275	for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
276		for flagName, isSet := range flagnameAndStatus {
277			if isSet {
278				// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
279				// Don't mark the flag that is already set as hidden because it may be an
280				// array or slice flag and therefore must continue being suggested
281				for _, fName := range strings.Split(flagList, " ") {
282					if fName != flagName {
283						flag := c.Flags().Lookup(fName)
284						flag.Hidden = true
285					}
286				}
287			}
288		}
289	}
290}