add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
  7use language_model::LanguageModelRegistry;
  8use language_models::{
  9    AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
 10    provider::open_ai_compatible::AvailableModel,
 11};
 12use settings::update_settings_file;
 13use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
 14use ui_input::SingleLineInput;
 15use workspace::{ModalView, Workspace};
 16
 17#[derive(Clone, Copy)]
 18pub enum LlmCompatibleProvider {
 19    OpenAi,
 20}
 21
 22impl LlmCompatibleProvider {
 23    fn name(&self) -> &'static str {
 24        match self {
 25            LlmCompatibleProvider::OpenAi => "OpenAI",
 26        }
 27    }
 28
 29    fn api_url(&self) -> &'static str {
 30        match self {
 31            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 32        }
 33    }
 34}
 35
 36struct AddLlmProviderInput {
 37    provider_name: Entity<SingleLineInput>,
 38    api_url: Entity<SingleLineInput>,
 39    api_key: Entity<SingleLineInput>,
 40    models: Vec<ModelInput>,
 41}
 42
 43impl AddLlmProviderInput {
 44    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 45        let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
 46        let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
 47        let api_key = single_line_input(
 48            "API Key",
 49            "000000000000000000000000000000000000000000000000",
 50            None,
 51            window,
 52            cx,
 53        );
 54
 55        Self {
 56            provider_name,
 57            api_url,
 58            api_key,
 59            models: vec![ModelInput::new(window, cx)],
 60        }
 61    }
 62
 63    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 64        self.models.push(ModelInput::new(window, cx));
 65    }
 66
 67    fn remove_model(&mut self, index: usize) {
 68        self.models.remove(index);
 69    }
 70}
 71
 72struct ModelInput {
 73    name: Entity<SingleLineInput>,
 74    max_completion_tokens: Entity<SingleLineInput>,
 75    max_output_tokens: Entity<SingleLineInput>,
 76    max_tokens: Entity<SingleLineInput>,
 77}
 78
 79impl ModelInput {
 80    fn new(window: &mut Window, cx: &mut App) -> Self {
 81        let model_name = single_line_input(
 82            "Model Name",
 83            "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
 84            None,
 85            window,
 86            cx,
 87        );
 88        let max_completion_tokens = single_line_input(
 89            "Max Completion Tokens",
 90            "200000",
 91            Some("200000"),
 92            window,
 93            cx,
 94        );
 95        let max_output_tokens = single_line_input(
 96            "Max Output Tokens",
 97            "Max Output Tokens",
 98            Some("32000"),
 99            window,
100            cx,
101        );
102        let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
103        Self {
104            name: model_name,
105            max_completion_tokens,
106            max_output_tokens,
107            max_tokens,
108        }
109    }
110
111    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
112        let name = self.name.read(cx).text(cx);
113        if name.is_empty() {
114            return Err(SharedString::from("Model Name cannot be empty"));
115        }
116        Ok(AvailableModel {
117            name,
118            display_name: None,
119            max_completion_tokens: Some(
120                self.max_completion_tokens
121                    .read(cx)
122                    .text(cx)
123                    .parse::<u64>()
124                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
125            ),
126            max_output_tokens: Some(
127                self.max_output_tokens
128                    .read(cx)
129                    .text(cx)
130                    .parse::<u64>()
131                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
132            ),
133            max_tokens: self
134                .max_tokens
135                .read(cx)
136                .text(cx)
137                .parse::<u64>()
138                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
139        })
140    }
141}
142
143fn single_line_input(
144    label: impl Into<SharedString>,
145    placeholder: impl Into<SharedString>,
146    text: Option<&str>,
147    window: &mut Window,
148    cx: &mut App,
149) -> Entity<SingleLineInput> {
150    cx.new(|cx| {
151        let input = SingleLineInput::new(window, cx, placeholder).label(label);
152        if let Some(text) = text {
153            input
154                .editor()
155                .update(cx, |editor, cx| editor.set_text(text, window, cx));
156        }
157        input
158    })
159}
160
161fn save_provider_to_settings(
162    input: &AddLlmProviderInput,
163    cx: &mut App,
164) -> Task<Result<(), SharedString>> {
165    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
166    if provider_name.is_empty() {
167        return Task::ready(Err("Provider Name cannot be empty".into()));
168    }
169
170    if LanguageModelRegistry::read_global(cx)
171        .providers()
172        .iter()
173        .any(|provider| {
174            provider.id().0.as_ref() == provider_name.as_ref()
175                || provider.name().0.as_ref() == provider_name.as_ref()
176        })
177    {
178        return Task::ready(Err(
179            "Provider Name is already taken by another provider".into()
180        ));
181    }
182
183    let api_url = input.api_url.read(cx).text(cx);
184    if api_url.is_empty() {
185        return Task::ready(Err("API URL cannot be empty".into()));
186    }
187
188    let api_key = input.api_key.read(cx).text(cx);
189    if api_key.is_empty() {
190        return Task::ready(Err("API Key cannot be empty".into()));
191    }
192
193    let mut models = Vec::new();
194    let mut model_names: HashSet<String> = HashSet::default();
195    for model in &input.models {
196        match model.parse(cx) {
197            Ok(model) => {
198                if !model_names.insert(model.name.clone()) {
199                    return Task::ready(Err("Model Names must be unique".into()));
200                }
201                models.push(model)
202            }
203            Err(err) => return Task::ready(Err(err)),
204        }
205    }
206
207    let fs = <dyn Fs>::global(cx);
208    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
209    cx.spawn(async move |cx| {
210        task.await
211            .map_err(|_| "Failed to write API key to keychain")?;
212        cx.update(|cx| {
213            update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
214                settings.openai_compatible.get_or_insert_default().insert(
215                    provider_name,
216                    OpenAiCompatibleSettingsContent {
217                        api_url,
218                        available_models: models,
219                    },
220                );
221            });
222        })
223        .ok();
224        Ok(())
225    })
226}
227
228pub struct AddLlmProviderModal {
229    provider: LlmCompatibleProvider,
230    input: AddLlmProviderInput,
231    focus_handle: FocusHandle,
232    last_error: Option<SharedString>,
233}
234
235impl AddLlmProviderModal {
236    pub fn toggle(
237        provider: LlmCompatibleProvider,
238        workspace: &mut Workspace,
239        window: &mut Window,
240        cx: &mut Context<Workspace>,
241    ) {
242        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
243    }
244
245    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
246        Self {
247            input: AddLlmProviderInput::new(provider, window, cx),
248            provider,
249            last_error: None,
250            focus_handle: cx.focus_handle(),
251        }
252    }
253
254    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
255        let task = save_provider_to_settings(&self.input, cx);
256        cx.spawn(async move |this, cx| {
257            let result = task.await;
258            this.update(cx, |this, cx| match result {
259                Ok(_) => {
260                    cx.emit(DismissEvent);
261                }
262                Err(error) => {
263                    this.last_error = Some(error);
264                    cx.notify();
265                }
266            })
267        })
268        .detach_and_log_err(cx);
269    }
270
271    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
272        cx.emit(DismissEvent);
273    }
274
275    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
276        v_flex()
277            .mt_1()
278            .gap_2()
279            .child(
280                h_flex()
281                    .justify_between()
282                    .child(Label::new("Models").size(LabelSize::Small))
283                    .child(
284                        Button::new("add-model", "Add Model")
285                            .icon(IconName::Plus)
286                            .icon_position(IconPosition::Start)
287                            .icon_size(IconSize::XSmall)
288                            .icon_color(Color::Muted)
289                            .label_size(LabelSize::Small)
290                            .on_click(cx.listener(|this, _, window, cx| {
291                                this.input.add_model(window, cx);
292                                cx.notify();
293                            })),
294                    ),
295            )
296            .children(
297                self.input
298                    .models
299                    .iter()
300                    .enumerate()
301                    .map(|(ix, _)| self.render_model(ix, cx)),
302            )
303    }
304
305    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
306        let has_more_than_one_model = self.input.models.len() > 1;
307        let model = &self.input.models[ix];
308
309        v_flex()
310            .p_2()
311            .gap_2()
312            .rounded_sm()
313            .border_1()
314            .border_dashed()
315            .border_color(cx.theme().colors().border.opacity(0.6))
316            .bg(cx.theme().colors().element_active.opacity(0.15))
317            .child(model.name.clone())
318            .child(
319                h_flex()
320                    .gap_2()
321                    .child(model.max_completion_tokens.clone())
322                    .child(model.max_output_tokens.clone()),
323            )
324            .child(model.max_tokens.clone())
325            .when(has_more_than_one_model, |this| {
326                this.child(
327                    Button::new(("remove-model", ix), "Remove Model")
328                        .icon(IconName::Trash)
329                        .icon_position(IconPosition::Start)
330                        .icon_size(IconSize::XSmall)
331                        .icon_color(Color::Muted)
332                        .label_size(LabelSize::Small)
333                        .style(ButtonStyle::Outlined)
334                        .full_width()
335                        .on_click(cx.listener(move |this, _, _window, cx| {
336                            this.input.remove_model(ix);
337                            cx.notify();
338                        })),
339                )
340            })
341    }
342}
343
344impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
345
346impl Focusable for AddLlmProviderModal {
347    fn focus_handle(&self, _cx: &App) -> FocusHandle {
348        self.focus_handle.clone()
349    }
350}
351
352impl ModalView for AddLlmProviderModal {}
353
354impl Render for AddLlmProviderModal {
355    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
356        let focus_handle = self.focus_handle(cx);
357
358        div()
359            .id("add-llm-provider-modal")
360            .key_context("AddLlmProviderModal")
361            .w(rems(34.))
362            .elevation_3(cx)
363            .on_action(cx.listener(Self::cancel))
364            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
365                this.focus_handle(cx).focus(window);
366            }))
367            .child(
368                Modal::new("configure-context-server", None)
369                    .header(ModalHeader::new().headline("Add LLM Provider").description(
370                        match self.provider {
371                            LlmCompatibleProvider::OpenAi => {
372                                "This provider will use an OpenAI compatible API."
373                            }
374                        },
375                    ))
376                    .when_some(self.last_error.clone(), |this, error| {
377                        this.section(
378                            Section::new().child(
379                                Banner::new()
380                                    .severity(ui::Severity::Warning)
381                                    .child(div().text_xs().child(error)),
382                            ),
383                        )
384                    })
385                    .child(
386                        v_flex()
387                            .id("modal_content")
388                            .size_full()
389                            .max_h_128()
390                            .overflow_y_scroll()
391                            .px(DynamicSpacing::Base12.rems(cx))
392                            .gap(DynamicSpacing::Base04.rems(cx))
393                            .child(self.input.provider_name.clone())
394                            .child(self.input.api_url.clone())
395                            .child(self.input.api_key.clone())
396                            .child(self.render_model_section(cx)),
397                    )
398                    .footer(
399                        ModalFooter::new().end_slot(
400                            h_flex()
401                                .gap_1()
402                                .child(
403                                    Button::new("cancel", "Cancel")
404                                        .key_binding(
405                                            KeyBinding::for_action_in(
406                                                &menu::Cancel,
407                                                &focus_handle,
408                                                window,
409                                                cx,
410                                            )
411                                            .map(|kb| kb.size(rems_from_px(12.))),
412                                        )
413                                        .on_click(cx.listener(|this, _event, window, cx| {
414                                            this.cancel(&menu::Cancel, window, cx)
415                                        })),
416                                )
417                                .child(
418                                    Button::new("save-server", "Save Provider")
419                                        .key_binding(
420                                            KeyBinding::for_action_in(
421                                                &menu::Confirm,
422                                                &focus_handle,
423                                                window,
424                                                cx,
425                                            )
426                                            .map(|kb| kb.size(rems_from_px(12.))),
427                                        )
428                                        .on_click(cx.listener(|this, _event, window, cx| {
429                                            this.confirm(&menu::Confirm, window, cx)
430                                        })),
431                                ),
432                        ),
433                    ),
434            )
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use editor::EditorSettings;
442    use fs::FakeFs;
443    use gpui::{TestAppContext, VisualTestContext};
444    use language::language_settings;
445    use language_model::{
446        LanguageModelProviderId, LanguageModelProviderName,
447        fake_provider::FakeLanguageModelProvider,
448    };
449    use project::Project;
450    use settings::{Settings as _, SettingsStore};
451    use util::path;
452
453    #[gpui::test]
454    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
455        let cx = setup_test(cx).await;
456
457        assert_eq!(
458            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
459            Some("Provider Name cannot be empty".into())
460        );
461
462        assert_eq!(
463            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
464            Some("API URL cannot be empty".into())
465        );
466
467        assert_eq!(
468            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
469            Some("API Key cannot be empty".into())
470        );
471
472        assert_eq!(
473            save_provider_validation_errors(
474                "someprovider",
475                "someurl",
476                "somekey",
477                vec![("", "200000", "200000", "32000")],
478                cx,
479            )
480            .await,
481            Some("Model Name cannot be empty".into())
482        );
483
484        assert_eq!(
485            save_provider_validation_errors(
486                "someprovider",
487                "someurl",
488                "somekey",
489                vec![("somemodel", "abc", "200000", "32000")],
490                cx,
491            )
492            .await,
493            Some("Max Tokens must be a number".into())
494        );
495
496        assert_eq!(
497            save_provider_validation_errors(
498                "someprovider",
499                "someurl",
500                "somekey",
501                vec![("somemodel", "200000", "abc", "32000")],
502                cx,
503            )
504            .await,
505            Some("Max Completion Tokens must be a number".into())
506        );
507
508        assert_eq!(
509            save_provider_validation_errors(
510                "someprovider",
511                "someurl",
512                "somekey",
513                vec![("somemodel", "200000", "200000", "abc")],
514                cx,
515            )
516            .await,
517            Some("Max Output Tokens must be a number".into())
518        );
519
520        assert_eq!(
521            save_provider_validation_errors(
522                "someprovider",
523                "someurl",
524                "somekey",
525                vec![
526                    ("somemodel", "200000", "200000", "32000"),
527                    ("somemodel", "200000", "200000", "32000"),
528                ],
529                cx,
530            )
531            .await,
532            Some("Model Names must be unique".into())
533        );
534    }
535
536    #[gpui::test]
537    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
538        let cx = setup_test(cx).await;
539
540        cx.update(|_window, cx| {
541            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
542                registry.register_provider(
543                    FakeLanguageModelProvider::new(
544                        LanguageModelProviderId::new("someprovider"),
545                        LanguageModelProviderName::new("Some Provider"),
546                    ),
547                    cx,
548                );
549            });
550        });
551
552        assert_eq!(
553            save_provider_validation_errors(
554                "someprovider",
555                "someurl",
556                "someapikey",
557                vec![("somemodel", "200000", "200000", "32000")],
558                cx,
559            )
560            .await,
561            Some("Provider Name is already taken by another provider".into())
562        );
563    }
564
565    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
566        cx.update(|cx| {
567            let store = SettingsStore::test(cx);
568            cx.set_global(store);
569            workspace::init_settings(cx);
570            Project::init_settings(cx);
571            theme::init(theme::LoadThemes::JustBase, cx);
572            language_settings::init(cx);
573            EditorSettings::register(cx);
574            language_model::init_settings(cx);
575            language_models::init_settings(cx);
576        });
577
578        let fs = FakeFs::new(cx.executor());
579        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
580        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
581        let (_, cx) =
582            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
583
584        cx
585    }
586
587    async fn save_provider_validation_errors(
588        provider_name: &str,
589        api_url: &str,
590        api_key: &str,
591        models: Vec<(&str, &str, &str, &str)>,
592        cx: &mut VisualTestContext,
593    ) -> Option<SharedString> {
594        fn set_text(
595            input: &Entity<SingleLineInput>,
596            text: &str,
597            window: &mut Window,
598            cx: &mut App,
599        ) {
600            input.update(cx, |input, cx| {
601                input.editor().update(cx, |editor, cx| {
602                    editor.set_text(text, window, cx);
603                });
604            });
605        }
606
607        let task = cx.update(|window, cx| {
608            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
609            set_text(&input.provider_name, provider_name, window, cx);
610            set_text(&input.api_url, api_url, window, cx);
611            set_text(&input.api_key, api_key, window, cx);
612
613            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
614                models.iter().enumerate()
615            {
616                if i >= input.models.len() {
617                    input.models.push(ModelInput::new(window, cx));
618                }
619                let model = &mut input.models[i];
620                set_text(&model.name, name, window, cx);
621                set_text(&model.max_tokens, max_tokens, window, cx);
622                set_text(
623                    &model.max_completion_tokens,
624                    max_completion_tokens,
625                    window,
626                    cx,
627                );
628                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
629            }
630            save_provider_to_settings(&input, cx)
631        });
632
633        task.await.err()
634    }
635}