edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState, Zeta2FeatureFlag,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use extension_host::ExtensionStore;
  7use feature_flags::FeatureFlagAppExt as _;
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
 10use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
 11
 12use crate::{
 13    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 14    components::{SettingsInputField, SettingsSectionHeader},
 15};
 16
 17pub struct EditPredictionSetupPage {
 18    settings_window: Entity<SettingsWindow>,
 19    scroll_handle: ScrollHandle,
 20}
 21
 22impl EditPredictionSetupPage {
 23    pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
 24        Self {
 25            settings_window,
 26            scroll_handle: ScrollHandle::new(),
 27        }
 28    }
 29}
 30
 31impl Render for EditPredictionSetupPage {
 32    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 33        let settings_window = self.settings_window.clone();
 34
 35        let copilot_extension_installed = ExtensionStore::global(cx)
 36            .read(cx)
 37            .installed_extensions()
 38            .contains_key("copilot-chat");
 39
 40        let providers = [
 41            (!copilot_extension_installed)
 42                .then(|| render_github_copilot_provider(window, cx).into_any_element()),
 43            cx.has_flag::<Zeta2FeatureFlag>().then(|| {
 44                render_api_key_provider(
 45                    IconName::Inception,
 46                    "Mercury",
 47                    "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 48                    mercury_api_token(cx),
 49                    |_cx| MERCURY_CREDENTIALS_URL,
 50                    None,
 51                    window,
 52                    cx,
 53                )
 54                .into_any_element()
 55            }),
 56            cx.has_flag::<Zeta2FeatureFlag>().then(|| {
 57                render_api_key_provider(
 58                    IconName::SweepAi,
 59                    "Sweep",
 60                    "https://app.sweep.dev/".into(),
 61                    sweep_api_token(cx),
 62                    |_cx| SWEEP_CREDENTIALS_URL,
 63                    None,
 64                    window,
 65                    cx,
 66                )
 67                .into_any_element()
 68            }),
 69            Some(
 70                render_api_key_provider(
 71                    IconName::AiMistral,
 72                    "Codestral",
 73                    "https://console.mistral.ai/codestral".into(),
 74                    codestral_api_key(cx),
 75                    |cx| language_models::MistralLanguageModelProvider::api_url(cx),
 76                    Some(settings_window.update(cx, |settings_window, cx| {
 77                        let codestral_settings = codestral_settings();
 78                        settings_window
 79                            .render_sub_page_items_section(
 80                                codestral_settings.iter().enumerate(),
 81                                None,
 82                                window,
 83                                cx,
 84                            )
 85                            .into_any_element()
 86                    })),
 87                    window,
 88                    cx,
 89                )
 90                .into_any_element(),
 91            ),
 92        ];
 93
 94        div()
 95            .size_full()
 96            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
 97            .child(
 98                v_flex()
 99                    .id("ep-setup-page")
100                    .min_w_0()
101                    .size_full()
102                    .px_8()
103                    .pb_16()
104                    .overflow_y_scroll()
105                    .track_scroll(&self.scroll_handle)
106                    .children(providers.into_iter().flatten()),
107            )
108    }
109}
110
111fn render_api_key_provider(
112    icon: IconName,
113    title: &'static str,
114    link: SharedString,
115    api_key_state: Entity<ApiKeyState>,
116    current_url: fn(&mut App) -> SharedString,
117    additional_fields: Option<AnyElement>,
118    window: &mut Window,
119    cx: &mut Context<EditPredictionSetupPage>,
120) -> impl IntoElement {
121    let weak_page = cx.weak_entity();
122    _ = window.use_keyed_state(title, cx, |_, cx| {
123        let task = api_key_state.update(cx, |key_state, cx| {
124            key_state.load_if_needed(current_url(cx), |state| state, cx)
125        });
126        cx.spawn(async move |_, cx| {
127            task.await.ok();
128            weak_page
129                .update(cx, |_, cx| {
130                    cx.notify();
131                })
132                .ok();
133        })
134    });
135
136    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
137        (
138            state.has_key(),
139            Some(state.env_var_name().clone()),
140            state.is_from_env_var(),
141        )
142    });
143
144    let write_key = move |api_key: Option<String>, cx: &mut App| {
145        api_key_state
146            .update(cx, |key_state, cx| {
147                let url = current_url(cx);
148                key_state.store(url, api_key, |key_state| key_state, cx)
149            })
150            .detach_and_log_err(cx);
151    };
152
153    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
154    let header = SettingsSectionHeader::new(title)
155        .icon(icon)
156        .no_padding(true);
157    let button_link_label = format!("{} dashboard", title);
158    let description = h_flex()
159        .min_w_0()
160        .gap_0p5()
161        .child(
162            Label::new("Visit the")
163                .size(LabelSize::Small)
164                .color(Color::Muted),
165        )
166        .child(
167            ButtonLink::new(button_link_label, link)
168                .no_icon(true)
169                .label_size(LabelSize::Small)
170                .label_color(Color::Muted),
171        )
172        .child(
173            Label::new("to generate an API key.")
174                .size(LabelSize::Small)
175                .color(Color::Muted),
176        );
177    let configured_card_label = if is_from_env_var {
178        "API Key Set in Environment Variable"
179    } else {
180        "API Key Configured"
181    };
182
183    let container = if has_key {
184        base_container.child(header).child(
185            ConfiguredApiCard::new(configured_card_label)
186                .button_label("Reset Key")
187                .button_tab_index(0)
188                .disabled(is_from_env_var)
189                .when_some(env_var_name, |this, env_var_name| {
190                    this.when(is_from_env_var, |this| {
191                        this.tooltip_label(format!(
192                            "To reset your API key, unset the {} environment variable.",
193                            env_var_name
194                        ))
195                    })
196                })
197                .on_click(move |_, _, cx| {
198                    write_key(None, cx);
199                }),
200        )
201    } else {
202        base_container.child(header).child(
203            h_flex()
204                .pt_2p5()
205                .w_full()
206                .justify_between()
207                .child(
208                    v_flex()
209                        .w_full()
210                        .max_w_1_2()
211                        .child(Label::new("API Key"))
212                        .child(description)
213                        .when_some(env_var_name, |this, env_var_name| {
214                            this.child({
215                                let label = format!(
216                                    "Or set the {} env var and restart Zed.",
217                                    env_var_name.as_ref()
218                                );
219                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
220                            })
221                        }),
222                )
223                .child(
224                    SettingsInputField::new()
225                        .tab_index(0)
226                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
227                        .on_confirm(move |api_key, cx| {
228                            write_key(api_key.filter(|key| !key.is_empty()), cx);
229                        }),
230                ),
231        )
232    };
233
234    container.when_some(additional_fields, |this, additional_fields| {
235        this.child(
236            div()
237                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
238                .px_neg_8()
239                .border_t_1()
240                .border_color(cx.theme().colors().border_variant)
241                .child(additional_fields),
242        )
243    })
244}
245
246fn codestral_settings() -> Box<[SettingsPageItem]> {
247    Box::new([
248        SettingsPageItem::SettingItem(SettingItem {
249            title: "API URL",
250            description: "The API URL to use for Codestral.",
251            field: Box::new(SettingField {
252                pick: |settings| {
253                    settings
254                        .project
255                        .all_languages
256                        .edit_predictions
257                        .as_ref()?
258                        .codestral
259                        .as_ref()?
260                        .api_url
261                        .as_ref()
262                },
263                write: |settings, value| {
264                    settings
265                        .project
266                        .all_languages
267                        .edit_predictions
268                        .get_or_insert_default()
269                        .codestral
270                        .get_or_insert_default()
271                        .api_url = value;
272                },
273                json_path: Some("edit_predictions.codestral.api_url"),
274            }),
275            metadata: Some(Box::new(SettingsFieldMetadata {
276                placeholder: Some(CODESTRAL_API_URL),
277                ..Default::default()
278            })),
279            files: USER,
280        }),
281        SettingsPageItem::SettingItem(SettingItem {
282            title: "Max Tokens",
283            description: "The maximum number of tokens to generate.",
284            field: Box::new(SettingField {
285                pick: |settings| {
286                    settings
287                        .project
288                        .all_languages
289                        .edit_predictions
290                        .as_ref()?
291                        .codestral
292                        .as_ref()?
293                        .max_tokens
294                        .as_ref()
295                },
296                write: |settings, value| {
297                    settings
298                        .project
299                        .all_languages
300                        .edit_predictions
301                        .get_or_insert_default()
302                        .codestral
303                        .get_or_insert_default()
304                        .max_tokens = value;
305                },
306                json_path: Some("edit_predictions.codestral.max_tokens"),
307            }),
308            metadata: None,
309            files: USER,
310        }),
311        SettingsPageItem::SettingItem(SettingItem {
312            title: "Model",
313            description: "The Codestral model id to use.",
314            field: Box::new(SettingField {
315                pick: |settings| {
316                    settings
317                        .project
318                        .all_languages
319                        .edit_predictions
320                        .as_ref()?
321                        .codestral
322                        .as_ref()?
323                        .model
324                        .as_ref()
325                },
326                write: |settings, value| {
327                    settings
328                        .project
329                        .all_languages
330                        .edit_predictions
331                        .get_or_insert_default()
332                        .codestral
333                        .get_or_insert_default()
334                        .model = value;
335                },
336                json_path: Some("edit_predictions.codestral.model"),
337            }),
338            metadata: Some(Box::new(SettingsFieldMetadata {
339                placeholder: Some("codestral-latest"),
340                ..Default::default()
341            })),
342            files: USER,
343        }),
344    ])
345}
346
347pub(crate) fn render_github_copilot_provider(
348    window: &mut Window,
349    cx: &mut App,
350) -> impl IntoElement {
351    let configuration_view = window.use_state(cx, |_, cx| {
352        copilot::ConfigurationView::new(
353            |cx| {
354                copilot::Copilot::global(cx)
355                    .is_some_and(|copilot| copilot.read(cx).is_authenticated())
356            },
357            copilot::ConfigurationMode::EditPrediction,
358            cx,
359        )
360    });
361
362    v_flex()
363        .id("github-copilot")
364        .min_w_0()
365        .gap_1p5()
366        .child(
367            SettingsSectionHeader::new("GitHub Copilot")
368                .icon(IconName::Copilot)
369                .no_padding(true),
370        )
371        .child(configuration_view)
372}