edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState, Zeta2FeatureFlag,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use extension_host::ExtensionStore;
  7use feature_flags::FeatureFlagAppExt as _;
  8use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
  9use language_model::{
 10    ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
 13use std::collections::HashMap;
 14use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
 15
 16use crate::{
 17    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 18    components::{SettingsInputField, SettingsSectionHeader},
 19};
 20
 21pub struct EditPredictionSetupPage {
 22    settings_window: Entity<SettingsWindow>,
 23    scroll_handle: ScrollHandle,
 24    extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
 25    _registry_subscription: Subscription,
 26}
 27
 28struct ExtensionOAuthProviderView {
 29    provider_name: SharedString,
 30    provider_icon: IconName,
 31    provider_icon_path: Option<SharedString>,
 32    configuration_view: AnyView,
 33}
 34
 35impl EditPredictionSetupPage {
 36    pub fn new(
 37        settings_window: Entity<SettingsWindow>,
 38        window: &mut Window,
 39        cx: &mut Context<Self>,
 40    ) -> Self {
 41        let registry_subscription = cx.subscribe_in(
 42            &LanguageModelRegistry::global(cx),
 43            window,
 44            |this, _, event: &language_model::Event, window, cx| match event {
 45                language_model::Event::AddedProvider(provider_id) => {
 46                    this.maybe_add_extension_oauth_view(provider_id, window, cx);
 47                }
 48                language_model::Event::RemovedProvider(provider_id) => {
 49                    this.extension_oauth_views.remove(provider_id);
 50                }
 51                _ => {}
 52            },
 53        );
 54
 55        let mut this = Self {
 56            settings_window,
 57            scroll_handle: ScrollHandle::new(),
 58            extension_oauth_views: HashMap::default(),
 59            _registry_subscription: registry_subscription,
 60        };
 61        this.build_extension_oauth_views(window, cx);
 62        this
 63    }
 64
 65    fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 66        let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
 67        for provider_id in oauth_provider_ids {
 68            self.maybe_add_extension_oauth_view(&provider_id, window, cx);
 69        }
 70    }
 71
 72    fn maybe_add_extension_oauth_view(
 73        &mut self,
 74        provider_id: &LanguageModelProviderId,
 75        window: &mut Window,
 76        cx: &mut Context<Self>,
 77    ) {
 78        // Check if this provider has OAuth configured in the extension manifest
 79        if !is_extension_oauth_provider(provider_id, cx) {
 80            return;
 81        }
 82
 83        let registry = LanguageModelRegistry::global(cx).read(cx);
 84        let Some(provider) = registry.provider(provider_id) else {
 85            return;
 86        };
 87
 88        let provider_name = provider.name().0;
 89        let provider_icon = provider.icon();
 90        let provider_icon_path = provider.icon_path();
 91        let configuration_view =
 92            provider.configuration_view(ConfigurationViewTargetAgent::EditPrediction, window, cx);
 93
 94        self.extension_oauth_views.insert(
 95            provider_id.clone(),
 96            ExtensionOAuthProviderView {
 97                provider_name,
 98                provider_icon,
 99                provider_icon_path,
100                configuration_view,
101            },
102        );
103    }
104}
105
106impl Render for EditPredictionSetupPage {
107    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
108        let settings_window = self.settings_window.clone();
109
110        let copilot_extension_installed = ExtensionStore::global(cx)
111            .read(cx)
112            .installed_extensions()
113            .contains_key("copilot-chat");
114
115        let mut providers: Vec<AnyElement> = Vec::new();
116
117        // Built-in Copilot (hidden if copilot-chat extension is installed)
118        if !copilot_extension_installed {
119            providers.push(render_github_copilot_provider(window, cx).into_any_element());
120        }
121
122        // Extension providers with OAuth support
123        for (provider_id, view) in &self.extension_oauth_views {
124            let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
125                Icon::from_external_svg(icon_path.clone())
126                    .size(ui::IconSize::Medium)
127                    .into_any_element()
128            } else {
129                Icon::new(view.provider_icon)
130                    .size(ui::IconSize::Medium)
131                    .into_any_element()
132            };
133
134            providers.push(
135                v_flex()
136                    .id(SharedString::from(provider_id.0.to_string()))
137                    .min_w_0()
138                    .gap_1p5()
139                    .child(
140                        h_flex().gap_2().items_center().child(icon_element).child(
141                            Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
142                        ),
143                    )
144                    .child(view.configuration_view.clone())
145                    .into_any_element(),
146            );
147        }
148
149        if cx.has_flag::<Zeta2FeatureFlag>() {
150            providers.push(
151                render_api_key_provider(
152                    IconName::Inception,
153                    "Mercury",
154                    "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
155                    mercury_api_token(cx),
156                    |_cx| MERCURY_CREDENTIALS_URL,
157                    None,
158                    window,
159                    cx,
160                )
161                .into_any_element(),
162            );
163        }
164
165        if cx.has_flag::<Zeta2FeatureFlag>() {
166            providers.push(
167                render_api_key_provider(
168                    IconName::SweepAi,
169                    "Sweep",
170                    "https://app.sweep.dev/".into(),
171                    sweep_api_token(cx),
172                    |_cx| SWEEP_CREDENTIALS_URL,
173                    None,
174                    window,
175                    cx,
176                )
177                .into_any_element(),
178            );
179        }
180
181        providers.push(
182            render_api_key_provider(
183                IconName::AiMistral,
184                "Codestral",
185                "https://console.mistral.ai/codestral".into(),
186                codestral_api_key(cx),
187                |cx| language_models::MistralLanguageModelProvider::api_url(cx),
188                Some(settings_window.update(cx, |settings_window, cx| {
189                    let codestral_settings = codestral_settings();
190                    settings_window
191                        .render_sub_page_items_section(
192                            codestral_settings.iter().enumerate(),
193                            None,
194                            window,
195                            cx,
196                        )
197                        .into_any_element()
198                })),
199                window,
200                cx,
201            )
202            .into_any_element(),
203        );
204
205        div()
206            .size_full()
207            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
208            .child(
209                v_flex()
210                    .id("ep-setup-page")
211                    .min_w_0()
212                    .size_full()
213                    .px_8()
214                    .pb_16()
215                    .overflow_y_scroll()
216                    .track_scroll(&self.scroll_handle)
217                    .children(providers),
218            )
219    }
220}
221
222/// Get extension provider IDs that have OAuth configured.
223fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
224    let extension_store = ExtensionStore::global(cx).read(cx);
225
226    extension_store
227        .installed_extensions()
228        .iter()
229        .flat_map(|(extension_id, entry)| {
230            entry.manifest.language_model_providers.iter().filter_map(
231                move |(provider_id, provider_entry)| {
232                    // Check if this provider has OAuth configured
233                    let has_oauth = provider_entry
234                        .auth
235                        .as_ref()
236                        .is_some_and(|auth| auth.oauth.is_some());
237
238                    if has_oauth {
239                        Some(LanguageModelProviderId(
240                            format!("{}:{}", extension_id, provider_id).into(),
241                        ))
242                    } else {
243                        None
244                    }
245                },
246            )
247        })
248        .collect()
249}
250
251/// Check if a provider ID corresponds to an extension with OAuth configured.
252fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
253    // Extension provider IDs are in the format "extension_id:provider_id"
254    let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
255        return false;
256    };
257
258    let extension_store = ExtensionStore::global(cx).read(cx);
259    let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
260        return false;
261    };
262
263    entry
264        .manifest
265        .language_model_providers
266        .get(local_provider_id)
267        .and_then(|p| p.auth.as_ref())
268        .is_some_and(|auth| auth.oauth.is_some())
269}
270
271fn render_api_key_provider(
272    icon: IconName,
273    title: &'static str,
274    link: SharedString,
275    api_key_state: Entity<ApiKeyState>,
276    current_url: fn(&mut App) -> SharedString,
277    additional_fields: Option<AnyElement>,
278    window: &mut Window,
279    cx: &mut Context<EditPredictionSetupPage>,
280) -> impl IntoElement {
281    let weak_page = cx.weak_entity();
282    _ = window.use_keyed_state(title, cx, |_, cx| {
283        let task = api_key_state.update(cx, |key_state, cx| {
284            key_state.load_if_needed(current_url(cx), |state| state, cx)
285        });
286        cx.spawn(async move |_, cx| {
287            task.await.ok();
288            weak_page
289                .update(cx, |_, cx| {
290                    cx.notify();
291                })
292                .ok();
293        })
294    });
295
296    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
297        (
298            state.has_key(),
299            Some(state.env_var_name().clone()),
300            state.is_from_env_var(),
301        )
302    });
303
304    let write_key = move |api_key: Option<String>, cx: &mut App| {
305        api_key_state
306            .update(cx, |key_state, cx| {
307                let url = current_url(cx);
308                key_state.store(url, api_key, |key_state| key_state, cx)
309            })
310            .detach_and_log_err(cx);
311    };
312
313    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
314    let header = SettingsSectionHeader::new(title)
315        .icon(icon)
316        .no_padding(true);
317    let button_link_label = format!("{} dashboard", title);
318    let description = h_flex()
319        .min_w_0()
320        .gap_0p5()
321        .child(
322            Label::new("Visit the")
323                .size(LabelSize::Small)
324                .color(Color::Muted),
325        )
326        .child(
327            ButtonLink::new(button_link_label, link)
328                .no_icon(true)
329                .label_size(LabelSize::Small)
330                .label_color(Color::Muted),
331        )
332        .child(
333            Label::new("to generate an API key.")
334                .size(LabelSize::Small)
335                .color(Color::Muted),
336        );
337    let configured_card_label = if is_from_env_var {
338        "API Key Set in Environment Variable"
339    } else {
340        "API Key Configured"
341    };
342
343    let container = if has_key {
344        base_container.child(header).child(
345            ConfiguredApiCard::new(configured_card_label)
346                .button_label("Reset Key")
347                .button_tab_index(0)
348                .disabled(is_from_env_var)
349                .when_some(env_var_name, |this, env_var_name| {
350                    this.when(is_from_env_var, |this| {
351                        this.tooltip_label(format!(
352                            "To reset your API key, unset the {} environment variable.",
353                            env_var_name
354                        ))
355                    })
356                })
357                .on_click(move |_, _, cx| {
358                    write_key(None, cx);
359                }),
360        )
361    } else {
362        base_container.child(header).child(
363            h_flex()
364                .pt_2p5()
365                .w_full()
366                .justify_between()
367                .child(
368                    v_flex()
369                        .w_full()
370                        .max_w_1_2()
371                        .child(Label::new("API Key"))
372                        .child(description)
373                        .when_some(env_var_name, |this, env_var_name| {
374                            this.child({
375                                let label = format!(
376                                    "Or set the {} env var and restart Zed.",
377                                    env_var_name.as_ref()
378                                );
379                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
380                            })
381                        }),
382                )
383                .child(
384                    SettingsInputField::new()
385                        .tab_index(0)
386                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
387                        .on_confirm(move |api_key, cx| {
388                            write_key(api_key.filter(|key| !key.is_empty()), cx);
389                        }),
390                ),
391        )
392    };
393
394    container.when_some(additional_fields, |this, additional_fields| {
395        this.child(
396            div()
397                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
398                .px_neg_8()
399                .border_t_1()
400                .border_color(cx.theme().colors().border_variant)
401                .child(additional_fields),
402        )
403    })
404}
405
406fn codestral_settings() -> Box<[SettingsPageItem]> {
407    Box::new([
408        SettingsPageItem::SettingItem(SettingItem {
409            title: "API URL",
410            description: "The API URL to use for Codestral.",
411            field: Box::new(SettingField {
412                pick: |settings| {
413                    settings
414                        .project
415                        .all_languages
416                        .edit_predictions
417                        .as_ref()?
418                        .codestral
419                        .as_ref()?
420                        .api_url
421                        .as_ref()
422                },
423                write: |settings, value| {
424                    settings
425                        .project
426                        .all_languages
427                        .edit_predictions
428                        .get_or_insert_default()
429                        .codestral
430                        .get_or_insert_default()
431                        .api_url = value;
432                },
433                json_path: Some("edit_predictions.codestral.api_url"),
434            }),
435            metadata: Some(Box::new(SettingsFieldMetadata {
436                placeholder: Some(CODESTRAL_API_URL),
437                ..Default::default()
438            })),
439            files: USER,
440        }),
441        SettingsPageItem::SettingItem(SettingItem {
442            title: "Max Tokens",
443            description: "The maximum number of tokens to generate.",
444            field: Box::new(SettingField {
445                pick: |settings| {
446                    settings
447                        .project
448                        .all_languages
449                        .edit_predictions
450                        .as_ref()?
451                        .codestral
452                        .as_ref()?
453                        .max_tokens
454                        .as_ref()
455                },
456                write: |settings, value| {
457                    settings
458                        .project
459                        .all_languages
460                        .edit_predictions
461                        .get_or_insert_default()
462                        .codestral
463                        .get_or_insert_default()
464                        .max_tokens = value;
465                },
466                json_path: Some("edit_predictions.codestral.max_tokens"),
467            }),
468            metadata: None,
469            files: USER,
470        }),
471        SettingsPageItem::SettingItem(SettingItem {
472            title: "Model",
473            description: "The Codestral model id to use.",
474            field: Box::new(SettingField {
475                pick: |settings| {
476                    settings
477                        .project
478                        .all_languages
479                        .edit_predictions
480                        .as_ref()?
481                        .codestral
482                        .as_ref()?
483                        .model
484                        .as_ref()
485                },
486                write: |settings, value| {
487                    settings
488                        .project
489                        .all_languages
490                        .edit_predictions
491                        .get_or_insert_default()
492                        .codestral
493                        .get_or_insert_default()
494                        .model = value;
495                },
496                json_path: Some("edit_predictions.codestral.model"),
497            }),
498            metadata: Some(Box::new(SettingsFieldMetadata {
499                placeholder: Some("codestral-latest"),
500                ..Default::default()
501            })),
502            files: USER,
503        }),
504    ])
505}
506
507pub(crate) fn render_github_copilot_provider(
508    window: &mut Window,
509    cx: &mut App,
510) -> impl IntoElement {
511    let configuration_view = window.use_state(cx, |_, cx| {
512        copilot::ConfigurationView::new(
513            |cx| {
514                copilot::Copilot::global(cx)
515                    .is_some_and(|copilot| copilot.read(cx).is_authenticated())
516            },
517            copilot::ConfigurationMode::EditPrediction,
518            cx,
519        )
520    });
521
522    v_flex()
523        .id("github-copilot")
524        .min_w_0()
525        .gap_1p5()
526        .child(
527            SettingsSectionHeader::new("GitHub Copilot")
528                .icon(IconName::Copilot)
529                .no_padding(true),
530        )
531        .child(configuration_view)
532}