edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState, Zeta2FeatureFlag,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use extension_host::ExtensionStore;
  7use feature_flags::FeatureFlagAppExt as _;
  8use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
  9use language_model::{
 10    ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
 13use std::collections::HashMap;
 14use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
 15
 16use crate::{
 17    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 18    components::{SettingsInputField, SettingsSectionHeader},
 19};
 20
 21pub struct EditPredictionSetupPage {
 22    settings_window: Entity<SettingsWindow>,
 23    scroll_handle: ScrollHandle,
 24    extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
 25    _registry_subscription: Subscription,
 26}
 27
 28struct ExtensionOAuthProviderView {
 29    provider_name: SharedString,
 30    provider_icon: IconName,
 31    provider_icon_path: Option<SharedString>,
 32    configuration_view: AnyView,
 33}
 34
 35impl EditPredictionSetupPage {
 36    pub fn new(
 37        settings_window: Entity<SettingsWindow>,
 38        window: &mut Window,
 39        cx: &mut Context<Self>,
 40    ) -> Self {
 41        let registry_subscription = cx.subscribe_in(
 42            &LanguageModelRegistry::global(cx),
 43            window,
 44            |this, _, event: &language_model::Event, window, cx| match event {
 45                language_model::Event::AddedProvider(provider_id) => {
 46                    this.maybe_add_extension_oauth_view(provider_id, window, cx);
 47                }
 48                language_model::Event::RemovedProvider(provider_id) => {
 49                    this.extension_oauth_views.remove(provider_id);
 50                }
 51                _ => {}
 52            },
 53        );
 54
 55        let mut this = Self {
 56            settings_window,
 57            scroll_handle: ScrollHandle::new(),
 58            extension_oauth_views: HashMap::default(),
 59            _registry_subscription: registry_subscription,
 60        };
 61        this.build_extension_oauth_views(window, cx);
 62        this
 63    }
 64
 65    fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 66        let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
 67        for provider_id in oauth_provider_ids {
 68            self.maybe_add_extension_oauth_view(&provider_id, window, cx);
 69        }
 70    }
 71
 72    fn maybe_add_extension_oauth_view(
 73        &mut self,
 74        provider_id: &LanguageModelProviderId,
 75        window: &mut Window,
 76        cx: &mut Context<Self>,
 77    ) {
 78        // Check if this provider has OAuth configured in the extension manifest
 79        if !is_extension_oauth_provider(provider_id, cx) {
 80            return;
 81        }
 82
 83        let registry = LanguageModelRegistry::global(cx).read(cx);
 84        let Some(provider) = registry.provider(provider_id) else {
 85            return;
 86        };
 87
 88        let provider_name = provider.name().0.clone();
 89        let provider_icon = provider.icon();
 90        let provider_icon_path = provider.icon_path();
 91        let configuration_view =
 92            provider.configuration_view(ConfigurationViewTargetAgent::ZedAgent, window, cx);
 93
 94        self.extension_oauth_views.insert(
 95            provider_id.clone(),
 96            ExtensionOAuthProviderView {
 97                provider_name,
 98                provider_icon,
 99                provider_icon_path,
100                configuration_view,
101            },
102        );
103    }
104}
105
106impl Render for EditPredictionSetupPage {
107    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
108        let settings_window = self.settings_window.clone();
109
110        let copilot_extension_installed = ExtensionStore::global(cx)
111            .read(cx)
112            .installed_extensions()
113            .contains_key("copilot-chat");
114
115        let mut providers: Vec<AnyElement> = Vec::new();
116
117        // Built-in Copilot (hidden if copilot-chat extension is installed)
118        if !copilot_extension_installed {
119            providers.push(render_github_copilot_provider(window, cx).into_any_element());
120        }
121
122        // Extension providers with OAuth support
123        for (provider_id, view) in &self.extension_oauth_views {
124            let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
125                Icon::from_external_svg(icon_path.clone())
126                    .size(ui::IconSize::Medium)
127                    .into_any_element()
128            } else {
129                Icon::new(view.provider_icon)
130                    .size(ui::IconSize::Medium)
131                    .into_any_element()
132            };
133
134            providers.push(
135                v_flex()
136                    .id(SharedString::from(provider_id.0.to_string()))
137                    .min_w_0()
138                    .gap_1p5()
139                    .child(
140                        h_flex().gap_2().items_center().child(icon_element).child(
141                            Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
142                        ),
143                    )
144                    .child(view.configuration_view.clone())
145                    .into_any_element(),
146            );
147        }
148
149        // Mercury (feature flagged)
150        if cx.has_flag::<Zeta2FeatureFlag>() {
151            providers.push(
152                render_api_key_provider(
153                    IconName::Inception,
154                    "Mercury",
155                    "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
156                    mercury_api_token(cx),
157                    |_cx| MERCURY_CREDENTIALS_URL,
158                    None,
159                    window,
160                    cx,
161                )
162                .into_any_element(),
163            );
164        }
165
166        // Sweep (feature flagged)
167        if cx.has_flag::<Zeta2FeatureFlag>() {
168            providers.push(
169                render_api_key_provider(
170                    IconName::SweepAi,
171                    "Sweep",
172                    "https://app.sweep.dev/".into(),
173                    sweep_api_token(cx),
174                    |_cx| SWEEP_CREDENTIALS_URL,
175                    None,
176                    window,
177                    cx,
178                )
179                .into_any_element(),
180            );
181        }
182
183        // Codestral
184        providers.push(
185            render_api_key_provider(
186                IconName::AiMistral,
187                "Codestral",
188                "https://console.mistral.ai/codestral".into(),
189                codestral_api_key(cx),
190                |cx| language_models::MistralLanguageModelProvider::api_url(cx),
191                Some(settings_window.update(cx, |settings_window, cx| {
192                    let codestral_settings = codestral_settings();
193                    settings_window
194                        .render_sub_page_items_section(
195                            codestral_settings.iter().enumerate(),
196                            None,
197                            window,
198                            cx,
199                        )
200                        .into_any_element()
201                })),
202                window,
203                cx,
204            )
205            .into_any_element(),
206        );
207
208        div()
209            .size_full()
210            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
211            .child(
212                v_flex()
213                    .id("ep-setup-page")
214                    .min_w_0()
215                    .size_full()
216                    .px_8()
217                    .pb_16()
218                    .overflow_y_scroll()
219                    .track_scroll(&self.scroll_handle)
220                    .children(providers),
221            )
222    }
223}
224
225/// Get extension provider IDs that have OAuth configured.
226fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
227    let extension_store = ExtensionStore::global(cx).read(cx);
228
229    extension_store
230        .installed_extensions()
231        .iter()
232        .flat_map(|(extension_id, entry)| {
233            entry.manifest.language_model_providers.iter().filter_map(
234                move |(provider_id, provider_entry)| {
235                    // Check if this provider has OAuth configured
236                    let has_oauth = provider_entry
237                        .auth
238                        .as_ref()
239                        .is_some_and(|auth| auth.oauth.is_some());
240
241                    if has_oauth {
242                        Some(LanguageModelProviderId(
243                            format!("{}:{}", extension_id, provider_id).into(),
244                        ))
245                    } else {
246                        None
247                    }
248                },
249            )
250        })
251        .collect()
252}
253
254/// Check if a provider ID corresponds to an extension with OAuth configured.
255fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
256    // Extension provider IDs are in the format "extension_id:provider_id"
257    let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
258        return false;
259    };
260
261    let extension_store = ExtensionStore::global(cx).read(cx);
262    let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
263        return false;
264    };
265
266    entry
267        .manifest
268        .language_model_providers
269        .get(local_provider_id)
270        .and_then(|p| p.auth.as_ref())
271        .is_some_and(|auth| auth.oauth.is_some())
272}
273
274fn render_api_key_provider(
275    icon: IconName,
276    title: &'static str,
277    link: SharedString,
278    api_key_state: Entity<ApiKeyState>,
279    current_url: fn(&mut App) -> SharedString,
280    additional_fields: Option<AnyElement>,
281    window: &mut Window,
282    cx: &mut Context<EditPredictionSetupPage>,
283) -> impl IntoElement {
284    let weak_page = cx.weak_entity();
285    _ = window.use_keyed_state(title, cx, |_, cx| {
286        let task = api_key_state.update(cx, |key_state, cx| {
287            key_state.load_if_needed(current_url(cx), |state| state, cx)
288        });
289        cx.spawn(async move |_, cx| {
290            task.await.ok();
291            weak_page
292                .update(cx, |_, cx| {
293                    cx.notify();
294                })
295                .ok();
296        })
297    });
298
299    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
300        (
301            state.has_key(),
302            Some(state.env_var_name().clone()),
303            state.is_from_env_var(),
304        )
305    });
306
307    let write_key = move |api_key: Option<String>, cx: &mut App| {
308        api_key_state
309            .update(cx, |key_state, cx| {
310                let url = current_url(cx);
311                key_state.store(url, api_key, |key_state| key_state, cx)
312            })
313            .detach_and_log_err(cx);
314    };
315
316    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
317    let header = SettingsSectionHeader::new(title)
318        .icon(icon)
319        .no_padding(true);
320    let button_link_label = format!("{} dashboard", title);
321    let description = h_flex()
322        .min_w_0()
323        .gap_0p5()
324        .child(
325            Label::new("Visit the")
326                .size(LabelSize::Small)
327                .color(Color::Muted),
328        )
329        .child(
330            ButtonLink::new(button_link_label, link)
331                .no_icon(true)
332                .label_size(LabelSize::Small)
333                .label_color(Color::Muted),
334        )
335        .child(
336            Label::new("to generate an API key.")
337                .size(LabelSize::Small)
338                .color(Color::Muted),
339        );
340    let configured_card_label = if is_from_env_var {
341        "API Key Set in Environment Variable"
342    } else {
343        "API Key Configured"
344    };
345
346    let container = if has_key {
347        base_container.child(header).child(
348            ConfiguredApiCard::new(configured_card_label)
349                .button_label("Reset Key")
350                .button_tab_index(0)
351                .disabled(is_from_env_var)
352                .when_some(env_var_name, |this, env_var_name| {
353                    this.when(is_from_env_var, |this| {
354                        this.tooltip_label(format!(
355                            "To reset your API key, unset the {} environment variable.",
356                            env_var_name
357                        ))
358                    })
359                })
360                .on_click(move |_, _, cx| {
361                    write_key(None, cx);
362                }),
363        )
364    } else {
365        base_container.child(header).child(
366            h_flex()
367                .pt_2p5()
368                .w_full()
369                .justify_between()
370                .child(
371                    v_flex()
372                        .w_full()
373                        .max_w_1_2()
374                        .child(Label::new("API Key"))
375                        .child(description)
376                        .when_some(env_var_name, |this, env_var_name| {
377                            this.child({
378                                let label = format!(
379                                    "Or set the {} env var and restart Zed.",
380                                    env_var_name.as_ref()
381                                );
382                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
383                            })
384                        }),
385                )
386                .child(
387                    SettingsInputField::new()
388                        .tab_index(0)
389                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
390                        .on_confirm(move |api_key, cx| {
391                            write_key(api_key.filter(|key| !key.is_empty()), cx);
392                        }),
393                ),
394        )
395    };
396
397    container.when_some(additional_fields, |this, additional_fields| {
398        this.child(
399            div()
400                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
401                .px_neg_8()
402                .border_t_1()
403                .border_color(cx.theme().colors().border_variant)
404                .child(additional_fields),
405        )
406    })
407}
408
409fn codestral_settings() -> Box<[SettingsPageItem]> {
410    Box::new([
411        SettingsPageItem::SettingItem(SettingItem {
412            title: "API URL",
413            description: "The API URL to use for Codestral.",
414            field: Box::new(SettingField {
415                pick: |settings| {
416                    settings
417                        .project
418                        .all_languages
419                        .edit_predictions
420                        .as_ref()?
421                        .codestral
422                        .as_ref()?
423                        .api_url
424                        .as_ref()
425                },
426                write: |settings, value| {
427                    settings
428                        .project
429                        .all_languages
430                        .edit_predictions
431                        .get_or_insert_default()
432                        .codestral
433                        .get_or_insert_default()
434                        .api_url = value;
435                },
436                json_path: Some("edit_predictions.codestral.api_url"),
437            }),
438            metadata: Some(Box::new(SettingsFieldMetadata {
439                placeholder: Some(CODESTRAL_API_URL),
440                ..Default::default()
441            })),
442            files: USER,
443        }),
444        SettingsPageItem::SettingItem(SettingItem {
445            title: "Max Tokens",
446            description: "The maximum number of tokens to generate.",
447            field: Box::new(SettingField {
448                pick: |settings| {
449                    settings
450                        .project
451                        .all_languages
452                        .edit_predictions
453                        .as_ref()?
454                        .codestral
455                        .as_ref()?
456                        .max_tokens
457                        .as_ref()
458                },
459                write: |settings, value| {
460                    settings
461                        .project
462                        .all_languages
463                        .edit_predictions
464                        .get_or_insert_default()
465                        .codestral
466                        .get_or_insert_default()
467                        .max_tokens = value;
468                },
469                json_path: Some("edit_predictions.codestral.max_tokens"),
470            }),
471            metadata: None,
472            files: USER,
473        }),
474        SettingsPageItem::SettingItem(SettingItem {
475            title: "Model",
476            description: "The Codestral model id to use.",
477            field: Box::new(SettingField {
478                pick: |settings| {
479                    settings
480                        .project
481                        .all_languages
482                        .edit_predictions
483                        .as_ref()?
484                        .codestral
485                        .as_ref()?
486                        .model
487                        .as_ref()
488                },
489                write: |settings, value| {
490                    settings
491                        .project
492                        .all_languages
493                        .edit_predictions
494                        .get_or_insert_default()
495                        .codestral
496                        .get_or_insert_default()
497                        .model = value;
498                },
499                json_path: Some("edit_predictions.codestral.model"),
500            }),
501            metadata: Some(Box::new(SettingsFieldMetadata {
502                placeholder: Some("codestral-latest"),
503                ..Default::default()
504            })),
505            files: USER,
506        }),
507    ])
508}
509
510pub(crate) fn render_github_copilot_provider(
511    window: &mut Window,
512    cx: &mut App,
513) -> impl IntoElement {
514    let configuration_view = window.use_state(cx, |_, cx| {
515        copilot::ConfigurationView::new(
516            |cx| {
517                copilot::Copilot::global(cx)
518                    .is_some_and(|copilot| copilot.read(cx).is_authenticated())
519            },
520            copilot::ConfigurationMode::EditPrediction,
521            cx,
522        )
523    });
524
525    v_flex()
526        .id("github-copilot")
527        .min_w_0()
528        .gap_1p5()
529        .child(
530            SettingsSectionHeader::new("GitHub Copilot")
531                .icon(IconName::Copilot)
532                .no_padding(true),
533        )
534        .child(configuration_view)
535}