feature_flags.rs

  1use futures::{channel::oneshot, FutureExt as _};
  2use gpui::{AppContext, Global, Subscription, ViewContext};
  3use std::{
  4    future::Future,
  5    pin::Pin,
  6    task::{Context, Poll},
  7};
  8
  9#[derive(Default)]
 10struct FeatureFlags {
 11    flags: Vec<String>,
 12    staff: bool,
 13}
 14
 15impl FeatureFlags {
 16    fn has_flag<T: FeatureFlag>(&self) -> bool {
 17        if self.staff && T::enabled_for_staff() {
 18            return true;
 19        }
 20
 21        self.flags.iter().any(|f| f.as_str() == T::NAME)
 22    }
 23}
 24
 25impl Global for FeatureFlags {}
 26
 27/// To create a feature flag, implement this trait on a trivial type and use it as
 28/// a generic parameter when called [`FeatureFlagAppExt::has_flag`].
 29///
 30/// Feature flags are enabled for members of Zed staff by default. To disable this behavior
 31/// so you can test flags being disabled, set ZED_DISABLE_STAFF=1 in your environment,
 32/// which will force Zed to treat the current user as non-staff.
 33pub trait FeatureFlag {
 34    const NAME: &'static str;
 35
 36    /// Returns whether this feature flag is enabled for Zed staff.
 37    fn enabled_for_staff() -> bool {
 38        true
 39    }
 40}
 41
 42pub struct Assistant2FeatureFlag;
 43
 44impl FeatureFlag for Assistant2FeatureFlag {
 45    const NAME: &'static str = "assistant2";
 46}
 47
 48pub struct ToolUseFeatureFlag;
 49
 50impl FeatureFlag for ToolUseFeatureFlag {
 51    const NAME: &'static str = "assistant-tool-use";
 52
 53    fn enabled_for_staff() -> bool {
 54        false
 55    }
 56}
 57
 58pub struct PredictEditsFeatureFlag;
 59impl FeatureFlag for PredictEditsFeatureFlag {
 60    const NAME: &'static str = "predict-edits";
 61}
 62
 63pub struct GitUiFeatureFlag;
 64impl FeatureFlag for GitUiFeatureFlag {
 65    const NAME: &'static str = "git-ui";
 66}
 67
 68pub struct Remoting {}
 69impl FeatureFlag for Remoting {
 70    const NAME: &'static str = "remoting";
 71}
 72
 73pub struct LanguageModels {}
 74impl FeatureFlag for LanguageModels {
 75    const NAME: &'static str = "language-models";
 76}
 77
 78pub struct LlmClosedBeta {}
 79impl FeatureFlag for LlmClosedBeta {
 80    const NAME: &'static str = "llm-closed-beta";
 81}
 82
 83pub struct ZedPro {}
 84impl FeatureFlag for ZedPro {
 85    const NAME: &'static str = "zed-pro";
 86}
 87
 88pub struct NotebookFeatureFlag;
 89
 90impl FeatureFlag for NotebookFeatureFlag {
 91    const NAME: &'static str = "notebooks";
 92}
 93
 94pub struct AutoCommand {}
 95impl FeatureFlag for AutoCommand {
 96    const NAME: &'static str = "auto-command";
 97
 98    fn enabled_for_staff() -> bool {
 99        false
100    }
101}
102
103pub trait FeatureFlagViewExt<V: 'static> {
104    fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
105    where
106        F: Fn(bool, &mut V, &mut ViewContext<V>) + Send + Sync + 'static;
107}
108
109impl<V> FeatureFlagViewExt<V> for ViewContext<'_, V>
110where
111    V: 'static,
112{
113    fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
114    where
115        F: Fn(bool, &mut V, &mut ViewContext<V>) + 'static,
116    {
117        self.observe_global::<FeatureFlags>(move |v, cx| {
118            let feature_flags = cx.global::<FeatureFlags>();
119            callback(feature_flags.has_flag::<T>(), v, cx);
120        })
121    }
122}
123
124pub trait FeatureFlagAppExt {
125    fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
126    fn update_flags(&mut self, staff: bool, flags: Vec<String>);
127    fn set_staff(&mut self, staff: bool);
128    fn has_flag<T: FeatureFlag>(&self) -> bool;
129    fn is_staff(&self) -> bool;
130
131    fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
132    where
133        F: FnMut(bool, &mut AppContext) + 'static;
134}
135
136impl FeatureFlagAppExt for AppContext {
137    fn update_flags(&mut self, staff: bool, flags: Vec<String>) {
138        let feature_flags = self.default_global::<FeatureFlags>();
139        feature_flags.staff = staff;
140        feature_flags.flags = flags;
141    }
142
143    fn set_staff(&mut self, staff: bool) {
144        let feature_flags = self.default_global::<FeatureFlags>();
145        feature_flags.staff = staff;
146    }
147
148    fn has_flag<T: FeatureFlag>(&self) -> bool {
149        self.try_global::<FeatureFlags>()
150            .map(|flags| flags.has_flag::<T>())
151            .unwrap_or(false)
152    }
153
154    fn is_staff(&self) -> bool {
155        self.try_global::<FeatureFlags>()
156            .map(|flags| flags.staff)
157            .unwrap_or(false)
158    }
159
160    fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription
161    where
162        F: FnMut(bool, &mut AppContext) + 'static,
163    {
164        self.observe_global::<FeatureFlags>(move |cx| {
165            let feature_flags = cx.global::<FeatureFlags>();
166            callback(feature_flags.has_flag::<T>(), cx);
167        })
168    }
169
170    fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
171        let (tx, rx) = oneshot::channel::<bool>();
172        let mut tx = Some(tx);
173        let subscription: Option<Subscription>;
174
175        match self.try_global::<FeatureFlags>() {
176            Some(feature_flags) => {
177                subscription = None;
178                tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
179            }
180            None => {
181                subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
182                    let feature_flags = cx.global::<FeatureFlags>();
183                    if let Some(tx) = tx.take() {
184                        tx.send(feature_flags.has_flag::<T>()).ok();
185                    }
186                }));
187            }
188        }
189
190        WaitForFlag(rx, subscription)
191    }
192}
193
194pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
195
196impl Future for WaitForFlag {
197    type Output = bool;
198
199    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
200        self.0.poll_unpin(cx).map(|result| {
201            self.1.take();
202            result.unwrap_or(false)
203        })
204    }
205}