1use std::sync::Arc;
2
3use crate::ZedPredictModal;
4use chrono::Utc;
5use client::{Client, UserStore};
6use feature_flags::{FeatureFlagAppExt as _, PredictEditsFeatureFlag};
7use fs::Fs;
8use gpui::{Entity, Subscription, WeakEntity};
9use language::language_settings::{all_language_settings, InlineCompletionProvider};
10use settings::SettingsStore;
11use ui::{prelude::*, ButtonLike, Tooltip};
12use util::ResultExt;
13use workspace::Workspace;
14
15/// Prompts user to try AI inline prediction feature
16pub struct ZedPredictBanner {
17 workspace: WeakEntity<Workspace>,
18 user_store: Entity<UserStore>,
19 client: Arc<Client>,
20 fs: Arc<dyn Fs>,
21 dismissed: bool,
22 _subscription: Subscription,
23}
24
25impl ZedPredictBanner {
26 pub fn new(
27 workspace: WeakEntity<Workspace>,
28 user_store: Entity<UserStore>,
29 client: Arc<Client>,
30 fs: Arc<dyn Fs>,
31 cx: &mut Context<Self>,
32 ) -> Self {
33 Self {
34 workspace,
35 user_store,
36 client,
37 fs,
38 dismissed: get_dismissed(),
39 _subscription: cx.observe_global::<SettingsStore>(Self::handle_settings_changed),
40 }
41 }
42
43 fn should_show(&self, cx: &mut App) -> bool {
44 if !cx.has_flag::<PredictEditsFeatureFlag>() || self.dismissed {
45 return false;
46 }
47
48 let provider = all_language_settings(None, cx).inline_completions.provider;
49
50 match provider {
51 InlineCompletionProvider::None
52 | InlineCompletionProvider::Copilot
53 | InlineCompletionProvider::Supermaven => true,
54 InlineCompletionProvider::Zed => false,
55 }
56 }
57
58 fn handle_settings_changed(&mut self, cx: &mut Context<Self>) {
59 if self.dismissed {
60 return;
61 }
62
63 let provider = all_language_settings(None, cx).inline_completions.provider;
64
65 match provider {
66 InlineCompletionProvider::None
67 | InlineCompletionProvider::Copilot
68 | InlineCompletionProvider::Supermaven => {}
69 InlineCompletionProvider::Zed => {
70 self.dismiss(cx);
71 }
72 }
73 }
74
75 fn dismiss(&mut self, cx: &mut Context<Self>) {
76 persist_dismissed(cx);
77 self.dismissed = true;
78 cx.notify();
79 }
80}
81
82const DISMISSED_AT_KEY: &str = "zed_predict_banner_dismissed_at";
83
84pub(crate) fn get_dismissed() -> bool {
85 db::kvp::KEY_VALUE_STORE
86 .read_kvp(DISMISSED_AT_KEY)
87 .log_err()
88 .map_or(false, |dismissed| dismissed.is_some())
89}
90
91pub(crate) fn persist_dismissed(cx: &mut App) {
92 cx.spawn(|_| {
93 let time = Utc::now().to_rfc3339();
94 db::kvp::KEY_VALUE_STORE.write_kvp(DISMISSED_AT_KEY.into(), time)
95 })
96 .detach_and_log_err(cx);
97}
98
99impl Render for ZedPredictBanner {
100 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
101 if !self.should_show(cx) {
102 return div();
103 }
104
105 let border_color = cx.theme().colors().editor_foreground.opacity(0.3);
106 let banner = h_flex()
107 .rounded_md()
108 .border_1()
109 .border_color(border_color)
110 .child(
111 ButtonLike::new("try-zed-predict")
112 .child(
113 h_flex()
114 .h_full()
115 .items_center()
116 .gap_1p5()
117 .child(Icon::new(IconName::ZedPredict).size(IconSize::Small))
118 .child(
119 h_flex()
120 .gap_0p5()
121 .child(
122 Label::new("Introducing:")
123 .size(LabelSize::Small)
124 .color(Color::Muted),
125 )
126 .child(Label::new("Edit Prediction").size(LabelSize::Small)),
127 ),
128 )
129 .on_click({
130 let workspace = self.workspace.clone();
131 let user_store = self.user_store.clone();
132 let client = self.client.clone();
133 let fs = self.fs.clone();
134 move |_, window, cx| {
135 let Some(workspace) = workspace.upgrade() else {
136 return;
137 };
138 ZedPredictModal::toggle(
139 workspace,
140 user_store.clone(),
141 client.clone(),
142 fs.clone(),
143 window,
144 cx,
145 );
146 }
147 }),
148 )
149 .child(
150 div().border_l_1().border_color(border_color).child(
151 IconButton::new("close", IconName::Close)
152 .icon_size(IconSize::Indicator)
153 .on_click(cx.listener(|this, _, _window, cx| this.dismiss(cx)))
154 .tooltip(|window, cx| {
155 Tooltip::with_meta(
156 "Close Announcement Banner",
157 None,
158 "It won't show again for this feature",
159 window,
160 cx,
161 )
162 }),
163 ),
164 );
165
166 div().pr_1().child(banner)
167 }
168}