feature_flags.rs

  1use futures::{channel::oneshot, FutureExt as _};
  2use gpui::{App, Context, Global, Subscription, Window};
  3use std::{future::Future, pin::Pin, task::Poll};
  4
  5#[derive(Default)]
  6struct FeatureFlags {
  7    flags: Vec<String>,
  8    staff: bool,
  9}
 10
 11impl FeatureFlags {
 12    fn has_flag<T: FeatureFlag>(&self) -> bool {
 13        if self.staff && T::enabled_for_staff() {
 14            return true;
 15        }
 16
 17        self.flags.iter().any(|f| f.as_str() == T::NAME)
 18    }
 19}
 20
 21impl Global for FeatureFlags {}
 22
 23/// To create a feature flag, implement this trait on a trivial type and use it as
 24/// a generic parameter when called [`FeatureFlagAppExt::has_flag`].
 25///
 26/// Feature flags are enabled for members of Zed staff by default. To disable this behavior
 27/// so you can test flags being disabled, set ZED_DISABLE_STAFF=1 in your environment,
 28/// which will force Zed to treat the current user as non-staff.
 29pub trait FeatureFlag {
 30    const NAME: &'static str;
 31
 32    /// Returns whether this feature flag is enabled for Zed staff.
 33    fn enabled_for_staff() -> bool {
 34        true
 35    }
 36}
 37
 38pub struct Assistant2FeatureFlag;
 39
 40impl FeatureFlag for Assistant2FeatureFlag {
 41    const NAME: &'static str = "assistant2";
 42}
 43
 44pub struct PredictEditsFeatureFlag;
 45impl FeatureFlag for PredictEditsFeatureFlag {
 46    const NAME: &'static str = "predict-edits";
 47}
 48
 49pub struct PredictEditsRateCompletionsFeatureFlag;
 50impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag {
 51    const NAME: &'static str = "predict-edits-rate-completions";
 52}
 53
 54pub struct GitUiFeatureFlag;
 55impl FeatureFlag for GitUiFeatureFlag {
 56    const NAME: &'static str = "git-ui";
 57}
 58
 59pub struct Remoting {}
 60impl FeatureFlag for Remoting {
 61    const NAME: &'static str = "remoting";
 62}
 63
 64pub struct LanguageModels {}
 65impl FeatureFlag for LanguageModels {
 66    const NAME: &'static str = "language-models";
 67}
 68
 69pub struct LlmClosedBeta {}
 70impl FeatureFlag for LlmClosedBeta {
 71    const NAME: &'static str = "llm-closed-beta";
 72}
 73
 74pub struct ZedPro {}
 75impl FeatureFlag for ZedPro {
 76    const NAME: &'static str = "zed-pro";
 77}
 78
 79pub struct NotebookFeatureFlag;
 80
 81impl FeatureFlag for NotebookFeatureFlag {
 82    const NAME: &'static str = "notebooks";
 83}
 84
 85pub struct AutoCommand {}
 86impl FeatureFlag for AutoCommand {
 87    const NAME: &'static str = "auto-command";
 88
 89    fn enabled_for_staff() -> bool {
 90        false
 91    }
 92}
 93
 94pub trait FeatureFlagViewExt<V: 'static> {
 95    fn observe_flag<T: FeatureFlag, F>(&mut self, window: &Window, callback: F) -> Subscription
 96    where
 97        F: Fn(bool, &mut V, &mut Window, &mut Context<V>) + Send + Sync + 'static;
 98}
 99
100impl<V> FeatureFlagViewExt<V> for Context<'_, V>
101where
102    V: 'static,
103{
104    fn observe_flag<T: FeatureFlag, F>(&mut self, window: &Window, callback: F) -> Subscription
105    where
106        F: Fn(bool, &mut V, &mut Window, &mut Context<V>) + 'static,
107    {
108        self.observe_global_in::<FeatureFlags>(window, move |v, window, cx| {
109            let feature_flags = cx.global::<FeatureFlags>();
110            callback(feature_flags.has_flag::<T>(), v, window, cx);
111        })
112    }
113}
114
115pub trait FeatureFlagAppExt {
116    fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
117    fn update_flags(&mut self, staff: bool, flags: Vec<String>);
118    fn set_staff(&mut self, staff: bool);
119    fn has_flag<T: FeatureFlag>(&self) -> bool;
120    fn is_staff(&self) -> bool;
121
122    fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
123    where
124        F: FnMut(bool, &mut App) + 'static;
125}
126
127impl FeatureFlagAppExt for App {
128    fn update_flags(&mut self, staff: bool, flags: Vec<String>) {
129        let feature_flags = self.default_global::<FeatureFlags>();
130        feature_flags.staff = staff;
131        feature_flags.flags = flags;
132    }
133
134    fn set_staff(&mut self, staff: bool) {
135        let feature_flags = self.default_global::<FeatureFlags>();
136        feature_flags.staff = staff;
137    }
138
139    fn has_flag<T: FeatureFlag>(&self) -> bool {
140        self.try_global::<FeatureFlags>()
141            .map(|flags| flags.has_flag::<T>())
142            .unwrap_or(false)
143    }
144
145    fn is_staff(&self) -> bool {
146        self.try_global::<FeatureFlags>()
147            .map(|flags| flags.staff)
148            .unwrap_or(false)
149    }
150
151    fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription
152    where
153        F: FnMut(bool, &mut App) + 'static,
154    {
155        self.observe_global::<FeatureFlags>(move |cx| {
156            let feature_flags = cx.global::<FeatureFlags>();
157            callback(feature_flags.has_flag::<T>(), cx);
158        })
159    }
160
161    fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
162        let (tx, rx) = oneshot::channel::<bool>();
163        let mut tx = Some(tx);
164        let subscription: Option<Subscription>;
165
166        match self.try_global::<FeatureFlags>() {
167            Some(feature_flags) => {
168                subscription = None;
169                tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
170            }
171            None => {
172                subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
173                    let feature_flags = cx.global::<FeatureFlags>();
174                    if let Some(tx) = tx.take() {
175                        tx.send(feature_flags.has_flag::<T>()).ok();
176                    }
177                }));
178            }
179        }
180
181        WaitForFlag(rx, subscription)
182    }
183}
184
185pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
186
187impl Future for WaitForFlag {
188    type Output = bool;
189
190    fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
191        self.0.poll_unpin(cx).map(|result| {
192            self.1.take();
193            result.unwrap_or(false)
194        })
195    }
196}