1use edit_prediction::{
2 ApiKeyState, Zeta2FeatureFlag,
3 mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
4 sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
5};
6use extension_host::ExtensionStore;
7use feature_flags::FeatureFlagAppExt as _;
8use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
9use language_model::{
10 ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
11};
12use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
13use std::collections::HashMap;
14use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
15
16use crate::{
17 SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
18 components::{SettingsInputField, SettingsSectionHeader},
19};
20
21pub struct EditPredictionSetupPage {
22 settings_window: Entity<SettingsWindow>,
23 scroll_handle: ScrollHandle,
24 extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
25 _registry_subscription: Subscription,
26}
27
28struct ExtensionOAuthProviderView {
29 provider_name: SharedString,
30 provider_icon: IconName,
31 provider_icon_path: Option<SharedString>,
32 configuration_view: AnyView,
33}
34
35impl EditPredictionSetupPage {
36 pub fn new(
37 settings_window: Entity<SettingsWindow>,
38 window: &mut Window,
39 cx: &mut Context<Self>,
40 ) -> Self {
41 let registry_subscription = cx.subscribe_in(
42 &LanguageModelRegistry::global(cx),
43 window,
44 |this, _, event: &language_model::Event, window, cx| match event {
45 language_model::Event::AddedProvider(provider_id) => {
46 this.maybe_add_extension_oauth_view(provider_id, window, cx);
47 }
48 language_model::Event::RemovedProvider(provider_id) => {
49 this.extension_oauth_views.remove(provider_id);
50 }
51 _ => {}
52 },
53 );
54
55 let mut this = Self {
56 settings_window,
57 scroll_handle: ScrollHandle::new(),
58 extension_oauth_views: HashMap::default(),
59 _registry_subscription: registry_subscription,
60 };
61 this.build_extension_oauth_views(window, cx);
62 this
63 }
64
65 fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
66 let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
67 for provider_id in oauth_provider_ids {
68 self.maybe_add_extension_oauth_view(&provider_id, window, cx);
69 }
70 }
71
72 fn maybe_add_extension_oauth_view(
73 &mut self,
74 provider_id: &LanguageModelProviderId,
75 window: &mut Window,
76 cx: &mut Context<Self>,
77 ) {
78 // Check if this provider has OAuth configured in the extension manifest
79 if !is_extension_oauth_provider(provider_id, cx) {
80 return;
81 }
82
83 let registry = LanguageModelRegistry::global(cx).read(cx);
84 let Some(provider) = registry.provider(provider_id) else {
85 return;
86 };
87
88 let provider_name = provider.name().0;
89 let provider_icon = provider.icon();
90 let provider_icon_path = provider.icon_path();
91 let configuration_view =
92 provider.configuration_view(ConfigurationViewTargetAgent::EditPrediction, window, cx);
93
94 self.extension_oauth_views.insert(
95 provider_id.clone(),
96 ExtensionOAuthProviderView {
97 provider_name,
98 provider_icon,
99 provider_icon_path,
100 configuration_view,
101 },
102 );
103 }
104}
105
106impl Render for EditPredictionSetupPage {
107 fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
108 let settings_window = self.settings_window.clone();
109
110 let copilot_extension_installed = ExtensionStore::global(cx)
111 .read(cx)
112 .installed_extensions()
113 .contains_key("copilot-chat");
114
115 let mut providers: Vec<AnyElement> = Vec::new();
116
117 // Built-in Copilot (hidden if copilot-chat extension is installed)
118 if !copilot_extension_installed {
119 providers.push(render_github_copilot_provider(window, cx).into_any_element());
120 }
121
122 // Extension providers with OAuth support
123 for (provider_id, view) in &self.extension_oauth_views {
124 let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
125 Icon::from_external_svg(icon_path.clone())
126 .size(ui::IconSize::Medium)
127 .into_any_element()
128 } else {
129 Icon::new(view.provider_icon)
130 .size(ui::IconSize::Medium)
131 .into_any_element()
132 };
133
134 providers.push(
135 v_flex()
136 .id(SharedString::from(provider_id.0.to_string()))
137 .min_w_0()
138 .gap_1p5()
139 .child(
140 h_flex().gap_2().items_center().child(icon_element).child(
141 Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
142 ),
143 )
144 .child(view.configuration_view.clone())
145 .into_any_element(),
146 );
147 }
148
149 if cx.has_flag::<Zeta2FeatureFlag>() {
150 providers.push(
151 render_api_key_provider(
152 IconName::Inception,
153 "Mercury",
154 "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
155 mercury_api_token(cx),
156 |_cx| MERCURY_CREDENTIALS_URL,
157 None,
158 window,
159 cx,
160 )
161 .into_any_element(),
162 );
163 }
164
165 if cx.has_flag::<Zeta2FeatureFlag>() {
166 providers.push(
167 render_api_key_provider(
168 IconName::SweepAi,
169 "Sweep",
170 "https://app.sweep.dev/".into(),
171 sweep_api_token(cx),
172 |_cx| SWEEP_CREDENTIALS_URL,
173 None,
174 window,
175 cx,
176 )
177 .into_any_element(),
178 );
179 }
180
181 providers.push(
182 render_api_key_provider(
183 IconName::AiMistral,
184 "Codestral",
185 "https://console.mistral.ai/codestral".into(),
186 codestral_api_key(cx),
187 |cx| language_models::MistralLanguageModelProvider::api_url(cx),
188 Some(settings_window.update(cx, |settings_window, cx| {
189 let codestral_settings = codestral_settings();
190 settings_window
191 .render_sub_page_items_section(
192 codestral_settings.iter().enumerate(),
193 None,
194 window,
195 cx,
196 )
197 .into_any_element()
198 })),
199 window,
200 cx,
201 )
202 .into_any_element(),
203 );
204
205 div()
206 .size_full()
207 .vertical_scrollbar_for(&self.scroll_handle, window, cx)
208 .child(
209 v_flex()
210 .id("ep-setup-page")
211 .min_w_0()
212 .size_full()
213 .px_8()
214 .pb_16()
215 .overflow_y_scroll()
216 .track_scroll(&self.scroll_handle)
217 .children(providers),
218 )
219 }
220}
221
222/// Get extension provider IDs that have OAuth configured.
223fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
224 let extension_store = ExtensionStore::global(cx).read(cx);
225
226 extension_store
227 .installed_extensions()
228 .iter()
229 .flat_map(|(extension_id, entry)| {
230 entry.manifest.language_model_providers.iter().filter_map(
231 move |(provider_id, provider_entry)| {
232 // Check if this provider has OAuth configured
233 let has_oauth = provider_entry
234 .auth
235 .as_ref()
236 .is_some_and(|auth| auth.oauth.is_some());
237
238 if has_oauth {
239 Some(LanguageModelProviderId(
240 format!("{}:{}", extension_id, provider_id).into(),
241 ))
242 } else {
243 None
244 }
245 },
246 )
247 })
248 .collect()
249}
250
251/// Check if a provider ID corresponds to an extension with OAuth configured.
252fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
253 // Extension provider IDs are in the format "extension_id:provider_id"
254 let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
255 return false;
256 };
257
258 let extension_store = ExtensionStore::global(cx).read(cx);
259 let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
260 return false;
261 };
262
263 entry
264 .manifest
265 .language_model_providers
266 .get(local_provider_id)
267 .and_then(|p| p.auth.as_ref())
268 .is_some_and(|auth| auth.oauth.is_some())
269}
270
271fn render_api_key_provider(
272 icon: IconName,
273 title: &'static str,
274 link: SharedString,
275 api_key_state: Entity<ApiKeyState>,
276 current_url: fn(&mut App) -> SharedString,
277 additional_fields: Option<AnyElement>,
278 window: &mut Window,
279 cx: &mut Context<EditPredictionSetupPage>,
280) -> impl IntoElement {
281 let weak_page = cx.weak_entity();
282 _ = window.use_keyed_state(title, cx, |_, cx| {
283 let task = api_key_state.update(cx, |key_state, cx| {
284 key_state.load_if_needed(current_url(cx), |state| state, cx)
285 });
286 cx.spawn(async move |_, cx| {
287 task.await.ok();
288 weak_page
289 .update(cx, |_, cx| {
290 cx.notify();
291 })
292 .ok();
293 })
294 });
295
296 let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
297 (
298 state.has_key(),
299 Some(state.env_var_name().clone()),
300 state.is_from_env_var(),
301 )
302 });
303
304 let write_key = move |api_key: Option<String>, cx: &mut App| {
305 api_key_state
306 .update(cx, |key_state, cx| {
307 let url = current_url(cx);
308 key_state.store(url, api_key, |key_state| key_state, cx)
309 })
310 .detach_and_log_err(cx);
311 };
312
313 let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
314 let header = SettingsSectionHeader::new(title)
315 .icon(icon)
316 .no_padding(true);
317 let button_link_label = format!("{} dashboard", title);
318 let description = h_flex()
319 .min_w_0()
320 .gap_0p5()
321 .child(
322 Label::new("Visit the")
323 .size(LabelSize::Small)
324 .color(Color::Muted),
325 )
326 .child(
327 ButtonLink::new(button_link_label, link)
328 .no_icon(true)
329 .label_size(LabelSize::Small)
330 .label_color(Color::Muted),
331 )
332 .child(
333 Label::new("to generate an API key.")
334 .size(LabelSize::Small)
335 .color(Color::Muted),
336 );
337 let configured_card_label = if is_from_env_var {
338 "API Key Set in Environment Variable"
339 } else {
340 "API Key Configured"
341 };
342
343 let container = if has_key {
344 base_container.child(header).child(
345 ConfiguredApiCard::new(configured_card_label)
346 .button_label("Reset Key")
347 .button_tab_index(0)
348 .disabled(is_from_env_var)
349 .when_some(env_var_name, |this, env_var_name| {
350 this.when(is_from_env_var, |this| {
351 this.tooltip_label(format!(
352 "To reset your API key, unset the {} environment variable.",
353 env_var_name
354 ))
355 })
356 })
357 .on_click(move |_, _, cx| {
358 write_key(None, cx);
359 }),
360 )
361 } else {
362 base_container.child(header).child(
363 h_flex()
364 .pt_2p5()
365 .w_full()
366 .justify_between()
367 .child(
368 v_flex()
369 .w_full()
370 .max_w_1_2()
371 .child(Label::new("API Key"))
372 .child(description)
373 .when_some(env_var_name, |this, env_var_name| {
374 this.child({
375 let label = format!(
376 "Or set the {} env var and restart Zed.",
377 env_var_name.as_ref()
378 );
379 Label::new(label).size(LabelSize::Small).color(Color::Muted)
380 })
381 }),
382 )
383 .child(
384 SettingsInputField::new()
385 .tab_index(0)
386 .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
387 .on_confirm(move |api_key, cx| {
388 write_key(api_key.filter(|key| !key.is_empty()), cx);
389 }),
390 ),
391 )
392 };
393
394 container.when_some(additional_fields, |this, additional_fields| {
395 this.child(
396 div()
397 .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
398 .px_neg_8()
399 .border_t_1()
400 .border_color(cx.theme().colors().border_variant)
401 .child(additional_fields),
402 )
403 })
404}
405
406fn codestral_settings() -> Box<[SettingsPageItem]> {
407 Box::new([
408 SettingsPageItem::SettingItem(SettingItem {
409 title: "API URL",
410 description: "The API URL to use for Codestral.",
411 field: Box::new(SettingField {
412 pick: |settings| {
413 settings
414 .project
415 .all_languages
416 .edit_predictions
417 .as_ref()?
418 .codestral
419 .as_ref()?
420 .api_url
421 .as_ref()
422 },
423 write: |settings, value| {
424 settings
425 .project
426 .all_languages
427 .edit_predictions
428 .get_or_insert_default()
429 .codestral
430 .get_or_insert_default()
431 .api_url = value;
432 },
433 json_path: Some("edit_predictions.codestral.api_url"),
434 }),
435 metadata: Some(Box::new(SettingsFieldMetadata {
436 placeholder: Some(CODESTRAL_API_URL),
437 ..Default::default()
438 })),
439 files: USER,
440 }),
441 SettingsPageItem::SettingItem(SettingItem {
442 title: "Max Tokens",
443 description: "The maximum number of tokens to generate.",
444 field: Box::new(SettingField {
445 pick: |settings| {
446 settings
447 .project
448 .all_languages
449 .edit_predictions
450 .as_ref()?
451 .codestral
452 .as_ref()?
453 .max_tokens
454 .as_ref()
455 },
456 write: |settings, value| {
457 settings
458 .project
459 .all_languages
460 .edit_predictions
461 .get_or_insert_default()
462 .codestral
463 .get_or_insert_default()
464 .max_tokens = value;
465 },
466 json_path: Some("edit_predictions.codestral.max_tokens"),
467 }),
468 metadata: None,
469 files: USER,
470 }),
471 SettingsPageItem::SettingItem(SettingItem {
472 title: "Model",
473 description: "The Codestral model id to use.",
474 field: Box::new(SettingField {
475 pick: |settings| {
476 settings
477 .project
478 .all_languages
479 .edit_predictions
480 .as_ref()?
481 .codestral
482 .as_ref()?
483 .model
484 .as_ref()
485 },
486 write: |settings, value| {
487 settings
488 .project
489 .all_languages
490 .edit_predictions
491 .get_or_insert_default()
492 .codestral
493 .get_or_insert_default()
494 .model = value;
495 },
496 json_path: Some("edit_predictions.codestral.model"),
497 }),
498 metadata: Some(Box::new(SettingsFieldMetadata {
499 placeholder: Some("codestral-latest"),
500 ..Default::default()
501 })),
502 files: USER,
503 }),
504 ])
505}
506
507pub(crate) fn render_github_copilot_provider(
508 window: &mut Window,
509 cx: &mut App,
510) -> impl IntoElement {
511 let configuration_view = window.use_state(cx, |_, cx| {
512 copilot::ConfigurationView::new(
513 |cx| {
514 copilot::Copilot::global(cx)
515 .is_some_and(|copilot| copilot.read(cx).is_authenticated())
516 },
517 copilot::ConfigurationMode::EditPrediction,
518 cx,
519 )
520 });
521
522 v_flex()
523 .id("github-copilot")
524 .min_w_0()
525 .gap_1p5()
526 .child(
527 SettingsSectionHeader::new("GitHub Copilot")
528 .icon(IconName::Copilot)
529 .no_padding(true),
530 )
531 .child(configuration_view)
532}