add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: &str,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input.set_text(text, window, cx);
 35        }
 36        input
 37    })
 38}
 39
 40#[derive(Clone, Copy)]
 41pub enum LlmCompatibleProvider {
 42    OpenAi,
 43}
 44
 45impl LlmCompatibleProvider {
 46    fn name(&self) -> &'static str {
 47        match self {
 48            LlmCompatibleProvider::OpenAi => "OpenAI",
 49        }
 50    }
 51
 52    fn api_url(&self) -> &'static str {
 53        match self {
 54            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 55        }
 56    }
 57}
 58
 59struct AddLlmProviderInput {
 60    provider_name: Entity<InputField>,
 61    api_url: Entity<InputField>,
 62    api_key: Entity<InputField>,
 63    models: Vec<ModelInput>,
 64}
 65
 66impl AddLlmProviderInput {
 67    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 68        let provider_name =
 69            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 70        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 71        let api_key = single_line_input(
 72            "API Key",
 73            "000000000000000000000000000000000000000000000000",
 74            None,
 75            3,
 76            window,
 77            cx,
 78        );
 79
 80        Self {
 81            provider_name,
 82            api_url,
 83            api_key,
 84            models: vec![ModelInput::new(0, window, cx)],
 85        }
 86    }
 87
 88    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 89        let model_index = self.models.len();
 90        self.models.push(ModelInput::new(model_index, window, cx));
 91    }
 92
 93    fn remove_model(&mut self, index: usize) {
 94        self.models.remove(index);
 95    }
 96}
 97
 98struct ModelCapabilityToggles {
 99    pub supports_tools: ToggleState,
100    pub supports_images: ToggleState,
101    pub supports_parallel_tool_calls: ToggleState,
102    pub supports_prompt_cache_key: ToggleState,
103    pub supports_chat_completions: ToggleState,
104}
105
106struct ModelInput {
107    name: Entity<InputField>,
108    max_completion_tokens: Entity<InputField>,
109    max_output_tokens: Entity<InputField>,
110    max_tokens: Entity<InputField>,
111    capabilities: ModelCapabilityToggles,
112}
113
114impl ModelInput {
115    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
116        let base_tab_index = (3 + (model_index * 4)) as isize;
117
118        let model_name = single_line_input(
119            "Model Name",
120            "e.g. gpt-5, claude-opus-4, gemini-2.5-pro",
121            None,
122            base_tab_index + 1,
123            window,
124            cx,
125        );
126        let max_completion_tokens = single_line_input(
127            "Max Completion Tokens",
128            "200000",
129            Some("200000"),
130            base_tab_index + 2,
131            window,
132            cx,
133        );
134        let max_output_tokens = single_line_input(
135            "Max Output Tokens",
136            "Max Output Tokens",
137            Some("32000"),
138            base_tab_index + 3,
139            window,
140            cx,
141        );
142        let max_tokens = single_line_input(
143            "Max Tokens",
144            "Max Tokens",
145            Some("200000"),
146            base_tab_index + 4,
147            window,
148            cx,
149        );
150
151        let ModelCapabilities {
152            tools,
153            images,
154            parallel_tool_calls,
155            prompt_cache_key,
156            chat_completions,
157        } = ModelCapabilities::default();
158
159        Self {
160            name: model_name,
161            max_completion_tokens,
162            max_output_tokens,
163            max_tokens,
164            capabilities: ModelCapabilityToggles {
165                supports_tools: tools.into(),
166                supports_images: images.into(),
167                supports_parallel_tool_calls: parallel_tool_calls.into(),
168                supports_prompt_cache_key: prompt_cache_key.into(),
169                supports_chat_completions: chat_completions.into(),
170            },
171        }
172    }
173
174    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
175        let name = self.name.read(cx).text(cx);
176        if name.is_empty() {
177            return Err(SharedString::from("Model Name cannot be empty"));
178        }
179        Ok(AvailableModel {
180            name,
181            display_name: None,
182            max_completion_tokens: Some(
183                self.max_completion_tokens
184                    .read(cx)
185                    .text(cx)
186                    .parse::<u64>()
187                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
188            ),
189            max_output_tokens: Some(
190                self.max_output_tokens
191                    .read(cx)
192                    .text(cx)
193                    .parse::<u64>()
194                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
195            ),
196            max_tokens: self
197                .max_tokens
198                .read(cx)
199                .text(cx)
200                .parse::<u64>()
201                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
202            capabilities: ModelCapabilities {
203                tools: self.capabilities.supports_tools.selected(),
204                images: self.capabilities.supports_images.selected(),
205                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
206                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
207                chat_completions: self.capabilities.supports_chat_completions.selected(),
208            },
209        })
210    }
211}
212
213fn save_provider_to_settings(
214    input: &AddLlmProviderInput,
215    cx: &mut App,
216) -> Task<Result<(), SharedString>> {
217    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
218    if provider_name.is_empty() {
219        return Task::ready(Err("Provider Name cannot be empty".into()));
220    }
221
222    if LanguageModelRegistry::read_global(cx)
223        .providers()
224        .iter()
225        .any(|provider| {
226            provider.id().0.as_ref() == provider_name.as_ref()
227                || provider.name().0.as_ref() == provider_name.as_ref()
228        })
229    {
230        return Task::ready(Err(
231            "Provider Name is already taken by another provider".into()
232        ));
233    }
234
235    let api_url = input.api_url.read(cx).text(cx);
236    if api_url.is_empty() {
237        return Task::ready(Err("API URL cannot be empty".into()));
238    }
239
240    let api_key = input.api_key.read(cx).text(cx);
241    if api_key.is_empty() {
242        return Task::ready(Err("API Key cannot be empty".into()));
243    }
244
245    let mut models = Vec::new();
246    let mut model_names: HashSet<String> = HashSet::default();
247    for model in &input.models {
248        match model.parse(cx) {
249            Ok(model) => {
250                if !model_names.insert(model.name.clone()) {
251                    return Task::ready(Err("Model Names must be unique".into()));
252                }
253                models.push(model)
254            }
255            Err(err) => return Task::ready(Err(err)),
256        }
257    }
258
259    let fs = <dyn Fs>::global(cx);
260    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
261    cx.spawn(async move |cx| {
262        task.await
263            .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
264        cx.update(|cx| {
265            update_settings_file(fs, cx, |settings, _cx| {
266                settings
267                    .language_models
268                    .get_or_insert_default()
269                    .openai_compatible
270                    .get_or_insert_default()
271                    .insert(
272                        provider_name,
273                        OpenAiCompatibleSettingsContent {
274                            api_url,
275                            available_models: models,
276                        },
277                    );
278            });
279        });
280        Ok(())
281    })
282}
283
284pub struct AddLlmProviderModal {
285    provider: LlmCompatibleProvider,
286    input: AddLlmProviderInput,
287    scroll_handle: ScrollHandle,
288    focus_handle: FocusHandle,
289    last_error: Option<SharedString>,
290}
291
292impl AddLlmProviderModal {
293    pub fn toggle(
294        provider: LlmCompatibleProvider,
295        workspace: &mut Workspace,
296        window: &mut Window,
297        cx: &mut Context<Workspace>,
298    ) {
299        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
300    }
301
302    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
303        Self {
304            input: AddLlmProviderInput::new(provider, window, cx),
305            provider,
306            last_error: None,
307            focus_handle: cx.focus_handle(),
308            scroll_handle: ScrollHandle::new(),
309        }
310    }
311
312    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
313        let task = save_provider_to_settings(&self.input, cx);
314        cx.spawn(async move |this, cx| {
315            let result = task.await;
316            this.update(cx, |this, cx| match result {
317                Ok(_) => {
318                    cx.emit(DismissEvent);
319                }
320                Err(error) => {
321                    this.last_error = Some(error);
322                    cx.notify();
323                }
324            })
325        })
326        .detach_and_log_err(cx);
327    }
328
329    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
330        cx.emit(DismissEvent);
331    }
332
333    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
334        v_flex()
335            .mt_1()
336            .gap_2()
337            .child(
338                h_flex()
339                    .justify_between()
340                    .child(Label::new("Models").size(LabelSize::Small))
341                    .child(
342                        Button::new("add-model", "Add Model")
343                            .start_icon(
344                                Icon::new(IconName::Plus)
345                                    .size(IconSize::XSmall)
346                                    .color(Color::Muted),
347                            )
348                            .label_size(LabelSize::Small)
349                            .on_click(cx.listener(|this, _, window, cx| {
350                                this.input.add_model(window, cx);
351                                cx.notify();
352                            })),
353                    ),
354            )
355            .children(
356                self.input
357                    .models
358                    .iter()
359                    .enumerate()
360                    .map(|(ix, _)| self.render_model(ix, cx)),
361            )
362    }
363
364    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
365        let has_more_than_one_model = self.input.models.len() > 1;
366        let model = &self.input.models[ix];
367
368        v_flex()
369            .p_2()
370            .gap_2()
371            .rounded_sm()
372            .border_1()
373            .border_dashed()
374            .border_color(cx.theme().colors().border.opacity(0.6))
375            .bg(cx.theme().colors().element_active.opacity(0.15))
376            .child(model.name.clone())
377            .child(
378                h_flex()
379                    .gap_2()
380                    .child(model.max_completion_tokens.clone())
381                    .child(model.max_output_tokens.clone()),
382            )
383            .child(model.max_tokens.clone())
384            .child(
385                v_flex()
386                    .gap_1()
387                    .child(
388                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
389                            .label("Supports tools")
390                            .on_click(cx.listener(move |this, checked, _window, cx| {
391                                this.input.models[ix].capabilities.supports_tools = *checked;
392                                cx.notify();
393                            })),
394                    )
395                    .child(
396                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
397                            .label("Supports images")
398                            .on_click(cx.listener(move |this, checked, _window, cx| {
399                                this.input.models[ix].capabilities.supports_images = *checked;
400                                cx.notify();
401                            })),
402                    )
403                    .child(
404                        Checkbox::new(
405                            ("supports-parallel-tool-calls", ix),
406                            model.capabilities.supports_parallel_tool_calls,
407                        )
408                        .label("Supports parallel_tool_calls")
409                        .on_click(cx.listener(
410                            move |this, checked, _window, cx| {
411                                this.input.models[ix]
412                                    .capabilities
413                                    .supports_parallel_tool_calls = *checked;
414                                cx.notify();
415                            },
416                        )),
417                    )
418                    .child(
419                        Checkbox::new(
420                            ("supports-prompt-cache-key", ix),
421                            model.capabilities.supports_prompt_cache_key,
422                        )
423                        .label("Supports prompt_cache_key")
424                        .on_click(cx.listener(
425                            move |this, checked, _window, cx| {
426                                this.input.models[ix].capabilities.supports_prompt_cache_key =
427                                    *checked;
428                                cx.notify();
429                            },
430                        )),
431                    )
432                    .child(
433                        Checkbox::new(
434                            ("supports-chat-completions", ix),
435                            model.capabilities.supports_chat_completions,
436                        )
437                        .label("Supports /chat/completions")
438                        .on_click(cx.listener(
439                            move |this, checked, _window, cx| {
440                                this.input.models[ix].capabilities.supports_chat_completions =
441                                    *checked;
442                                cx.notify();
443                            },
444                        )),
445                    ),
446            )
447            .when(has_more_than_one_model, |this| {
448                this.child(
449                    Button::new(("remove-model", ix), "Remove Model")
450                        .start_icon(
451                            Icon::new(IconName::Trash)
452                                .size(IconSize::XSmall)
453                                .color(Color::Muted),
454                        )
455                        .label_size(LabelSize::Small)
456                        .style(ButtonStyle::Outlined)
457                        .full_width()
458                        .on_click(cx.listener(move |this, _, _window, cx| {
459                            this.input.remove_model(ix);
460                            cx.notify();
461                        })),
462                )
463            })
464    }
465
466    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
467        window.focus_next(cx);
468    }
469
470    fn on_tab_prev(
471        &mut self,
472        _: &menu::SelectPrevious,
473        window: &mut Window,
474        cx: &mut Context<Self>,
475    ) {
476        window.focus_prev(cx);
477    }
478}
479
480impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
481
482impl Focusable for AddLlmProviderModal {
483    fn focus_handle(&self, _cx: &App) -> FocusHandle {
484        self.focus_handle.clone()
485    }
486}
487
488impl ModalView for AddLlmProviderModal {}
489
490impl Render for AddLlmProviderModal {
491    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
492        let focus_handle = self.focus_handle(cx);
493
494        let window_size = window.viewport_size();
495        let rem_size = window.rem_size();
496        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
497
498        let modal_max_height = if is_large_window {
499            rems_from_px(450.)
500        } else {
501            rems_from_px(200.)
502        };
503
504        v_flex()
505            .id("add-llm-provider-modal")
506            .key_context("AddLlmProviderModal")
507            .w(rems(34.))
508            .elevation_3(cx)
509            .on_action(cx.listener(Self::cancel))
510            .on_action(cx.listener(Self::on_tab))
511            .on_action(cx.listener(Self::on_tab_prev))
512            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
513                this.focus_handle(cx).focus(window, cx);
514            }))
515            .child(
516                Modal::new("configure-context-server", None)
517                    .header(ModalHeader::new().headline("Add LLM Provider").description(
518                        match self.provider {
519                            LlmCompatibleProvider::OpenAi => {
520                                "This provider will use an OpenAI compatible API."
521                            }
522                        },
523                    ))
524                    .when_some(self.last_error.clone(), |this, error| {
525                        this.section(
526                            Section::new().child(
527                                Banner::new()
528                                    .severity(Severity::Warning)
529                                    .child(div().text_xs().child(error)),
530                            ),
531                        )
532                    })
533                    .child(
534                        div()
535                            .size_full()
536                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
537                            .child(
538                                v_flex()
539                                    .id("modal_content")
540                                    .size_full()
541                                    .tab_group()
542                                    .max_h(modal_max_height)
543                                    .pl_3()
544                                    .pr_4()
545                                    .pb_2()
546                                    .gap_2()
547                                    .overflow_y_scroll()
548                                    .track_scroll(&self.scroll_handle)
549                                    .child(self.input.provider_name.clone())
550                                    .child(self.input.api_url.clone())
551                                    .child(self.input.api_key.clone())
552                                    .child(self.render_model_section(cx)),
553                            ),
554                    )
555                    .footer(
556                        ModalFooter::new().end_slot(
557                            h_flex()
558                                .gap_1()
559                                .child(
560                                    Button::new("cancel", "Cancel")
561                                        .key_binding(
562                                            KeyBinding::for_action_in(
563                                                &menu::Cancel,
564                                                &focus_handle,
565                                                cx,
566                                            )
567                                            .map(|kb| kb.size(rems_from_px(12.))),
568                                        )
569                                        .on_click(cx.listener(|this, _event, window, cx| {
570                                            this.cancel(&menu::Cancel, window, cx)
571                                        })),
572                                )
573                                .child(
574                                    Button::new("save-server", "Save Provider")
575                                        .key_binding(
576                                            KeyBinding::for_action_in(
577                                                &menu::Confirm,
578                                                &focus_handle,
579                                                cx,
580                                            )
581                                            .map(|kb| kb.size(rems_from_px(12.))),
582                                        )
583                                        .on_click(cx.listener(|this, _event, window, cx| {
584                                            this.confirm(&menu::Confirm, window, cx)
585                                        })),
586                                ),
587                        ),
588                    ),
589            )
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    use fs::FakeFs;
597    use gpui::{TestAppContext, VisualTestContext};
598    use language_model::{
599        LanguageModelProviderId, LanguageModelProviderName,
600        fake_provider::FakeLanguageModelProvider,
601    };
602    use project::Project;
603    use settings::SettingsStore;
604    use util::path;
605    use workspace::MultiWorkspace;
606
607    #[gpui::test]
608    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
609        let cx = setup_test(cx).await;
610
611        assert_eq!(
612            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
613            Some("Provider Name cannot be empty".into())
614        );
615
616        assert_eq!(
617            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
618            Some("API URL cannot be empty".into())
619        );
620
621        assert_eq!(
622            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
623            Some("API Key cannot be empty".into())
624        );
625
626        assert_eq!(
627            save_provider_validation_errors(
628                "someprovider",
629                "someurl",
630                "somekey",
631                vec![("", "200000", "200000", "32000")],
632                cx,
633            )
634            .await,
635            Some("Model Name cannot be empty".into())
636        );
637
638        assert_eq!(
639            save_provider_validation_errors(
640                "someprovider",
641                "someurl",
642                "somekey",
643                vec![("somemodel", "abc", "200000", "32000")],
644                cx,
645            )
646            .await,
647            Some("Max Tokens must be a number".into())
648        );
649
650        assert_eq!(
651            save_provider_validation_errors(
652                "someprovider",
653                "someurl",
654                "somekey",
655                vec![("somemodel", "200000", "abc", "32000")],
656                cx,
657            )
658            .await,
659            Some("Max Completion Tokens must be a number".into())
660        );
661
662        assert_eq!(
663            save_provider_validation_errors(
664                "someprovider",
665                "someurl",
666                "somekey",
667                vec![("somemodel", "200000", "200000", "abc")],
668                cx,
669            )
670            .await,
671            Some("Max Output Tokens must be a number".into())
672        );
673
674        assert_eq!(
675            save_provider_validation_errors(
676                "someprovider",
677                "someurl",
678                "somekey",
679                vec![
680                    ("somemodel", "200000", "200000", "32000"),
681                    ("somemodel", "200000", "200000", "32000"),
682                ],
683                cx,
684            )
685            .await,
686            Some("Model Names must be unique".into())
687        );
688    }
689
690    #[gpui::test]
691    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
692        let cx = setup_test(cx).await;
693
694        cx.update(|_window, cx| {
695            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
696                registry.register_provider(
697                    Arc::new(FakeLanguageModelProvider::new(
698                        LanguageModelProviderId::new("someprovider"),
699                        LanguageModelProviderName::new("Some Provider"),
700                    )),
701                    cx,
702                );
703            });
704        });
705
706        assert_eq!(
707            save_provider_validation_errors(
708                "someprovider",
709                "someurl",
710                "someapikey",
711                vec![("somemodel", "200000", "200000", "32000")],
712                cx,
713            )
714            .await,
715            Some("Provider Name is already taken by another provider".into())
716        );
717    }
718
719    #[gpui::test]
720    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
721        let cx = setup_test(cx).await;
722
723        cx.update(|window, cx| {
724            let model_input = ModelInput::new(0, window, cx);
725            model_input.name.update(cx, |input, cx| {
726                input.set_text("somemodel", window, cx);
727            });
728            assert_eq!(
729                model_input.capabilities.supports_tools,
730                ToggleState::Selected
731            );
732            assert_eq!(
733                model_input.capabilities.supports_images,
734                ToggleState::Unselected
735            );
736            assert_eq!(
737                model_input.capabilities.supports_parallel_tool_calls,
738                ToggleState::Unselected
739            );
740            assert_eq!(
741                model_input.capabilities.supports_prompt_cache_key,
742                ToggleState::Unselected
743            );
744            assert_eq!(
745                model_input.capabilities.supports_chat_completions,
746                ToggleState::Selected
747            );
748
749            let parsed_model = model_input.parse(cx).unwrap();
750            assert!(parsed_model.capabilities.tools);
751            assert!(!parsed_model.capabilities.images);
752            assert!(!parsed_model.capabilities.parallel_tool_calls);
753            assert!(!parsed_model.capabilities.prompt_cache_key);
754            assert!(parsed_model.capabilities.chat_completions);
755        });
756    }
757
758    #[gpui::test]
759    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
760        let cx = setup_test(cx).await;
761
762        cx.update(|window, cx| {
763            let mut model_input = ModelInput::new(0, window, cx);
764            model_input.name.update(cx, |input, cx| {
765                input.set_text("somemodel", window, cx);
766            });
767
768            model_input.capabilities.supports_tools = ToggleState::Unselected;
769            model_input.capabilities.supports_images = ToggleState::Unselected;
770            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
771            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
772            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
773
774            let parsed_model = model_input.parse(cx).unwrap();
775            assert!(!parsed_model.capabilities.tools);
776            assert!(!parsed_model.capabilities.images);
777            assert!(!parsed_model.capabilities.parallel_tool_calls);
778            assert!(!parsed_model.capabilities.prompt_cache_key);
779            assert!(!parsed_model.capabilities.chat_completions);
780        });
781    }
782
783    #[gpui::test]
784    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
785        let cx = setup_test(cx).await;
786
787        cx.update(|window, cx| {
788            let mut model_input = ModelInput::new(0, window, cx);
789            model_input.name.update(cx, |input, cx| {
790                input.set_text("somemodel", window, cx);
791            });
792
793            model_input.capabilities.supports_tools = ToggleState::Selected;
794            model_input.capabilities.supports_images = ToggleState::Unselected;
795            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
796            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
797            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
798
799            let parsed_model = model_input.parse(cx).unwrap();
800            assert_eq!(parsed_model.name, "somemodel");
801            assert!(parsed_model.capabilities.tools);
802            assert!(!parsed_model.capabilities.images);
803            assert!(parsed_model.capabilities.parallel_tool_calls);
804            assert!(!parsed_model.capabilities.prompt_cache_key);
805            assert!(parsed_model.capabilities.chat_completions);
806        });
807    }
808
809    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
810        cx.update(|cx| {
811            let store = SettingsStore::test(cx);
812            cx.set_global(store);
813            theme::init(theme::LoadThemes::JustBase, cx);
814
815            language_model::init_settings(cx);
816            editor::init(cx);
817        });
818
819        let fs = FakeFs::new(cx.executor());
820        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
821        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
822        let (multi_workspace, cx) =
823            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
824        let _workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
825
826        cx
827    }
828
829    async fn save_provider_validation_errors(
830        provider_name: &str,
831        api_url: &str,
832        api_key: &str,
833        models: Vec<(&str, &str, &str, &str)>,
834        cx: &mut VisualTestContext,
835    ) -> Option<SharedString> {
836        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
837            input.update(cx, |input, cx| {
838                input.set_text(text, window, cx);
839            });
840        }
841
842        let task = cx.update(|window, cx| {
843            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
844            set_text(&input.provider_name, provider_name, window, cx);
845            set_text(&input.api_url, api_url, window, cx);
846            set_text(&input.api_key, api_key, window, cx);
847
848            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
849                models.iter().enumerate()
850            {
851                if i >= input.models.len() {
852                    input.models.push(ModelInput::new(i, window, cx));
853                }
854                let model = &mut input.models[i];
855                set_text(&model.name, name, window, cx);
856                set_text(&model.max_tokens, max_tokens, window, cx);
857                set_text(
858                    &model.max_completion_tokens,
859                    max_completion_tokens,
860                    window,
861                    cx,
862                );
863                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
864            }
865            save_provider_to_settings(&input, cx)
866        });
867
868        task.await.err()
869    }
870}