add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: &str,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input.set_text(text, window, cx);
 35        }
 36        input
 37    })
 38}
 39
 40#[derive(Clone, Copy)]
 41pub enum LlmCompatibleProvider {
 42    OpenAi,
 43}
 44
 45impl LlmCompatibleProvider {
 46    fn name(&self) -> &'static str {
 47        match self {
 48            LlmCompatibleProvider::OpenAi => "OpenAI",
 49        }
 50    }
 51
 52    fn api_url(&self) -> &'static str {
 53        match self {
 54            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 55        }
 56    }
 57}
 58
 59struct AddLlmProviderInput {
 60    provider_name: Entity<InputField>,
 61    api_url: Entity<InputField>,
 62    api_key: Entity<InputField>,
 63    models: Vec<ModelInput>,
 64}
 65
 66impl AddLlmProviderInput {
 67    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 68        let provider_name =
 69            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 70        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 71        let api_key = cx.new(|cx| {
 72            InputField::new(
 73                window,
 74                cx,
 75                "000000000000000000000000000000000000000000000000",
 76            )
 77            .label("API Key")
 78            .tab_index(3)
 79            .tab_stop(true)
 80            .masked(true)
 81        });
 82
 83        Self {
 84            provider_name,
 85            api_url,
 86            api_key,
 87            models: vec![ModelInput::new(0, window, cx)],
 88        }
 89    }
 90
 91    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 92        let model_index = self.models.len();
 93        self.models.push(ModelInput::new(model_index, window, cx));
 94    }
 95
 96    fn remove_model(&mut self, index: usize) {
 97        self.models.remove(index);
 98    }
 99}
100
101struct ModelCapabilityToggles {
102    pub supports_tools: ToggleState,
103    pub supports_images: ToggleState,
104    pub supports_parallel_tool_calls: ToggleState,
105    pub supports_prompt_cache_key: ToggleState,
106    pub supports_chat_completions: ToggleState,
107}
108
109struct ModelInput {
110    name: Entity<InputField>,
111    max_completion_tokens: Entity<InputField>,
112    max_output_tokens: Entity<InputField>,
113    max_tokens: Entity<InputField>,
114    capabilities: ModelCapabilityToggles,
115}
116
117impl ModelInput {
118    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
119        let base_tab_index = (3 + (model_index * 4)) as isize;
120
121        let model_name = single_line_input(
122            "Model Name",
123            "e.g. gpt-5, claude-opus-4, gemini-2.5-pro",
124            None,
125            base_tab_index + 1,
126            window,
127            cx,
128        );
129        let max_completion_tokens = single_line_input(
130            "Max Completion Tokens",
131            "200000",
132            Some("200000"),
133            base_tab_index + 2,
134            window,
135            cx,
136        );
137        let max_output_tokens = single_line_input(
138            "Max Output Tokens",
139            "Max Output Tokens",
140            Some("32000"),
141            base_tab_index + 3,
142            window,
143            cx,
144        );
145        let max_tokens = single_line_input(
146            "Max Tokens",
147            "Max Tokens",
148            Some("200000"),
149            base_tab_index + 4,
150            window,
151            cx,
152        );
153
154        let ModelCapabilities {
155            tools,
156            images,
157            parallel_tool_calls,
158            prompt_cache_key,
159            chat_completions,
160        } = ModelCapabilities::default();
161
162        Self {
163            name: model_name,
164            max_completion_tokens,
165            max_output_tokens,
166            max_tokens,
167            capabilities: ModelCapabilityToggles {
168                supports_tools: tools.into(),
169                supports_images: images.into(),
170                supports_parallel_tool_calls: parallel_tool_calls.into(),
171                supports_prompt_cache_key: prompt_cache_key.into(),
172                supports_chat_completions: chat_completions.into(),
173            },
174        }
175    }
176
177    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
178        let name = self.name.read(cx).text(cx);
179        if name.is_empty() {
180            return Err(SharedString::from("Model Name cannot be empty"));
181        }
182        Ok(AvailableModel {
183            name,
184            display_name: None,
185            max_completion_tokens: Some(
186                self.max_completion_tokens
187                    .read(cx)
188                    .text(cx)
189                    .parse::<u64>()
190                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
191            ),
192            max_output_tokens: Some(
193                self.max_output_tokens
194                    .read(cx)
195                    .text(cx)
196                    .parse::<u64>()
197                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
198            ),
199            max_tokens: self
200                .max_tokens
201                .read(cx)
202                .text(cx)
203                .parse::<u64>()
204                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
205            reasoning_effort: None,
206            capabilities: ModelCapabilities {
207                tools: self.capabilities.supports_tools.selected(),
208                images: self.capabilities.supports_images.selected(),
209                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
210                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
211                chat_completions: self.capabilities.supports_chat_completions.selected(),
212            },
213        })
214    }
215}
216
217fn save_provider_to_settings(
218    input: &AddLlmProviderInput,
219    cx: &mut App,
220) -> Task<Result<(), SharedString>> {
221    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
222    if provider_name.is_empty() {
223        return Task::ready(Err("Provider Name cannot be empty".into()));
224    }
225
226    if LanguageModelRegistry::read_global(cx)
227        .providers()
228        .iter()
229        .any(|provider| {
230            provider.id().0.as_ref() == provider_name.as_ref()
231                || provider.name().0.as_ref() == provider_name.as_ref()
232        })
233    {
234        return Task::ready(Err(
235            "Provider Name is already taken by another provider".into()
236        ));
237    }
238
239    let api_url = input.api_url.read(cx).text(cx);
240    if api_url.is_empty() {
241        return Task::ready(Err("API URL cannot be empty".into()));
242    }
243
244    let api_key = input.api_key.read(cx).text(cx);
245    if api_key.is_empty() {
246        return Task::ready(Err("API Key cannot be empty".into()));
247    }
248
249    let mut models = Vec::new();
250    let mut model_names: HashSet<String> = HashSet::default();
251    for model in &input.models {
252        match model.parse(cx) {
253            Ok(model) => {
254                if !model_names.insert(model.name.clone()) {
255                    return Task::ready(Err("Model Names must be unique".into()));
256                }
257                models.push(model)
258            }
259            Err(err) => return Task::ready(Err(err)),
260        }
261    }
262
263    let fs = <dyn Fs>::global(cx);
264    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
265    cx.spawn(async move |cx| {
266        task.await
267            .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
268        cx.update(|cx| {
269            update_settings_file(fs, cx, |settings, _cx| {
270                settings
271                    .language_models
272                    .get_or_insert_default()
273                    .openai_compatible
274                    .get_or_insert_default()
275                    .insert(
276                        provider_name,
277                        OpenAiCompatibleSettingsContent {
278                            api_url,
279                            available_models: models,
280                        },
281                    );
282            });
283        });
284        Ok(())
285    })
286}
287
288pub struct AddLlmProviderModal {
289    provider: LlmCompatibleProvider,
290    input: AddLlmProviderInput,
291    scroll_handle: ScrollHandle,
292    focus_handle: FocusHandle,
293    last_error: Option<SharedString>,
294}
295
296impl AddLlmProviderModal {
297    pub fn toggle(
298        provider: LlmCompatibleProvider,
299        workspace: &mut Workspace,
300        window: &mut Window,
301        cx: &mut Context<Workspace>,
302    ) {
303        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
304    }
305
306    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
307        Self {
308            input: AddLlmProviderInput::new(provider, window, cx),
309            provider,
310            last_error: None,
311            focus_handle: cx.focus_handle(),
312            scroll_handle: ScrollHandle::new(),
313        }
314    }
315
316    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
317        let task = save_provider_to_settings(&self.input, cx);
318        cx.spawn(async move |this, cx| {
319            let result = task.await;
320            this.update(cx, |this, cx| match result {
321                Ok(_) => {
322                    cx.emit(DismissEvent);
323                }
324                Err(error) => {
325                    this.last_error = Some(error);
326                    cx.notify();
327                }
328            })
329        })
330        .detach_and_log_err(cx);
331    }
332
333    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
334        cx.emit(DismissEvent);
335    }
336
337    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
338        v_flex()
339            .mt_1()
340            .gap_2()
341            .child(
342                h_flex()
343                    .justify_between()
344                    .child(Label::new("Models").size(LabelSize::Small))
345                    .child(
346                        Button::new("add-model", "Add Model")
347                            .start_icon(
348                                Icon::new(IconName::Plus)
349                                    .size(IconSize::XSmall)
350                                    .color(Color::Muted),
351                            )
352                            .label_size(LabelSize::Small)
353                            .on_click(cx.listener(|this, _, window, cx| {
354                                this.input.add_model(window, cx);
355                                cx.notify();
356                            })),
357                    ),
358            )
359            .children(
360                self.input
361                    .models
362                    .iter()
363                    .enumerate()
364                    .map(|(ix, _)| self.render_model(ix, cx)),
365            )
366    }
367
368    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
369        let has_more_than_one_model = self.input.models.len() > 1;
370        let model = &self.input.models[ix];
371
372        v_flex()
373            .p_2()
374            .gap_2()
375            .rounded_sm()
376            .border_1()
377            .border_dashed()
378            .border_color(cx.theme().colors().border.opacity(0.6))
379            .bg(cx.theme().colors().element_active.opacity(0.15))
380            .child(model.name.clone())
381            .child(
382                h_flex()
383                    .gap_2()
384                    .child(model.max_completion_tokens.clone())
385                    .child(model.max_output_tokens.clone()),
386            )
387            .child(model.max_tokens.clone())
388            .child(
389                v_flex()
390                    .gap_1()
391                    .child(
392                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
393                            .label("Supports tools")
394                            .on_click(cx.listener(move |this, checked, _window, cx| {
395                                this.input.models[ix].capabilities.supports_tools = *checked;
396                                cx.notify();
397                            })),
398                    )
399                    .child(
400                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
401                            .label("Supports images")
402                            .on_click(cx.listener(move |this, checked, _window, cx| {
403                                this.input.models[ix].capabilities.supports_images = *checked;
404                                cx.notify();
405                            })),
406                    )
407                    .child(
408                        Checkbox::new(
409                            ("supports-parallel-tool-calls", ix),
410                            model.capabilities.supports_parallel_tool_calls,
411                        )
412                        .label("Supports parallel_tool_calls")
413                        .on_click(cx.listener(
414                            move |this, checked, _window, cx| {
415                                this.input.models[ix]
416                                    .capabilities
417                                    .supports_parallel_tool_calls = *checked;
418                                cx.notify();
419                            },
420                        )),
421                    )
422                    .child(
423                        Checkbox::new(
424                            ("supports-prompt-cache-key", ix),
425                            model.capabilities.supports_prompt_cache_key,
426                        )
427                        .label("Supports prompt_cache_key")
428                        .on_click(cx.listener(
429                            move |this, checked, _window, cx| {
430                                this.input.models[ix].capabilities.supports_prompt_cache_key =
431                                    *checked;
432                                cx.notify();
433                            },
434                        )),
435                    )
436                    .child(
437                        Checkbox::new(
438                            ("supports-chat-completions", ix),
439                            model.capabilities.supports_chat_completions,
440                        )
441                        .label("Supports /chat/completions")
442                        .on_click(cx.listener(
443                            move |this, checked, _window, cx| {
444                                this.input.models[ix].capabilities.supports_chat_completions =
445                                    *checked;
446                                cx.notify();
447                            },
448                        )),
449                    ),
450            )
451            .when(has_more_than_one_model, |this| {
452                this.child(
453                    Button::new(("remove-model", ix), "Remove Model")
454                        .start_icon(
455                            Icon::new(IconName::Trash)
456                                .size(IconSize::XSmall)
457                                .color(Color::Muted),
458                        )
459                        .label_size(LabelSize::Small)
460                        .style(ButtonStyle::Outlined)
461                        .full_width()
462                        .on_click(cx.listener(move |this, _, _window, cx| {
463                            this.input.remove_model(ix);
464                            cx.notify();
465                        })),
466                )
467            })
468    }
469
470    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
471        window.focus_next(cx);
472    }
473
474    fn on_tab_prev(
475        &mut self,
476        _: &menu::SelectPrevious,
477        window: &mut Window,
478        cx: &mut Context<Self>,
479    ) {
480        window.focus_prev(cx);
481    }
482}
483
484impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
485
486impl Focusable for AddLlmProviderModal {
487    fn focus_handle(&self, _cx: &App) -> FocusHandle {
488        self.focus_handle.clone()
489    }
490}
491
492impl ModalView for AddLlmProviderModal {}
493
494impl Render for AddLlmProviderModal {
495    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
496        let focus_handle = self.focus_handle(cx);
497
498        let window_size = window.viewport_size();
499        let rem_size = window.rem_size();
500        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
501
502        let modal_max_height = if is_large_window {
503            rems_from_px(450.)
504        } else {
505            rems_from_px(200.)
506        };
507
508        v_flex()
509            .id("add-llm-provider-modal")
510            .key_context("AddLlmProviderModal")
511            .w(rems(34.))
512            .elevation_3(cx)
513            .on_action(cx.listener(Self::cancel))
514            .on_action(cx.listener(Self::on_tab))
515            .on_action(cx.listener(Self::on_tab_prev))
516            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
517                this.focus_handle(cx).focus(window, cx);
518            }))
519            .child(
520                Modal::new("configure-context-server", None)
521                    .header(ModalHeader::new().headline("Add LLM Provider").description(
522                        match self.provider {
523                            LlmCompatibleProvider::OpenAi => {
524                                "This provider will use an OpenAI compatible API."
525                            }
526                        },
527                    ))
528                    .when_some(self.last_error.clone(), |this, error| {
529                        this.section(
530                            Section::new().child(
531                                Banner::new()
532                                    .severity(Severity::Warning)
533                                    .child(div().text_xs().child(error)),
534                            ),
535                        )
536                    })
537                    .child(
538                        div()
539                            .size_full()
540                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
541                            .child(
542                                v_flex()
543                                    .id("modal_content")
544                                    .size_full()
545                                    .tab_group()
546                                    .max_h(modal_max_height)
547                                    .pl_3()
548                                    .pr_4()
549                                    .pb_2()
550                                    .gap_2()
551                                    .overflow_y_scroll()
552                                    .track_scroll(&self.scroll_handle)
553                                    .child(self.input.provider_name.clone())
554                                    .child(self.input.api_url.clone())
555                                    .child(self.input.api_key.clone())
556                                    .child(self.render_model_section(cx)),
557                            ),
558                    )
559                    .footer(
560                        ModalFooter::new().end_slot(
561                            h_flex()
562                                .gap_1()
563                                .child(
564                                    Button::new("cancel", "Cancel")
565                                        .key_binding(
566                                            KeyBinding::for_action_in(
567                                                &menu::Cancel,
568                                                &focus_handle,
569                                                cx,
570                                            )
571                                            .map(|kb| kb.size(rems_from_px(12.))),
572                                        )
573                                        .on_click(cx.listener(|this, _event, window, cx| {
574                                            this.cancel(&menu::Cancel, window, cx)
575                                        })),
576                                )
577                                .child(
578                                    Button::new("save-server", "Save Provider")
579                                        .key_binding(
580                                            KeyBinding::for_action_in(
581                                                &menu::Confirm,
582                                                &focus_handle,
583                                                cx,
584                                            )
585                                            .map(|kb| kb.size(rems_from_px(12.))),
586                                        )
587                                        .on_click(cx.listener(|this, _event, window, cx| {
588                                            this.confirm(&menu::Confirm, window, cx)
589                                        })),
590                                ),
591                        ),
592                    ),
593            )
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use fs::FakeFs;
601    use gpui::{TestAppContext, VisualTestContext};
602    use language_model::{
603        LanguageModelProviderId, LanguageModelProviderName,
604        fake_provider::FakeLanguageModelProvider,
605    };
606    use project::Project;
607    use settings::SettingsStore;
608    use util::path;
609    use workspace::MultiWorkspace;
610
611    #[gpui::test]
612    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
613        let cx = setup_test(cx).await;
614
615        assert_eq!(
616            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
617            Some("Provider Name cannot be empty".into())
618        );
619
620        assert_eq!(
621            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
622            Some("API URL cannot be empty".into())
623        );
624
625        assert_eq!(
626            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
627            Some("API Key cannot be empty".into())
628        );
629
630        assert_eq!(
631            save_provider_validation_errors(
632                "someprovider",
633                "someurl",
634                "somekey",
635                vec![("", "200000", "200000", "32000")],
636                cx,
637            )
638            .await,
639            Some("Model Name cannot be empty".into())
640        );
641
642        assert_eq!(
643            save_provider_validation_errors(
644                "someprovider",
645                "someurl",
646                "somekey",
647                vec![("somemodel", "abc", "200000", "32000")],
648                cx,
649            )
650            .await,
651            Some("Max Tokens must be a number".into())
652        );
653
654        assert_eq!(
655            save_provider_validation_errors(
656                "someprovider",
657                "someurl",
658                "somekey",
659                vec![("somemodel", "200000", "abc", "32000")],
660                cx,
661            )
662            .await,
663            Some("Max Completion Tokens must be a number".into())
664        );
665
666        assert_eq!(
667            save_provider_validation_errors(
668                "someprovider",
669                "someurl",
670                "somekey",
671                vec![("somemodel", "200000", "200000", "abc")],
672                cx,
673            )
674            .await,
675            Some("Max Output Tokens must be a number".into())
676        );
677
678        assert_eq!(
679            save_provider_validation_errors(
680                "someprovider",
681                "someurl",
682                "somekey",
683                vec![
684                    ("somemodel", "200000", "200000", "32000"),
685                    ("somemodel", "200000", "200000", "32000"),
686                ],
687                cx,
688            )
689            .await,
690            Some("Model Names must be unique".into())
691        );
692    }
693
694    #[gpui::test]
695    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
696        let cx = setup_test(cx).await;
697
698        cx.update(|_window, cx| {
699            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
700                registry.register_provider(
701                    Arc::new(FakeLanguageModelProvider::new(
702                        LanguageModelProviderId::new("someprovider"),
703                        LanguageModelProviderName::new("Some Provider"),
704                    )),
705                    cx,
706                );
707            });
708        });
709
710        assert_eq!(
711            save_provider_validation_errors(
712                "someprovider",
713                "someurl",
714                "someapikey",
715                vec![("somemodel", "200000", "200000", "32000")],
716                cx,
717            )
718            .await,
719            Some("Provider Name is already taken by another provider".into())
720        );
721    }
722
723    #[gpui::test]
724    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
725        let cx = setup_test(cx).await;
726
727        cx.update(|window, cx| {
728            let model_input = ModelInput::new(0, window, cx);
729            model_input.name.update(cx, |input, cx| {
730                input.set_text("somemodel", window, cx);
731            });
732            assert_eq!(
733                model_input.capabilities.supports_tools,
734                ToggleState::Selected
735            );
736            assert_eq!(
737                model_input.capabilities.supports_images,
738                ToggleState::Unselected
739            );
740            assert_eq!(
741                model_input.capabilities.supports_parallel_tool_calls,
742                ToggleState::Unselected
743            );
744            assert_eq!(
745                model_input.capabilities.supports_prompt_cache_key,
746                ToggleState::Unselected
747            );
748            assert_eq!(
749                model_input.capabilities.supports_chat_completions,
750                ToggleState::Selected
751            );
752
753            let parsed_model = model_input.parse(cx).unwrap();
754            assert!(parsed_model.capabilities.tools);
755            assert!(!parsed_model.capabilities.images);
756            assert!(!parsed_model.capabilities.parallel_tool_calls);
757            assert!(!parsed_model.capabilities.prompt_cache_key);
758            assert!(parsed_model.capabilities.chat_completions);
759        });
760    }
761
762    #[gpui::test]
763    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
764        let cx = setup_test(cx).await;
765
766        cx.update(|window, cx| {
767            let mut model_input = ModelInput::new(0, window, cx);
768            model_input.name.update(cx, |input, cx| {
769                input.set_text("somemodel", window, cx);
770            });
771
772            model_input.capabilities.supports_tools = ToggleState::Unselected;
773            model_input.capabilities.supports_images = ToggleState::Unselected;
774            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
775            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
776            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
777
778            let parsed_model = model_input.parse(cx).unwrap();
779            assert!(!parsed_model.capabilities.tools);
780            assert!(!parsed_model.capabilities.images);
781            assert!(!parsed_model.capabilities.parallel_tool_calls);
782            assert!(!parsed_model.capabilities.prompt_cache_key);
783            assert!(!parsed_model.capabilities.chat_completions);
784        });
785    }
786
787    #[gpui::test]
788    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
789        let cx = setup_test(cx).await;
790
791        cx.update(|window, cx| {
792            let mut model_input = ModelInput::new(0, window, cx);
793            model_input.name.update(cx, |input, cx| {
794                input.set_text("somemodel", window, cx);
795            });
796
797            model_input.capabilities.supports_tools = ToggleState::Selected;
798            model_input.capabilities.supports_images = ToggleState::Unselected;
799            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
800            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
801            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
802
803            let parsed_model = model_input.parse(cx).unwrap();
804            assert_eq!(parsed_model.name, "somemodel");
805            assert!(parsed_model.capabilities.tools);
806            assert!(!parsed_model.capabilities.images);
807            assert!(parsed_model.capabilities.parallel_tool_calls);
808            assert!(!parsed_model.capabilities.prompt_cache_key);
809            assert!(parsed_model.capabilities.chat_completions);
810        });
811    }
812
813    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
814        cx.update(|cx| {
815            let store = SettingsStore::test(cx);
816            cx.set_global(store);
817            theme_settings::init(theme::LoadThemes::JustBase, cx);
818
819            language_model::init(cx);
820            editor::init(cx);
821        });
822
823        let fs = FakeFs::new(cx.executor());
824        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
825        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
826        let (multi_workspace, cx) =
827            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
828        let _workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
829
830        cx
831    }
832
833    async fn save_provider_validation_errors(
834        provider_name: &str,
835        api_url: &str,
836        api_key: &str,
837        models: Vec<(&str, &str, &str, &str)>,
838        cx: &mut VisualTestContext,
839    ) -> Option<SharedString> {
840        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
841            input.update(cx, |input, cx| {
842                input.set_text(text, window, cx);
843            });
844        }
845
846        let task = cx.update(|window, cx| {
847            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
848            set_text(&input.provider_name, provider_name, window, cx);
849            set_text(&input.api_url, api_url, window, cx);
850            set_text(&input.api_key, api_key, window, cx);
851
852            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
853                models.iter().enumerate()
854            {
855                if i >= input.models.len() {
856                    input.models.push(ModelInput::new(i, window, cx));
857                }
858                let model = &mut input.models[i];
859                set_text(&model.name, name, window, cx);
860                set_text(&model.max_tokens, max_tokens, window, cx);
861                set_text(
862                    &model.max_completion_tokens,
863                    max_completion_tokens,
864                    window,
865                    cx,
866                );
867                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
868            }
869            save_provider_to_settings(&input, cx)
870        });
871
872        task.await.err()
873    }
874}