edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState, Zeta2FeatureFlag,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use feature_flags::FeatureFlagAppExt as _;
  7use gpui::{Entity, ScrollHandle, prelude::*};
  8use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
  9use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
 10
 11const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 12const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 13
 14use crate::{
 15    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 16    components::{SettingsInputField, SettingsSectionHeader},
 17};
 18
 19pub struct EditPredictionSetupPage {
 20    settings_window: Entity<SettingsWindow>,
 21    scroll_handle: ScrollHandle,
 22}
 23
 24impl EditPredictionSetupPage {
 25    pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
 26        Self {
 27            settings_window,
 28            scroll_handle: ScrollHandle::new(),
 29        }
 30    }
 31}
 32
 33impl Render for EditPredictionSetupPage {
 34    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 35        let settings_window = self.settings_window.clone();
 36
 37        let providers = [
 38            Some(render_github_copilot_provider(window, cx).into_any_element()),
 39            cx.has_flag::<Zeta2FeatureFlag>().then(|| {
 40                render_api_key_provider(
 41                    IconName::Inception,
 42                    "Mercury",
 43                    "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 44                    mercury_api_token(cx),
 45                    |_cx| MERCURY_CREDENTIALS_URL,
 46                    None,
 47                    window,
 48                    cx,
 49                )
 50                .into_any_element()
 51            }),
 52            cx.has_flag::<Zeta2FeatureFlag>().then(|| {
 53                render_api_key_provider(
 54                    IconName::SweepAi,
 55                    "Sweep",
 56                    "https://app.sweep.dev/".into(),
 57                    sweep_api_token(cx),
 58                    |_cx| SWEEP_CREDENTIALS_URL,
 59                    None,
 60                    window,
 61                    cx,
 62                )
 63                .into_any_element()
 64            }),
 65            Some(
 66                render_api_key_provider(
 67                    IconName::AiMistral,
 68                    "Codestral",
 69                    "https://console.mistral.ai/codestral".into(),
 70                    codestral_api_key(cx),
 71                    |cx| language_models::MistralLanguageModelProvider::api_url(cx),
 72                    Some(settings_window.update(cx, |settings_window, cx| {
 73                        let codestral_settings = codestral_settings();
 74                        settings_window
 75                            .render_sub_page_items_section(
 76                                codestral_settings.iter().enumerate(),
 77                                None,
 78                                window,
 79                                cx,
 80                            )
 81                            .into_any_element()
 82                    })),
 83                    window,
 84                    cx,
 85                )
 86                .into_any_element(),
 87            ),
 88            Some(render_ollama_provider(settings_window.clone(), window, cx).into_any_element()),
 89        ];
 90
 91        div()
 92            .size_full()
 93            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
 94            .child(
 95                v_flex()
 96                    .id("ep-setup-page")
 97                    .min_w_0()
 98                    .size_full()
 99                    .px_8()
100                    .pb_16()
101                    .overflow_y_scroll()
102                    .track_scroll(&self.scroll_handle)
103                    .children(providers.into_iter().flatten()),
104            )
105    }
106}
107
108fn render_api_key_provider(
109    icon: IconName,
110    title: &'static str,
111    link: SharedString,
112    api_key_state: Entity<ApiKeyState>,
113    current_url: fn(&mut App) -> SharedString,
114    additional_fields: Option<AnyElement>,
115    window: &mut Window,
116    cx: &mut Context<EditPredictionSetupPage>,
117) -> impl IntoElement {
118    let weak_page = cx.weak_entity();
119    _ = window.use_keyed_state(title, cx, |_, cx| {
120        let task = api_key_state.update(cx, |key_state, cx| {
121            key_state.load_if_needed(current_url(cx), |state| state, cx)
122        });
123        cx.spawn(async move |_, cx| {
124            task.await.ok();
125            weak_page
126                .update(cx, |_, cx| {
127                    cx.notify();
128                })
129                .ok();
130        })
131    });
132
133    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
134        (
135            state.has_key(),
136            Some(state.env_var_name().clone()),
137            state.is_from_env_var(),
138        )
139    });
140
141    let write_key = move |api_key: Option<String>, cx: &mut App| {
142        api_key_state
143            .update(cx, |key_state, cx| {
144                let url = current_url(cx);
145                key_state.store(url, api_key, |key_state| key_state, cx)
146            })
147            .detach_and_log_err(cx);
148    };
149
150    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
151    let header = SettingsSectionHeader::new(title)
152        .icon(icon)
153        .no_padding(true);
154    let button_link_label = format!("{} dashboard", title);
155    let description = h_flex()
156        .min_w_0()
157        .gap_0p5()
158        .child(
159            Label::new("Visit the")
160                .size(LabelSize::Small)
161                .color(Color::Muted),
162        )
163        .child(
164            ButtonLink::new(button_link_label, link)
165                .no_icon(true)
166                .label_size(LabelSize::Small)
167                .label_color(Color::Muted),
168        )
169        .child(
170            Label::new("to generate an API key.")
171                .size(LabelSize::Small)
172                .color(Color::Muted),
173        );
174    let configured_card_label = if is_from_env_var {
175        "API Key Set in Environment Variable"
176    } else {
177        "API Key Configured"
178    };
179
180    let container = if has_key {
181        base_container.child(header).child(
182            ConfiguredApiCard::new(configured_card_label)
183                .button_label("Reset Key")
184                .button_tab_index(0)
185                .disabled(is_from_env_var)
186                .when_some(env_var_name, |this, env_var_name| {
187                    this.when(is_from_env_var, |this| {
188                        this.tooltip_label(format!(
189                            "To reset your API key, unset the {} environment variable.",
190                            env_var_name
191                        ))
192                    })
193                })
194                .on_click(move |_, _, cx| {
195                    write_key(None, cx);
196                }),
197        )
198    } else {
199        base_container.child(header).child(
200            h_flex()
201                .pt_2p5()
202                .w_full()
203                .justify_between()
204                .child(
205                    v_flex()
206                        .w_full()
207                        .max_w_1_2()
208                        .child(Label::new("API Key"))
209                        .child(description)
210                        .when_some(env_var_name, |this, env_var_name| {
211                            this.child({
212                                let label = format!(
213                                    "Or set the {} env var and restart Zed.",
214                                    env_var_name.as_ref()
215                                );
216                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
217                            })
218                        }),
219                )
220                .child(
221                    SettingsInputField::new()
222                        .tab_index(0)
223                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
224                        .on_confirm(move |api_key, cx| {
225                            write_key(api_key.filter(|key| !key.is_empty()), cx);
226                        }),
227                ),
228        )
229    };
230
231    container.when_some(additional_fields, |this, additional_fields| {
232        this.child(
233            div()
234                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
235                .px_neg_8()
236                .border_t_1()
237                .border_color(cx.theme().colors().border_variant)
238                .child(additional_fields),
239        )
240    })
241}
242
243fn render_ollama_provider(
244    settings_window: Entity<SettingsWindow>,
245    window: &mut Window,
246    cx: &mut Context<EditPredictionSetupPage>,
247) -> impl IntoElement {
248    let ollama_settings = ollama_settings();
249    let additional_fields = settings_window.update(cx, |settings_window, cx| {
250        settings_window
251            .render_sub_page_items_section(ollama_settings.iter().enumerate(), None, window, cx)
252            .into_any_element()
253    });
254
255    v_flex()
256        .id("ollama")
257        .min_w_0()
258        .pt_8()
259        .gap_1p5()
260        .child(
261            SettingsSectionHeader::new("Ollama")
262                .icon(IconName::ZedPredict)
263                .no_padding(true),
264        )
265        .child(
266            Label::new("Configure the local Ollama server and model used for edit predictions.")
267                .size(LabelSize::Small)
268                .color(Color::Muted),
269        )
270        .child(additional_fields)
271}
272
273fn ollama_settings() -> Box<[SettingsPageItem]> {
274    Box::new([
275        SettingsPageItem::SettingItem(SettingItem {
276            title: "API URL",
277            description: "The base URL of your Ollama server.",
278            field: Box::new(SettingField {
279                pick: |settings| {
280                    settings
281                        .project
282                        .all_languages
283                        .edit_predictions
284                        .as_ref()?
285                        .ollama
286                        .as_ref()?
287                        .api_url
288                        .as_ref()
289                },
290                write: |settings, value| {
291                    settings
292                        .project
293                        .all_languages
294                        .edit_predictions
295                        .get_or_insert_default()
296                        .ollama
297                        .get_or_insert_default()
298                        .api_url = value;
299                },
300                json_path: Some("edit_predictions.ollama.api_url"),
301            }),
302            metadata: Some(Box::new(SettingsFieldMetadata {
303                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
304                ..Default::default()
305            })),
306            files: USER,
307        }),
308        SettingsPageItem::SettingItem(SettingItem {
309            title: "Model",
310            description: "The Ollama model to use for edit predictions.",
311            field: Box::new(SettingField {
312                pick: |settings| {
313                    settings
314                        .project
315                        .all_languages
316                        .edit_predictions
317                        .as_ref()?
318                        .ollama
319                        .as_ref()?
320                        .model
321                        .as_ref()
322                },
323                write: |settings, value| {
324                    settings
325                        .project
326                        .all_languages
327                        .edit_predictions
328                        .get_or_insert_default()
329                        .ollama
330                        .get_or_insert_default()
331                        .model = value;
332                },
333                json_path: Some("edit_predictions.ollama.model"),
334            }),
335            metadata: Some(Box::new(SettingsFieldMetadata {
336                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
337                ..Default::default()
338            })),
339            files: USER,
340        }),
341    ])
342}
343
344fn codestral_settings() -> Box<[SettingsPageItem]> {
345    Box::new([
346        SettingsPageItem::SettingItem(SettingItem {
347            title: "API URL",
348            description: "The API URL to use for Codestral.",
349            field: Box::new(SettingField {
350                pick: |settings| {
351                    settings
352                        .project
353                        .all_languages
354                        .edit_predictions
355                        .as_ref()?
356                        .codestral
357                        .as_ref()?
358                        .api_url
359                        .as_ref()
360                },
361                write: |settings, value| {
362                    settings
363                        .project
364                        .all_languages
365                        .edit_predictions
366                        .get_or_insert_default()
367                        .codestral
368                        .get_or_insert_default()
369                        .api_url = value;
370                },
371                json_path: Some("edit_predictions.codestral.api_url"),
372            }),
373            metadata: Some(Box::new(SettingsFieldMetadata {
374                placeholder: Some(CODESTRAL_API_URL),
375                ..Default::default()
376            })),
377            files: USER,
378        }),
379        SettingsPageItem::SettingItem(SettingItem {
380            title: "Max Tokens",
381            description: "The maximum number of tokens to generate.",
382            field: Box::new(SettingField {
383                pick: |settings| {
384                    settings
385                        .project
386                        .all_languages
387                        .edit_predictions
388                        .as_ref()?
389                        .codestral
390                        .as_ref()?
391                        .max_tokens
392                        .as_ref()
393                },
394                write: |settings, value| {
395                    settings
396                        .project
397                        .all_languages
398                        .edit_predictions
399                        .get_or_insert_default()
400                        .codestral
401                        .get_or_insert_default()
402                        .max_tokens = value;
403                },
404                json_path: Some("edit_predictions.codestral.max_tokens"),
405            }),
406            metadata: None,
407            files: USER,
408        }),
409        SettingsPageItem::SettingItem(SettingItem {
410            title: "Model",
411            description: "The Codestral model id to use.",
412            field: Box::new(SettingField {
413                pick: |settings| {
414                    settings
415                        .project
416                        .all_languages
417                        .edit_predictions
418                        .as_ref()?
419                        .codestral
420                        .as_ref()?
421                        .model
422                        .as_ref()
423                },
424                write: |settings, value| {
425                    settings
426                        .project
427                        .all_languages
428                        .edit_predictions
429                        .get_or_insert_default()
430                        .codestral
431                        .get_or_insert_default()
432                        .model = value;
433                },
434                json_path: Some("edit_predictions.codestral.model"),
435            }),
436            metadata: Some(Box::new(SettingsFieldMetadata {
437                placeholder: Some("codestral-latest"),
438                ..Default::default()
439            })),
440            files: USER,
441        }),
442    ])
443}
444
445pub(crate) fn render_github_copilot_provider(
446    window: &mut Window,
447    cx: &mut App,
448) -> impl IntoElement {
449    let configuration_view = window.use_state(cx, |_, cx| {
450        copilot::ConfigurationView::new(
451            |cx| {
452                copilot::Copilot::global(cx)
453                    .is_some_and(|copilot| copilot.read(cx).is_authenticated())
454            },
455            copilot::ConfigurationMode::EditPrediction,
456            cx,
457        )
458    });
459
460    v_flex()
461        .id("github-copilot")
462        .min_w_0()
463        .gap_1p5()
464        .child(
465            SettingsSectionHeader::new("GitHub Copilot")
466                .icon(IconName::Copilot)
467                .no_padding(true),
468        )
469        .child(configuration_view)
470}