add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: impl Into<SharedString>,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input
 35                .editor()
 36                .update(cx, |editor, cx| editor.set_text(text, window, cx));
 37        }
 38        input
 39    })
 40}
 41
 42#[derive(Clone, Copy)]
 43pub enum LlmCompatibleProvider {
 44    OpenAi,
 45}
 46
 47impl LlmCompatibleProvider {
 48    fn name(&self) -> &'static str {
 49        match self {
 50            LlmCompatibleProvider::OpenAi => "OpenAI",
 51        }
 52    }
 53
 54    fn api_url(&self) -> &'static str {
 55        match self {
 56            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 57        }
 58    }
 59}
 60
 61struct AddLlmProviderInput {
 62    provider_name: Entity<InputField>,
 63    api_url: Entity<InputField>,
 64    api_key: Entity<InputField>,
 65    models: Vec<ModelInput>,
 66}
 67
 68impl AddLlmProviderInput {
 69    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 70        let provider_name =
 71            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 72        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 73        let api_key = single_line_input(
 74            "API Key",
 75            "000000000000000000000000000000000000000000000000",
 76            None,
 77            3,
 78            window,
 79            cx,
 80        );
 81
 82        Self {
 83            provider_name,
 84            api_url,
 85            api_key,
 86            models: vec![ModelInput::new(0, window, cx)],
 87        }
 88    }
 89
 90    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 91        let model_index = self.models.len();
 92        self.models.push(ModelInput::new(model_index, window, cx));
 93    }
 94
 95    fn remove_model(&mut self, index: usize) {
 96        self.models.remove(index);
 97    }
 98}
 99
100struct ModelCapabilityToggles {
101    pub supports_tools: ToggleState,
102    pub supports_images: ToggleState,
103    pub supports_parallel_tool_calls: ToggleState,
104    pub supports_prompt_cache_key: ToggleState,
105}
106
107struct ModelInput {
108    name: Entity<InputField>,
109    max_completion_tokens: Entity<InputField>,
110    max_output_tokens: Entity<InputField>,
111    max_tokens: Entity<InputField>,
112    capabilities: ModelCapabilityToggles,
113}
114
115impl ModelInput {
116    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
117        let base_tab_index = (3 + (model_index * 4)) as isize;
118
119        let model_name = single_line_input(
120            "Model Name",
121            "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
122            None,
123            base_tab_index + 1,
124            window,
125            cx,
126        );
127        let max_completion_tokens = single_line_input(
128            "Max Completion Tokens",
129            "200000",
130            Some("200000"),
131            base_tab_index + 2,
132            window,
133            cx,
134        );
135        let max_output_tokens = single_line_input(
136            "Max Output Tokens",
137            "Max Output Tokens",
138            Some("32000"),
139            base_tab_index + 3,
140            window,
141            cx,
142        );
143        let max_tokens = single_line_input(
144            "Max Tokens",
145            "Max Tokens",
146            Some("200000"),
147            base_tab_index + 4,
148            window,
149            cx,
150        );
151
152        let ModelCapabilities {
153            tools,
154            images,
155            parallel_tool_calls,
156            prompt_cache_key,
157        } = ModelCapabilities::default();
158
159        Self {
160            name: model_name,
161            max_completion_tokens,
162            max_output_tokens,
163            max_tokens,
164            capabilities: ModelCapabilityToggles {
165                supports_tools: tools.into(),
166                supports_images: images.into(),
167                supports_parallel_tool_calls: parallel_tool_calls.into(),
168                supports_prompt_cache_key: prompt_cache_key.into(),
169            },
170        }
171    }
172
173    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
174        let name = self.name.read(cx).text(cx);
175        if name.is_empty() {
176            return Err(SharedString::from("Model Name cannot be empty"));
177        }
178        Ok(AvailableModel {
179            name,
180            display_name: None,
181            max_completion_tokens: Some(
182                self.max_completion_tokens
183                    .read(cx)
184                    .text(cx)
185                    .parse::<u64>()
186                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
187            ),
188            max_output_tokens: Some(
189                self.max_output_tokens
190                    .read(cx)
191                    .text(cx)
192                    .parse::<u64>()
193                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
194            ),
195            max_tokens: self
196                .max_tokens
197                .read(cx)
198                .text(cx)
199                .parse::<u64>()
200                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
201            capabilities: ModelCapabilities {
202                tools: self.capabilities.supports_tools.selected(),
203                images: self.capabilities.supports_images.selected(),
204                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
205                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
206            },
207        })
208    }
209}
210
211fn save_provider_to_settings(
212    input: &AddLlmProviderInput,
213    cx: &mut App,
214) -> Task<Result<(), SharedString>> {
215    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
216    if provider_name.is_empty() {
217        return Task::ready(Err("Provider Name cannot be empty".into()));
218    }
219
220    if LanguageModelRegistry::read_global(cx)
221        .providers()
222        .iter()
223        .any(|provider| {
224            provider.id().0.as_ref() == provider_name.as_ref()
225                || provider.name().0.as_ref() == provider_name.as_ref()
226        })
227    {
228        return Task::ready(Err(
229            "Provider Name is already taken by another provider".into()
230        ));
231    }
232
233    let api_url = input.api_url.read(cx).text(cx);
234    if api_url.is_empty() {
235        return Task::ready(Err("API URL cannot be empty".into()));
236    }
237
238    let api_key = input.api_key.read(cx).text(cx);
239    if api_key.is_empty() {
240        return Task::ready(Err("API Key cannot be empty".into()));
241    }
242
243    let mut models = Vec::new();
244    let mut model_names: HashSet<String> = HashSet::default();
245    for model in &input.models {
246        match model.parse(cx) {
247            Ok(model) => {
248                if !model_names.insert(model.name.clone()) {
249                    return Task::ready(Err("Model Names must be unique".into()));
250                }
251                models.push(model)
252            }
253            Err(err) => return Task::ready(Err(err)),
254        }
255    }
256
257    let fs = <dyn Fs>::global(cx);
258    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
259    cx.spawn(async move |cx| {
260        task.await
261            .map_err(|_| "Failed to write API key to keychain")?;
262        cx.update(|cx| {
263            update_settings_file(fs, cx, |settings, _cx| {
264                settings
265                    .language_models
266                    .get_or_insert_default()
267                    .openai_compatible
268                    .get_or_insert_default()
269                    .insert(
270                        provider_name,
271                        OpenAiCompatibleSettingsContent {
272                            api_url,
273                            available_models: models,
274                        },
275                    );
276            });
277        })
278        .ok();
279        Ok(())
280    })
281}
282
283pub struct AddLlmProviderModal {
284    provider: LlmCompatibleProvider,
285    input: AddLlmProviderInput,
286    scroll_handle: ScrollHandle,
287    focus_handle: FocusHandle,
288    last_error: Option<SharedString>,
289}
290
291impl AddLlmProviderModal {
292    pub fn toggle(
293        provider: LlmCompatibleProvider,
294        workspace: &mut Workspace,
295        window: &mut Window,
296        cx: &mut Context<Workspace>,
297    ) {
298        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
299    }
300
301    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
302        Self {
303            input: AddLlmProviderInput::new(provider, window, cx),
304            provider,
305            last_error: None,
306            focus_handle: cx.focus_handle(),
307            scroll_handle: ScrollHandle::new(),
308        }
309    }
310
311    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
312        let task = save_provider_to_settings(&self.input, cx);
313        cx.spawn(async move |this, cx| {
314            let result = task.await;
315            this.update(cx, |this, cx| match result {
316                Ok(_) => {
317                    cx.emit(DismissEvent);
318                }
319                Err(error) => {
320                    this.last_error = Some(error);
321                    cx.notify();
322                }
323            })
324        })
325        .detach_and_log_err(cx);
326    }
327
328    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
329        cx.emit(DismissEvent);
330    }
331
332    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
333        v_flex()
334            .mt_1()
335            .gap_2()
336            .child(
337                h_flex()
338                    .justify_between()
339                    .child(Label::new("Models").size(LabelSize::Small))
340                    .child(
341                        Button::new("add-model", "Add Model")
342                            .icon(IconName::Plus)
343                            .icon_position(IconPosition::Start)
344                            .icon_size(IconSize::XSmall)
345                            .icon_color(Color::Muted)
346                            .label_size(LabelSize::Small)
347                            .on_click(cx.listener(|this, _, window, cx| {
348                                this.input.add_model(window, cx);
349                                cx.notify();
350                            })),
351                    ),
352            )
353            .children(
354                self.input
355                    .models
356                    .iter()
357                    .enumerate()
358                    .map(|(ix, _)| self.render_model(ix, cx)),
359            )
360    }
361
362    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
363        let has_more_than_one_model = self.input.models.len() > 1;
364        let model = &self.input.models[ix];
365
366        v_flex()
367            .p_2()
368            .gap_2()
369            .rounded_sm()
370            .border_1()
371            .border_dashed()
372            .border_color(cx.theme().colors().border.opacity(0.6))
373            .bg(cx.theme().colors().element_active.opacity(0.15))
374            .child(model.name.clone())
375            .child(
376                h_flex()
377                    .gap_2()
378                    .child(model.max_completion_tokens.clone())
379                    .child(model.max_output_tokens.clone()),
380            )
381            .child(model.max_tokens.clone())
382            .child(
383                v_flex()
384                    .gap_1()
385                    .child(
386                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
387                            .label("Supports tools")
388                            .on_click(cx.listener(move |this, checked, _window, cx| {
389                                this.input.models[ix].capabilities.supports_tools = *checked;
390                                cx.notify();
391                            })),
392                    )
393                    .child(
394                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
395                            .label("Supports images")
396                            .on_click(cx.listener(move |this, checked, _window, cx| {
397                                this.input.models[ix].capabilities.supports_images = *checked;
398                                cx.notify();
399                            })),
400                    )
401                    .child(
402                        Checkbox::new(
403                            ("supports-parallel-tool-calls", ix),
404                            model.capabilities.supports_parallel_tool_calls,
405                        )
406                        .label("Supports parallel_tool_calls")
407                        .on_click(cx.listener(
408                            move |this, checked, _window, cx| {
409                                this.input.models[ix]
410                                    .capabilities
411                                    .supports_parallel_tool_calls = *checked;
412                                cx.notify();
413                            },
414                        )),
415                    )
416                    .child(
417                        Checkbox::new(
418                            ("supports-prompt-cache-key", ix),
419                            model.capabilities.supports_prompt_cache_key,
420                        )
421                        .label("Supports prompt_cache_key")
422                        .on_click(cx.listener(
423                            move |this, checked, _window, cx| {
424                                this.input.models[ix].capabilities.supports_prompt_cache_key =
425                                    *checked;
426                                cx.notify();
427                            },
428                        )),
429                    ),
430            )
431            .when(has_more_than_one_model, |this| {
432                this.child(
433                    Button::new(("remove-model", ix), "Remove Model")
434                        .icon(IconName::Trash)
435                        .icon_position(IconPosition::Start)
436                        .icon_size(IconSize::XSmall)
437                        .icon_color(Color::Muted)
438                        .label_size(LabelSize::Small)
439                        .style(ButtonStyle::Outlined)
440                        .full_width()
441                        .on_click(cx.listener(move |this, _, _window, cx| {
442                            this.input.remove_model(ix);
443                            cx.notify();
444                        })),
445                )
446            })
447    }
448
449    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, _: &mut Context<Self>) {
450        window.focus_next();
451    }
452
453    fn on_tab_prev(
454        &mut self,
455        _: &menu::SelectPrevious,
456        window: &mut Window,
457        _: &mut Context<Self>,
458    ) {
459        window.focus_prev();
460    }
461}
462
463impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
464
465impl Focusable for AddLlmProviderModal {
466    fn focus_handle(&self, _cx: &App) -> FocusHandle {
467        self.focus_handle.clone()
468    }
469}
470
471impl ModalView for AddLlmProviderModal {}
472
473impl Render for AddLlmProviderModal {
474    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
475        let focus_handle = self.focus_handle(cx);
476
477        let window_size = window.viewport_size();
478        let rem_size = window.rem_size();
479        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
480
481        let modal_max_height = if is_large_window {
482            rems_from_px(450.)
483        } else {
484            rems_from_px(200.)
485        };
486
487        v_flex()
488            .id("add-llm-provider-modal")
489            .key_context("AddLlmProviderModal")
490            .w(rems(34.))
491            .elevation_3(cx)
492            .on_action(cx.listener(Self::cancel))
493            .on_action(cx.listener(Self::on_tab))
494            .on_action(cx.listener(Self::on_tab_prev))
495            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
496                this.focus_handle(cx).focus(window);
497            }))
498            .child(
499                Modal::new("configure-context-server", None)
500                    .header(ModalHeader::new().headline("Add LLM Provider").description(
501                        match self.provider {
502                            LlmCompatibleProvider::OpenAi => {
503                                "This provider will use an OpenAI compatible API."
504                            }
505                        },
506                    ))
507                    .when_some(self.last_error.clone(), |this, error| {
508                        this.section(
509                            Section::new().child(
510                                Banner::new()
511                                    .severity(Severity::Warning)
512                                    .child(div().text_xs().child(error)),
513                            ),
514                        )
515                    })
516                    .child(
517                        div()
518                            .size_full()
519                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
520                            .child(
521                                v_flex()
522                                    .id("modal_content")
523                                    .size_full()
524                                    .tab_group()
525                                    .max_h(modal_max_height)
526                                    .pl_3()
527                                    .pr_4()
528                                    .gap_2()
529                                    .overflow_y_scroll()
530                                    .track_scroll(&self.scroll_handle)
531                                    .child(self.input.provider_name.clone())
532                                    .child(self.input.api_url.clone())
533                                    .child(self.input.api_key.clone())
534                                    .child(self.render_model_section(cx)),
535                            ),
536                    )
537                    .footer(
538                        ModalFooter::new().end_slot(
539                            h_flex()
540                                .gap_1()
541                                .child(
542                                    Button::new("cancel", "Cancel")
543                                        .key_binding(
544                                            KeyBinding::for_action_in(
545                                                &menu::Cancel,
546                                                &focus_handle,
547                                                cx,
548                                            )
549                                            .map(|kb| kb.size(rems_from_px(12.))),
550                                        )
551                                        .on_click(cx.listener(|this, _event, window, cx| {
552                                            this.cancel(&menu::Cancel, window, cx)
553                                        })),
554                                )
555                                .child(
556                                    Button::new("save-server", "Save Provider")
557                                        .key_binding(
558                                            KeyBinding::for_action_in(
559                                                &menu::Confirm,
560                                                &focus_handle,
561                                                cx,
562                                            )
563                                            .map(|kb| kb.size(rems_from_px(12.))),
564                                        )
565                                        .on_click(cx.listener(|this, _event, window, cx| {
566                                            this.confirm(&menu::Confirm, window, cx)
567                                        })),
568                                ),
569                        ),
570                    ),
571            )
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use fs::FakeFs;
579    use gpui::{TestAppContext, VisualTestContext};
580    use language_model::{
581        LanguageModelProviderId, LanguageModelProviderName,
582        fake_provider::FakeLanguageModelProvider,
583    };
584    use project::Project;
585    use settings::SettingsStore;
586    use util::path;
587
588    #[gpui::test]
589    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
590        let cx = setup_test(cx).await;
591
592        assert_eq!(
593            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
594            Some("Provider Name cannot be empty".into())
595        );
596
597        assert_eq!(
598            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
599            Some("API URL cannot be empty".into())
600        );
601
602        assert_eq!(
603            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
604            Some("API Key cannot be empty".into())
605        );
606
607        assert_eq!(
608            save_provider_validation_errors(
609                "someprovider",
610                "someurl",
611                "somekey",
612                vec![("", "200000", "200000", "32000")],
613                cx,
614            )
615            .await,
616            Some("Model Name cannot be empty".into())
617        );
618
619        assert_eq!(
620            save_provider_validation_errors(
621                "someprovider",
622                "someurl",
623                "somekey",
624                vec![("somemodel", "abc", "200000", "32000")],
625                cx,
626            )
627            .await,
628            Some("Max Tokens must be a number".into())
629        );
630
631        assert_eq!(
632            save_provider_validation_errors(
633                "someprovider",
634                "someurl",
635                "somekey",
636                vec![("somemodel", "200000", "abc", "32000")],
637                cx,
638            )
639            .await,
640            Some("Max Completion Tokens must be a number".into())
641        );
642
643        assert_eq!(
644            save_provider_validation_errors(
645                "someprovider",
646                "someurl",
647                "somekey",
648                vec![("somemodel", "200000", "200000", "abc")],
649                cx,
650            )
651            .await,
652            Some("Max Output Tokens must be a number".into())
653        );
654
655        assert_eq!(
656            save_provider_validation_errors(
657                "someprovider",
658                "someurl",
659                "somekey",
660                vec![
661                    ("somemodel", "200000", "200000", "32000"),
662                    ("somemodel", "200000", "200000", "32000"),
663                ],
664                cx,
665            )
666            .await,
667            Some("Model Names must be unique".into())
668        );
669    }
670
671    #[gpui::test]
672    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
673        let cx = setup_test(cx).await;
674
675        cx.update(|_window, cx| {
676            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
677                registry.register_provider(
678                    Arc::new(FakeLanguageModelProvider::new(
679                        LanguageModelProviderId::new("someprovider"),
680                        LanguageModelProviderName::new("Some Provider"),
681                    )),
682                    cx,
683                );
684            });
685        });
686
687        assert_eq!(
688            save_provider_validation_errors(
689                "someprovider",
690                "someurl",
691                "someapikey",
692                vec![("somemodel", "200000", "200000", "32000")],
693                cx,
694            )
695            .await,
696            Some("Provider Name is already taken by another provider".into())
697        );
698    }
699
700    #[gpui::test]
701    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
702        let cx = setup_test(cx).await;
703
704        cx.update(|window, cx| {
705            let model_input = ModelInput::new(0, window, cx);
706            model_input.name.update(cx, |input, cx| {
707                input.editor().update(cx, |editor, cx| {
708                    editor.set_text("somemodel", window, cx);
709                });
710            });
711            assert_eq!(
712                model_input.capabilities.supports_tools,
713                ToggleState::Selected
714            );
715            assert_eq!(
716                model_input.capabilities.supports_images,
717                ToggleState::Unselected
718            );
719            assert_eq!(
720                model_input.capabilities.supports_parallel_tool_calls,
721                ToggleState::Unselected
722            );
723            assert_eq!(
724                model_input.capabilities.supports_prompt_cache_key,
725                ToggleState::Unselected
726            );
727
728            let parsed_model = model_input.parse(cx).unwrap();
729            assert!(parsed_model.capabilities.tools);
730            assert!(!parsed_model.capabilities.images);
731            assert!(!parsed_model.capabilities.parallel_tool_calls);
732            assert!(!parsed_model.capabilities.prompt_cache_key);
733        });
734    }
735
736    #[gpui::test]
737    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
738        let cx = setup_test(cx).await;
739
740        cx.update(|window, cx| {
741            let mut model_input = ModelInput::new(0, window, cx);
742            model_input.name.update(cx, |input, cx| {
743                input.editor().update(cx, |editor, cx| {
744                    editor.set_text("somemodel", window, cx);
745                });
746            });
747
748            model_input.capabilities.supports_tools = ToggleState::Unselected;
749            model_input.capabilities.supports_images = ToggleState::Unselected;
750            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
751            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
752
753            let parsed_model = model_input.parse(cx).unwrap();
754            assert!(!parsed_model.capabilities.tools);
755            assert!(!parsed_model.capabilities.images);
756            assert!(!parsed_model.capabilities.parallel_tool_calls);
757            assert!(!parsed_model.capabilities.prompt_cache_key);
758        });
759    }
760
761    #[gpui::test]
762    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
763        let cx = setup_test(cx).await;
764
765        cx.update(|window, cx| {
766            let mut model_input = ModelInput::new(0, window, cx);
767            model_input.name.update(cx, |input, cx| {
768                input.editor().update(cx, |editor, cx| {
769                    editor.set_text("somemodel", window, cx);
770                });
771            });
772
773            model_input.capabilities.supports_tools = ToggleState::Selected;
774            model_input.capabilities.supports_images = ToggleState::Unselected;
775            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
776            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
777
778            let parsed_model = model_input.parse(cx).unwrap();
779            assert_eq!(parsed_model.name, "somemodel");
780            assert!(parsed_model.capabilities.tools);
781            assert!(!parsed_model.capabilities.images);
782            assert!(parsed_model.capabilities.parallel_tool_calls);
783            assert!(!parsed_model.capabilities.prompt_cache_key);
784        });
785    }
786
787    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
788        cx.update(|cx| {
789            let store = SettingsStore::test(cx);
790            cx.set_global(store);
791            theme::init(theme::LoadThemes::JustBase, cx);
792
793            language_model::init_settings(cx);
794        });
795
796        let fs = FakeFs::new(cx.executor());
797        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
798        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
799        let (_, cx) =
800            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
801
802        cx
803    }
804
805    async fn save_provider_validation_errors(
806        provider_name: &str,
807        api_url: &str,
808        api_key: &str,
809        models: Vec<(&str, &str, &str, &str)>,
810        cx: &mut VisualTestContext,
811    ) -> Option<SharedString> {
812        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
813            input.update(cx, |input, cx| {
814                input.editor().update(cx, |editor, cx| {
815                    editor.set_text(text, window, cx);
816                });
817            });
818        }
819
820        let task = cx.update(|window, cx| {
821            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
822            set_text(&input.provider_name, provider_name, window, cx);
823            set_text(&input.api_url, api_url, window, cx);
824            set_text(&input.api_key, api_key, window, cx);
825
826            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
827                models.iter().enumerate()
828            {
829                if i >= input.models.len() {
830                    input.models.push(ModelInput::new(i, window, cx));
831                }
832                let model = &mut input.models[i];
833                set_text(&model.name, name, window, cx);
834                set_text(&model.max_tokens, max_tokens, window, cx);
835                set_text(
836                    &model.max_completion_tokens,
837                    max_completion_tokens,
838                    window,
839                    cx,
840                );
841                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
842            }
843            save_provider_to_settings(&input, cx)
844        });
845
846        task.await.err()
847    }
848}