add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
  7use language_model::LanguageModelRegistry;
  8use language_models::{
  9    AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
 10    provider::open_ai_compatible::AvailableModel,
 11};
 12use settings::update_settings_file;
 13use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
 14use ui_input::SingleLineInput;
 15use workspace::{ModalView, Workspace};
 16
 17#[derive(Clone, Copy)]
 18pub enum LlmCompatibleProvider {
 19    OpenAi,
 20}
 21
 22impl LlmCompatibleProvider {
 23    fn name(&self) -> &'static str {
 24        match self {
 25            LlmCompatibleProvider::OpenAi => "OpenAI",
 26        }
 27    }
 28
 29    fn api_url(&self) -> &'static str {
 30        match self {
 31            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 32        }
 33    }
 34}
 35
 36struct AddLlmProviderInput {
 37    provider_name: Entity<SingleLineInput>,
 38    api_url: Entity<SingleLineInput>,
 39    api_key: Entity<SingleLineInput>,
 40    models: Vec<ModelInput>,
 41}
 42
 43impl AddLlmProviderInput {
 44    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 45        let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
 46        let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
 47        let api_key = single_line_input(
 48            "API Key",
 49            "000000000000000000000000000000000000000000000000",
 50            None,
 51            window,
 52            cx,
 53        );
 54
 55        Self {
 56            provider_name,
 57            api_url,
 58            api_key,
 59            models: vec![ModelInput::new(window, cx)],
 60        }
 61    }
 62
 63    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 64        self.models.push(ModelInput::new(window, cx));
 65    }
 66
 67    fn remove_model(&mut self, index: usize) {
 68        self.models.remove(index);
 69    }
 70}
 71
 72struct ModelInput {
 73    name: Entity<SingleLineInput>,
 74    max_completion_tokens: Entity<SingleLineInput>,
 75    max_output_tokens: Entity<SingleLineInput>,
 76    max_tokens: Entity<SingleLineInput>,
 77}
 78
 79impl ModelInput {
 80    fn new(window: &mut Window, cx: &mut App) -> Self {
 81        let model_name = single_line_input(
 82            "Model Name",
 83            "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
 84            None,
 85            window,
 86            cx,
 87        );
 88        let max_completion_tokens = single_line_input(
 89            "Max Completion Tokens",
 90            "200000",
 91            Some("200000"),
 92            window,
 93            cx,
 94        );
 95        let max_output_tokens = single_line_input(
 96            "Max Output Tokens",
 97            "Max Output Tokens",
 98            Some("32000"),
 99            window,
100            cx,
101        );
102        let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
103        Self {
104            name: model_name,
105            max_completion_tokens,
106            max_output_tokens,
107            max_tokens,
108        }
109    }
110
111    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
112        let name = self.name.read(cx).text(cx);
113        if name.is_empty() {
114            return Err(SharedString::from("Model Name cannot be empty"));
115        }
116        Ok(AvailableModel {
117            name,
118            display_name: None,
119            max_completion_tokens: Some(
120                self.max_completion_tokens
121                    .read(cx)
122                    .text(cx)
123                    .parse::<u64>()
124                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
125            ),
126            max_output_tokens: Some(
127                self.max_output_tokens
128                    .read(cx)
129                    .text(cx)
130                    .parse::<u64>()
131                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
132            ),
133            max_tokens: self
134                .max_tokens
135                .read(cx)
136                .text(cx)
137                .parse::<u64>()
138                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
139        })
140    }
141}
142
143fn single_line_input(
144    label: impl Into<SharedString>,
145    placeholder: impl Into<SharedString>,
146    text: Option<&str>,
147    window: &mut Window,
148    cx: &mut App,
149) -> Entity<SingleLineInput> {
150    cx.new(|cx| {
151        let input = SingleLineInput::new(window, cx, placeholder).label(label);
152        if let Some(text) = text {
153            input
154                .editor()
155                .update(cx, |editor, cx| editor.set_text(text, window, cx));
156        }
157        input
158    })
159}
160
161fn save_provider_to_settings(
162    input: &AddLlmProviderInput,
163    cx: &mut App,
164) -> Task<Result<(), SharedString>> {
165    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
166    if provider_name.is_empty() {
167        return Task::ready(Err("Provider Name cannot be empty".into()));
168    }
169
170    if LanguageModelRegistry::read_global(cx)
171        .providers()
172        .iter()
173        .any(|provider| {
174            provider.id().0.as_ref() == provider_name.as_ref()
175                || provider.name().0.as_ref() == provider_name.as_ref()
176        })
177    {
178        return Task::ready(Err(
179            "Provider Name is already taken by another provider".into()
180        ));
181    }
182
183    let api_url = input.api_url.read(cx).text(cx);
184    if api_url.is_empty() {
185        return Task::ready(Err("API URL cannot be empty".into()));
186    }
187
188    let api_key = input.api_key.read(cx).text(cx);
189    if api_key.is_empty() {
190        return Task::ready(Err("API Key cannot be empty".into()));
191    }
192
193    let mut models = Vec::new();
194    let mut model_names: HashSet<String> = HashSet::default();
195    for model in &input.models {
196        match model.parse(cx) {
197            Ok(model) => {
198                if !model_names.insert(model.name.clone()) {
199                    return Task::ready(Err("Model Names must be unique".into()));
200                }
201                models.push(model)
202            }
203            Err(err) => return Task::ready(Err(err)),
204        }
205    }
206
207    let fs = <dyn Fs>::global(cx);
208    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
209    cx.spawn(async move |cx| {
210        task.await
211            .map_err(|_| "Failed to write API key to keychain")?;
212        cx.update(|cx| {
213            update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
214                settings.openai_compatible.get_or_insert_default().insert(
215                    provider_name,
216                    OpenAiCompatibleSettingsContent {
217                        api_url,
218                        available_models: models,
219                    },
220                );
221            });
222        })
223        .ok();
224        Ok(())
225    })
226}
227
228pub struct AddLlmProviderModal {
229    provider: LlmCompatibleProvider,
230    input: AddLlmProviderInput,
231    focus_handle: FocusHandle,
232    last_error: Option<SharedString>,
233}
234
235impl AddLlmProviderModal {
236    pub fn toggle(
237        provider: LlmCompatibleProvider,
238        workspace: &mut Workspace,
239        window: &mut Window,
240        cx: &mut Context<Workspace>,
241    ) {
242        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
243    }
244
245    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
246        Self {
247            input: AddLlmProviderInput::new(provider, window, cx),
248            provider,
249            last_error: None,
250            focus_handle: cx.focus_handle(),
251        }
252    }
253
254    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
255        let task = save_provider_to_settings(&self.input, cx);
256        cx.spawn(async move |this, cx| {
257            let result = task.await;
258            this.update(cx, |this, cx| match result {
259                Ok(_) => {
260                    cx.emit(DismissEvent);
261                }
262                Err(error) => {
263                    this.last_error = Some(error);
264                    cx.notify();
265                }
266            })
267        })
268        .detach_and_log_err(cx);
269    }
270
271    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
272        cx.emit(DismissEvent);
273    }
274
275    fn render_section(&self) -> Section {
276        Section::new()
277            .child(self.input.provider_name.clone())
278            .child(self.input.api_url.clone())
279            .child(self.input.api_key.clone())
280    }
281
282    fn render_model_section(&self, cx: &mut Context<Self>) -> Section {
283        Section::new().child(
284            v_flex()
285                .gap_2()
286                .child(
287                    h_flex()
288                        .justify_between()
289                        .child(Label::new("Models").size(LabelSize::Small))
290                        .child(
291                            Button::new("add-model", "Add Model")
292                                .icon(IconName::Plus)
293                                .icon_position(IconPosition::Start)
294                                .icon_size(IconSize::XSmall)
295                                .icon_color(Color::Muted)
296                                .label_size(LabelSize::Small)
297                                .on_click(cx.listener(|this, _, window, cx| {
298                                    this.input.add_model(window, cx);
299                                    cx.notify();
300                                })),
301                        ),
302                )
303                .children(
304                    self.input
305                        .models
306                        .iter()
307                        .enumerate()
308                        .map(|(ix, _)| self.render_model(ix, cx)),
309                ),
310        )
311    }
312
313    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
314        let has_more_than_one_model = self.input.models.len() > 1;
315        let model = &self.input.models[ix];
316
317        v_flex()
318            .p_2()
319            .gap_2()
320            .rounded_sm()
321            .border_1()
322            .border_dashed()
323            .border_color(cx.theme().colors().border.opacity(0.6))
324            .bg(cx.theme().colors().element_active.opacity(0.15))
325            .child(model.name.clone())
326            .child(
327                h_flex()
328                    .gap_2()
329                    .child(model.max_completion_tokens.clone())
330                    .child(model.max_output_tokens.clone()),
331            )
332            .child(model.max_tokens.clone())
333            .when(has_more_than_one_model, |this| {
334                this.child(
335                    Button::new(("remove-model", ix), "Remove Model")
336                        .icon(IconName::Trash)
337                        .icon_position(IconPosition::Start)
338                        .icon_size(IconSize::XSmall)
339                        .icon_color(Color::Muted)
340                        .label_size(LabelSize::Small)
341                        .style(ButtonStyle::Outlined)
342                        .full_width()
343                        .on_click(cx.listener(move |this, _, _window, cx| {
344                            this.input.remove_model(ix);
345                            cx.notify();
346                        })),
347                )
348            })
349    }
350}
351
352impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
353
354impl Focusable for AddLlmProviderModal {
355    fn focus_handle(&self, _cx: &App) -> FocusHandle {
356        self.focus_handle.clone()
357    }
358}
359
360impl ModalView for AddLlmProviderModal {}
361
362impl Render for AddLlmProviderModal {
363    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
364        let focus_handle = self.focus_handle(cx);
365
366        div()
367            .id("add-llm-provider-modal")
368            .key_context("AddLlmProviderModal")
369            .w(rems(34.))
370            .elevation_3(cx)
371            .on_action(cx.listener(Self::cancel))
372            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
373                this.focus_handle(cx).focus(window);
374            }))
375            .child(
376                Modal::new("configure-context-server", None)
377                    .header(ModalHeader::new().headline("Add LLM Provider").description(
378                        match self.provider {
379                            LlmCompatibleProvider::OpenAi => {
380                                "This provider will use an OpenAI compatible API."
381                            }
382                        },
383                    ))
384                    .when_some(self.last_error.clone(), |this, error| {
385                        this.section(
386                            Section::new().child(
387                                Banner::new()
388                                    .severity(ui::Severity::Warning)
389                                    .child(div().text_xs().child(error)),
390                            ),
391                        )
392                    })
393                    .child(
394                        v_flex()
395                            .id("modal_content")
396                            .max_h_128()
397                            .overflow_y_scroll()
398                            .gap_2()
399                            .child(self.render_section())
400                            .child(self.render_model_section(cx)),
401                    )
402                    .footer(
403                        ModalFooter::new().end_slot(
404                            h_flex()
405                                .gap_1()
406                                .child(
407                                    Button::new("cancel", "Cancel")
408                                        .key_binding(
409                                            KeyBinding::for_action_in(
410                                                &menu::Cancel,
411                                                &focus_handle,
412                                                window,
413                                                cx,
414                                            )
415                                            .map(|kb| kb.size(rems_from_px(12.))),
416                                        )
417                                        .on_click(cx.listener(|this, _event, window, cx| {
418                                            this.cancel(&menu::Cancel, window, cx)
419                                        })),
420                                )
421                                .child(
422                                    Button::new("save-server", "Save Provider")
423                                        .key_binding(
424                                            KeyBinding::for_action_in(
425                                                &menu::Confirm,
426                                                &focus_handle,
427                                                window,
428                                                cx,
429                                            )
430                                            .map(|kb| kb.size(rems_from_px(12.))),
431                                        )
432                                        .on_click(cx.listener(|this, _event, window, cx| {
433                                            this.confirm(&menu::Confirm, window, cx)
434                                        })),
435                                ),
436                        ),
437                    ),
438            )
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use editor::EditorSettings;
446    use fs::FakeFs;
447    use gpui::{TestAppContext, VisualTestContext};
448    use language::language_settings;
449    use language_model::{
450        LanguageModelProviderId, LanguageModelProviderName,
451        fake_provider::FakeLanguageModelProvider,
452    };
453    use project::Project;
454    use settings::{Settings as _, SettingsStore};
455    use util::path;
456
457    #[gpui::test]
458    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
459        let cx = setup_test(cx).await;
460
461        assert_eq!(
462            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
463            Some("Provider Name cannot be empty".into())
464        );
465
466        assert_eq!(
467            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
468            Some("API URL cannot be empty".into())
469        );
470
471        assert_eq!(
472            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
473            Some("API Key cannot be empty".into())
474        );
475
476        assert_eq!(
477            save_provider_validation_errors(
478                "someprovider",
479                "someurl",
480                "somekey",
481                vec![("", "200000", "200000", "32000")],
482                cx,
483            )
484            .await,
485            Some("Model Name cannot be empty".into())
486        );
487
488        assert_eq!(
489            save_provider_validation_errors(
490                "someprovider",
491                "someurl",
492                "somekey",
493                vec![("somemodel", "abc", "200000", "32000")],
494                cx,
495            )
496            .await,
497            Some("Max Tokens must be a number".into())
498        );
499
500        assert_eq!(
501            save_provider_validation_errors(
502                "someprovider",
503                "someurl",
504                "somekey",
505                vec![("somemodel", "200000", "abc", "32000")],
506                cx,
507            )
508            .await,
509            Some("Max Completion Tokens must be a number".into())
510        );
511
512        assert_eq!(
513            save_provider_validation_errors(
514                "someprovider",
515                "someurl",
516                "somekey",
517                vec![("somemodel", "200000", "200000", "abc")],
518                cx,
519            )
520            .await,
521            Some("Max Output Tokens must be a number".into())
522        );
523
524        assert_eq!(
525            save_provider_validation_errors(
526                "someprovider",
527                "someurl",
528                "somekey",
529                vec![
530                    ("somemodel", "200000", "200000", "32000"),
531                    ("somemodel", "200000", "200000", "32000"),
532                ],
533                cx,
534            )
535            .await,
536            Some("Model Names must be unique".into())
537        );
538    }
539
540    #[gpui::test]
541    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
542        let cx = setup_test(cx).await;
543
544        cx.update(|_window, cx| {
545            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
546                registry.register_provider(
547                    FakeLanguageModelProvider::new(
548                        LanguageModelProviderId::new("someprovider"),
549                        LanguageModelProviderName::new("Some Provider"),
550                    ),
551                    cx,
552                );
553            });
554        });
555
556        assert_eq!(
557            save_provider_validation_errors(
558                "someprovider",
559                "someurl",
560                "someapikey",
561                vec![("somemodel", "200000", "200000", "32000")],
562                cx,
563            )
564            .await,
565            Some("Provider Name is already taken by another provider".into())
566        );
567    }
568
569    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
570        cx.update(|cx| {
571            let store = SettingsStore::test(cx);
572            cx.set_global(store);
573            workspace::init_settings(cx);
574            Project::init_settings(cx);
575            theme::init(theme::LoadThemes::JustBase, cx);
576            language_settings::init(cx);
577            EditorSettings::register(cx);
578            language_model::init_settings(cx);
579            language_models::init_settings(cx);
580        });
581
582        let fs = FakeFs::new(cx.executor());
583        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
584        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
585        let (_, cx) =
586            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
587
588        cx
589    }
590
591    async fn save_provider_validation_errors(
592        provider_name: &str,
593        api_url: &str,
594        api_key: &str,
595        models: Vec<(&str, &str, &str, &str)>,
596        cx: &mut VisualTestContext,
597    ) -> Option<SharedString> {
598        fn set_text(
599            input: &Entity<SingleLineInput>,
600            text: &str,
601            window: &mut Window,
602            cx: &mut App,
603        ) {
604            input.update(cx, |input, cx| {
605                input.editor().update(cx, |editor, cx| {
606                    editor.set_text(text, window, cx);
607                });
608            });
609        }
610
611        let task = cx.update(|window, cx| {
612            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
613            set_text(&input.provider_name, provider_name, window, cx);
614            set_text(&input.api_url, api_url, window, cx);
615            set_text(&input.api_key, api_key, window, cx);
616
617            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
618                models.iter().enumerate()
619            {
620                if i >= input.models.len() {
621                    input.models.push(ModelInput::new(window, cx));
622                }
623                let model = &mut input.models[i];
624                set_text(&model.name, name, window, cx);
625                set_text(&model.max_tokens, max_tokens, window, cx);
626                set_text(
627                    &model.max_completion_tokens,
628                    max_completion_tokens,
629                    window,
630                    cx,
631                );
632                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
633            }
634            save_provider_to_settings(&input, cx)
635        });
636
637        task.await.err()
638    }
639}