1use futures::channel::oneshot;
2use futures::{select_biased, FutureExt};
3use gpui::{App, Context, Global, Subscription, Task, Window};
4use std::time::Duration;
5use std::{future::Future, pin::Pin, task::Poll};
6
7#[derive(Default)]
8struct FeatureFlags {
9 flags: Vec<String>,
10 staff: bool,
11}
12
13impl FeatureFlags {
14 fn has_flag<T: FeatureFlag>(&self) -> bool {
15 if self.staff && T::enabled_for_staff() {
16 return true;
17 }
18
19 self.flags.iter().any(|f| f.as_str() == T::NAME)
20 }
21}
22
23impl Global for FeatureFlags {}
24
25/// To create a feature flag, implement this trait on a trivial type and use it as
26/// a generic parameter when called [`FeatureFlagAppExt::has_flag`].
27///
28/// Feature flags are enabled for members of Zed staff by default. To disable this behavior
29/// so you can test flags being disabled, set ZED_DISABLE_STAFF=1 in your environment,
30/// which will force Zed to treat the current user as non-staff.
31pub trait FeatureFlag {
32 const NAME: &'static str;
33
34 /// Returns whether this feature flag is enabled for Zed staff.
35 fn enabled_for_staff() -> bool {
36 true
37 }
38}
39
40pub struct Assistant2FeatureFlag;
41
42impl FeatureFlag for Assistant2FeatureFlag {
43 const NAME: &'static str = "assistant2";
44}
45
46pub struct PredictEditsFeatureFlag;
47impl FeatureFlag for PredictEditsFeatureFlag {
48 const NAME: &'static str = "predict-edits";
49}
50
51pub struct PredictEditsRateCompletionsFeatureFlag;
52impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag {
53 const NAME: &'static str = "predict-edits-rate-completions";
54}
55
56pub struct GitUiFeatureFlag;
57impl FeatureFlag for GitUiFeatureFlag {
58 const NAME: &'static str = "git-ui";
59}
60
61pub struct Remoting {}
62impl FeatureFlag for Remoting {
63 const NAME: &'static str = "remoting";
64}
65
66pub struct LanguageModels {}
67impl FeatureFlag for LanguageModels {
68 const NAME: &'static str = "language-models";
69}
70
71pub struct LlmClosedBeta {}
72impl FeatureFlag for LlmClosedBeta {
73 const NAME: &'static str = "llm-closed-beta";
74}
75
76pub struct ZedPro {}
77impl FeatureFlag for ZedPro {
78 const NAME: &'static str = "zed-pro";
79}
80
81pub struct NotebookFeatureFlag;
82
83impl FeatureFlag for NotebookFeatureFlag {
84 const NAME: &'static str = "notebooks";
85}
86
87pub struct AutoCommand {}
88impl FeatureFlag for AutoCommand {
89 const NAME: &'static str = "auto-command";
90
91 fn enabled_for_staff() -> bool {
92 false
93 }
94}
95
96pub trait FeatureFlagViewExt<V: 'static> {
97 fn observe_flag<T: FeatureFlag, F>(&mut self, window: &Window, callback: F) -> Subscription
98 where
99 F: Fn(bool, &mut V, &mut Window, &mut Context<V>) + Send + Sync + 'static;
100}
101
102impl<V> FeatureFlagViewExt<V> for Context<'_, V>
103where
104 V: 'static,
105{
106 fn observe_flag<T: FeatureFlag, F>(&mut self, window: &Window, callback: F) -> Subscription
107 where
108 F: Fn(bool, &mut V, &mut Window, &mut Context<V>) + 'static,
109 {
110 self.observe_global_in::<FeatureFlags>(window, move |v, window, cx| {
111 let feature_flags = cx.global::<FeatureFlags>();
112 callback(feature_flags.has_flag::<T>(), v, window, cx);
113 })
114 }
115}
116
117pub trait FeatureFlagAppExt {
118 fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
119
120 /// Waits for the specified feature flag to resolve, up to the given timeout.
121 fn wait_for_flag_or_timeout<T: FeatureFlag>(&mut self, timeout: Duration) -> Task<bool>;
122
123 fn update_flags(&mut self, staff: bool, flags: Vec<String>);
124 fn set_staff(&mut self, staff: bool);
125 fn has_flag<T: FeatureFlag>(&self) -> bool;
126 fn is_staff(&self) -> bool;
127
128 fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
129 where
130 F: FnMut(bool, &mut App) + 'static;
131}
132
133impl FeatureFlagAppExt for App {
134 fn update_flags(&mut self, staff: bool, flags: Vec<String>) {
135 let feature_flags = self.default_global::<FeatureFlags>();
136 feature_flags.staff = staff;
137 feature_flags.flags = flags;
138 }
139
140 fn set_staff(&mut self, staff: bool) {
141 let feature_flags = self.default_global::<FeatureFlags>();
142 feature_flags.staff = staff;
143 }
144
145 fn has_flag<T: FeatureFlag>(&self) -> bool {
146 self.try_global::<FeatureFlags>()
147 .map(|flags| flags.has_flag::<T>())
148 .unwrap_or(false)
149 }
150
151 fn is_staff(&self) -> bool {
152 self.try_global::<FeatureFlags>()
153 .map(|flags| flags.staff)
154 .unwrap_or(false)
155 }
156
157 fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription
158 where
159 F: FnMut(bool, &mut App) + 'static,
160 {
161 self.observe_global::<FeatureFlags>(move |cx| {
162 let feature_flags = cx.global::<FeatureFlags>();
163 callback(feature_flags.has_flag::<T>(), cx);
164 })
165 }
166
167 fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
168 let (tx, rx) = oneshot::channel::<bool>();
169 let mut tx = Some(tx);
170 let subscription: Option<Subscription>;
171
172 match self.try_global::<FeatureFlags>() {
173 Some(feature_flags) => {
174 subscription = None;
175 tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
176 }
177 None => {
178 subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
179 let feature_flags = cx.global::<FeatureFlags>();
180 if let Some(tx) = tx.take() {
181 tx.send(feature_flags.has_flag::<T>()).ok();
182 }
183 }));
184 }
185 }
186
187 WaitForFlag(rx, subscription)
188 }
189
190 fn wait_for_flag_or_timeout<T: FeatureFlag>(&mut self, timeout: Duration) -> Task<bool> {
191 let wait_for_flag = self.wait_for_flag::<T>();
192
193 self.spawn(|_cx| async move {
194 let mut wait_for_flag = wait_for_flag.fuse();
195 let mut timeout = FutureExt::fuse(smol::Timer::after(timeout));
196
197 select_biased! {
198 is_enabled = wait_for_flag => is_enabled,
199 _ = timeout => false,
200 }
201 })
202 }
203}
204
205pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
206
207impl Future for WaitForFlag {
208 type Output = bool;
209
210 fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
211 self.0.poll_unpin(cx).map(|result| {
212 self.1.take();
213 result.unwrap_or(false)
214 })
215 }
216}