1use edit_prediction::{
2 ApiKeyState, Zeta2FeatureFlag,
3 mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
4 sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
5};
6use extension_host::ExtensionStore;
7use feature_flags::FeatureFlagAppExt as _;
8use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
9use language_model::{
10 ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
11};
12use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
13use std::collections::HashMap;
14use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
15
16use crate::{
17 SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
18 components::{SettingsInputField, SettingsSectionHeader},
19};
20
21pub struct EditPredictionSetupPage {
22 settings_window: Entity<SettingsWindow>,
23 scroll_handle: ScrollHandle,
24 extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
25 _registry_subscription: Subscription,
26}
27
28struct ExtensionOAuthProviderView {
29 provider_name: SharedString,
30 provider_icon: IconName,
31 provider_icon_path: Option<SharedString>,
32 configuration_view: AnyView,
33}
34
35impl EditPredictionSetupPage {
36 pub fn new(
37 settings_window: Entity<SettingsWindow>,
38 window: &mut Window,
39 cx: &mut Context<Self>,
40 ) -> Self {
41 let registry_subscription = cx.subscribe_in(
42 &LanguageModelRegistry::global(cx),
43 window,
44 |this, _, event: &language_model::Event, window, cx| match event {
45 language_model::Event::AddedProvider(provider_id) => {
46 this.maybe_add_extension_oauth_view(provider_id, window, cx);
47 }
48 language_model::Event::RemovedProvider(provider_id) => {
49 this.extension_oauth_views.remove(provider_id);
50 }
51 _ => {}
52 },
53 );
54
55 let mut this = Self {
56 settings_window,
57 scroll_handle: ScrollHandle::new(),
58 extension_oauth_views: HashMap::default(),
59 _registry_subscription: registry_subscription,
60 };
61 this.build_extension_oauth_views(window, cx);
62 this
63 }
64
65 fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
66 let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
67 for provider_id in oauth_provider_ids {
68 self.maybe_add_extension_oauth_view(&provider_id, window, cx);
69 }
70 }
71
72 fn maybe_add_extension_oauth_view(
73 &mut self,
74 provider_id: &LanguageModelProviderId,
75 window: &mut Window,
76 cx: &mut Context<Self>,
77 ) {
78 // Check if this provider has OAuth configured in the extension manifest
79 if !is_extension_oauth_provider(provider_id, cx) {
80 return;
81 }
82
83 let registry = LanguageModelRegistry::global(cx).read(cx);
84 let Some(provider) = registry.provider(provider_id) else {
85 return;
86 };
87
88 let provider_name = provider.name().0.clone();
89 let provider_icon = provider.icon();
90 let provider_icon_path = provider.icon_path();
91 let configuration_view =
92 provider.configuration_view(ConfigurationViewTargetAgent::ZedAgent, window, cx);
93
94 self.extension_oauth_views.insert(
95 provider_id.clone(),
96 ExtensionOAuthProviderView {
97 provider_name,
98 provider_icon,
99 provider_icon_path,
100 configuration_view,
101 },
102 );
103 }
104}
105
106impl Render for EditPredictionSetupPage {
107 fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
108 let settings_window = self.settings_window.clone();
109
110 let copilot_extension_installed = ExtensionStore::global(cx)
111 .read(cx)
112 .installed_extensions()
113 .contains_key("copilot-chat");
114
115 let mut providers: Vec<AnyElement> = Vec::new();
116
117 // Built-in Copilot (hidden if copilot-chat extension is installed)
118 if !copilot_extension_installed {
119 providers.push(render_github_copilot_provider(window, cx).into_any_element());
120 }
121
122 // Extension providers with OAuth support
123 for (provider_id, view) in &self.extension_oauth_views {
124 let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
125 Icon::from_external_svg(icon_path.clone())
126 .size(ui::IconSize::Medium)
127 .into_any_element()
128 } else {
129 Icon::new(view.provider_icon)
130 .size(ui::IconSize::Medium)
131 .into_any_element()
132 };
133
134 providers.push(
135 v_flex()
136 .id(SharedString::from(provider_id.0.to_string()))
137 .min_w_0()
138 .gap_1p5()
139 .child(
140 h_flex().gap_2().items_center().child(icon_element).child(
141 Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
142 ),
143 )
144 .child(view.configuration_view.clone())
145 .into_any_element(),
146 );
147 }
148
149 // Mercury (feature flagged)
150 if cx.has_flag::<Zeta2FeatureFlag>() {
151 providers.push(
152 render_api_key_provider(
153 IconName::Inception,
154 "Mercury",
155 "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
156 mercury_api_token(cx),
157 |_cx| MERCURY_CREDENTIALS_URL,
158 None,
159 window,
160 cx,
161 )
162 .into_any_element(),
163 );
164 }
165
166 // Sweep (feature flagged)
167 if cx.has_flag::<Zeta2FeatureFlag>() {
168 providers.push(
169 render_api_key_provider(
170 IconName::SweepAi,
171 "Sweep",
172 "https://app.sweep.dev/".into(),
173 sweep_api_token(cx),
174 |_cx| SWEEP_CREDENTIALS_URL,
175 None,
176 window,
177 cx,
178 )
179 .into_any_element(),
180 );
181 }
182
183 // Codestral
184 providers.push(
185 render_api_key_provider(
186 IconName::AiMistral,
187 "Codestral",
188 "https://console.mistral.ai/codestral".into(),
189 codestral_api_key(cx),
190 |cx| language_models::MistralLanguageModelProvider::api_url(cx),
191 Some(settings_window.update(cx, |settings_window, cx| {
192 let codestral_settings = codestral_settings();
193 settings_window
194 .render_sub_page_items_section(
195 codestral_settings.iter().enumerate(),
196 None,
197 window,
198 cx,
199 )
200 .into_any_element()
201 })),
202 window,
203 cx,
204 )
205 .into_any_element(),
206 );
207
208 div()
209 .size_full()
210 .vertical_scrollbar_for(&self.scroll_handle, window, cx)
211 .child(
212 v_flex()
213 .id("ep-setup-page")
214 .min_w_0()
215 .size_full()
216 .px_8()
217 .pb_16()
218 .overflow_y_scroll()
219 .track_scroll(&self.scroll_handle)
220 .children(providers),
221 )
222 }
223}
224
225/// Get extension provider IDs that have OAuth configured.
226fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
227 let extension_store = ExtensionStore::global(cx).read(cx);
228
229 extension_store
230 .installed_extensions()
231 .iter()
232 .flat_map(|(extension_id, entry)| {
233 entry.manifest.language_model_providers.iter().filter_map(
234 move |(provider_id, provider_entry)| {
235 // Check if this provider has OAuth configured
236 let has_oauth = provider_entry
237 .auth
238 .as_ref()
239 .is_some_and(|auth| auth.oauth.is_some());
240
241 if has_oauth {
242 Some(LanguageModelProviderId(
243 format!("{}:{}", extension_id, provider_id).into(),
244 ))
245 } else {
246 None
247 }
248 },
249 )
250 })
251 .collect()
252}
253
254/// Check if a provider ID corresponds to an extension with OAuth configured.
255fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
256 // Extension provider IDs are in the format "extension_id:provider_id"
257 let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
258 return false;
259 };
260
261 let extension_store = ExtensionStore::global(cx).read(cx);
262 let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
263 return false;
264 };
265
266 entry
267 .manifest
268 .language_model_providers
269 .get(local_provider_id)
270 .and_then(|p| p.auth.as_ref())
271 .is_some_and(|auth| auth.oauth.is_some())
272}
273
274fn render_api_key_provider(
275 icon: IconName,
276 title: &'static str,
277 link: SharedString,
278 api_key_state: Entity<ApiKeyState>,
279 current_url: fn(&mut App) -> SharedString,
280 additional_fields: Option<AnyElement>,
281 window: &mut Window,
282 cx: &mut Context<EditPredictionSetupPage>,
283) -> impl IntoElement {
284 let weak_page = cx.weak_entity();
285 _ = window.use_keyed_state(title, cx, |_, cx| {
286 let task = api_key_state.update(cx, |key_state, cx| {
287 key_state.load_if_needed(current_url(cx), |state| state, cx)
288 });
289 cx.spawn(async move |_, cx| {
290 task.await.ok();
291 weak_page
292 .update(cx, |_, cx| {
293 cx.notify();
294 })
295 .ok();
296 })
297 });
298
299 let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
300 (
301 state.has_key(),
302 Some(state.env_var_name().clone()),
303 state.is_from_env_var(),
304 )
305 });
306
307 let write_key = move |api_key: Option<String>, cx: &mut App| {
308 api_key_state
309 .update(cx, |key_state, cx| {
310 let url = current_url(cx);
311 key_state.store(url, api_key, |key_state| key_state, cx)
312 })
313 .detach_and_log_err(cx);
314 };
315
316 let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
317 let header = SettingsSectionHeader::new(title)
318 .icon(icon)
319 .no_padding(true);
320 let button_link_label = format!("{} dashboard", title);
321 let description = h_flex()
322 .min_w_0()
323 .gap_0p5()
324 .child(
325 Label::new("Visit the")
326 .size(LabelSize::Small)
327 .color(Color::Muted),
328 )
329 .child(
330 ButtonLink::new(button_link_label, link)
331 .no_icon(true)
332 .label_size(LabelSize::Small)
333 .label_color(Color::Muted),
334 )
335 .child(
336 Label::new("to generate an API key.")
337 .size(LabelSize::Small)
338 .color(Color::Muted),
339 );
340 let configured_card_label = if is_from_env_var {
341 "API Key Set in Environment Variable"
342 } else {
343 "API Key Configured"
344 };
345
346 let container = if has_key {
347 base_container.child(header).child(
348 ConfiguredApiCard::new(configured_card_label)
349 .button_label("Reset Key")
350 .button_tab_index(0)
351 .disabled(is_from_env_var)
352 .when_some(env_var_name, |this, env_var_name| {
353 this.when(is_from_env_var, |this| {
354 this.tooltip_label(format!(
355 "To reset your API key, unset the {} environment variable.",
356 env_var_name
357 ))
358 })
359 })
360 .on_click(move |_, _, cx| {
361 write_key(None, cx);
362 }),
363 )
364 } else {
365 base_container.child(header).child(
366 h_flex()
367 .pt_2p5()
368 .w_full()
369 .justify_between()
370 .child(
371 v_flex()
372 .w_full()
373 .max_w_1_2()
374 .child(Label::new("API Key"))
375 .child(description)
376 .when_some(env_var_name, |this, env_var_name| {
377 this.child({
378 let label = format!(
379 "Or set the {} env var and restart Zed.",
380 env_var_name.as_ref()
381 );
382 Label::new(label).size(LabelSize::Small).color(Color::Muted)
383 })
384 }),
385 )
386 .child(
387 SettingsInputField::new()
388 .tab_index(0)
389 .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
390 .on_confirm(move |api_key, cx| {
391 write_key(api_key.filter(|key| !key.is_empty()), cx);
392 }),
393 ),
394 )
395 };
396
397 container.when_some(additional_fields, |this, additional_fields| {
398 this.child(
399 div()
400 .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
401 .px_neg_8()
402 .border_t_1()
403 .border_color(cx.theme().colors().border_variant)
404 .child(additional_fields),
405 )
406 })
407}
408
409fn codestral_settings() -> Box<[SettingsPageItem]> {
410 Box::new([
411 SettingsPageItem::SettingItem(SettingItem {
412 title: "API URL",
413 description: "The API URL to use for Codestral.",
414 field: Box::new(SettingField {
415 pick: |settings| {
416 settings
417 .project
418 .all_languages
419 .edit_predictions
420 .as_ref()?
421 .codestral
422 .as_ref()?
423 .api_url
424 .as_ref()
425 },
426 write: |settings, value| {
427 settings
428 .project
429 .all_languages
430 .edit_predictions
431 .get_or_insert_default()
432 .codestral
433 .get_or_insert_default()
434 .api_url = value;
435 },
436 json_path: Some("edit_predictions.codestral.api_url"),
437 }),
438 metadata: Some(Box::new(SettingsFieldMetadata {
439 placeholder: Some(CODESTRAL_API_URL),
440 ..Default::default()
441 })),
442 files: USER,
443 }),
444 SettingsPageItem::SettingItem(SettingItem {
445 title: "Max Tokens",
446 description: "The maximum number of tokens to generate.",
447 field: Box::new(SettingField {
448 pick: |settings| {
449 settings
450 .project
451 .all_languages
452 .edit_predictions
453 .as_ref()?
454 .codestral
455 .as_ref()?
456 .max_tokens
457 .as_ref()
458 },
459 write: |settings, value| {
460 settings
461 .project
462 .all_languages
463 .edit_predictions
464 .get_or_insert_default()
465 .codestral
466 .get_or_insert_default()
467 .max_tokens = value;
468 },
469 json_path: Some("edit_predictions.codestral.max_tokens"),
470 }),
471 metadata: None,
472 files: USER,
473 }),
474 SettingsPageItem::SettingItem(SettingItem {
475 title: "Model",
476 description: "The Codestral model id to use.",
477 field: Box::new(SettingField {
478 pick: |settings| {
479 settings
480 .project
481 .all_languages
482 .edit_predictions
483 .as_ref()?
484 .codestral
485 .as_ref()?
486 .model
487 .as_ref()
488 },
489 write: |settings, value| {
490 settings
491 .project
492 .all_languages
493 .edit_predictions
494 .get_or_insert_default()
495 .codestral
496 .get_or_insert_default()
497 .model = value;
498 },
499 json_path: Some("edit_predictions.codestral.model"),
500 }),
501 metadata: Some(Box::new(SettingsFieldMetadata {
502 placeholder: Some("codestral-latest"),
503 ..Default::default()
504 })),
505 files: USER,
506 }),
507 ])
508}
509
510pub(crate) fn render_github_copilot_provider(
511 window: &mut Window,
512 cx: &mut App,
513) -> impl IntoElement {
514 let configuration_view = window.use_state(cx, |_, cx| {
515 copilot::ConfigurationView::new(
516 |cx| {
517 copilot::Copilot::global(cx)
518 .is_some_and(|copilot| copilot.read(cx).is_authenticated())
519 },
520 copilot::ConfigurationMode::EditPrediction,
521 cx,
522 )
523 });
524
525 v_flex()
526 .id("github-copilot")
527 .min_w_0()
528 .gap_1p5()
529 .child(
530 SettingsSectionHeader::new("GitHub Copilot")
531 .icon(IconName::Copilot)
532 .no_padding(true),
533 )
534 .child(configuration_view)
535}