add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: &str,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input.set_text(text, window, cx);
 35        }
 36        input
 37    })
 38}
 39
 40#[derive(Clone, Copy)]
 41pub enum LlmCompatibleProvider {
 42    OpenAi,
 43}
 44
 45impl LlmCompatibleProvider {
 46    fn name(&self) -> &'static str {
 47        match self {
 48            LlmCompatibleProvider::OpenAi => "OpenAI",
 49        }
 50    }
 51
 52    fn api_url(&self) -> &'static str {
 53        match self {
 54            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 55        }
 56    }
 57}
 58
 59struct AddLlmProviderInput {
 60    provider_name: Entity<InputField>,
 61    api_url: Entity<InputField>,
 62    api_key: Entity<InputField>,
 63    models: Vec<ModelInput>,
 64}
 65
 66impl AddLlmProviderInput {
 67    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 68        let provider_name =
 69            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 70        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 71        let api_key = cx.new(|cx| {
 72            InputField::new(
 73                window,
 74                cx,
 75                "000000000000000000000000000000000000000000000000",
 76            )
 77            .label("API Key")
 78            .tab_index(3)
 79            .tab_stop(true)
 80            .masked(true)
 81        });
 82
 83        Self {
 84            provider_name,
 85            api_url,
 86            api_key,
 87            models: vec![ModelInput::new(0, window, cx)],
 88        }
 89    }
 90
 91    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 92        let model_index = self.models.len();
 93        self.models.push(ModelInput::new(model_index, window, cx));
 94    }
 95
 96    fn remove_model(&mut self, index: usize) {
 97        self.models.remove(index);
 98    }
 99}
100
101struct ModelCapabilityToggles {
102    pub supports_tools: ToggleState,
103    pub supports_images: ToggleState,
104    pub supports_parallel_tool_calls: ToggleState,
105    pub supports_prompt_cache_key: ToggleState,
106    pub supports_chat_completions: ToggleState,
107}
108
109struct ModelInput {
110    name: Entity<InputField>,
111    max_completion_tokens: Entity<InputField>,
112    max_output_tokens: Entity<InputField>,
113    max_tokens: Entity<InputField>,
114    capabilities: ModelCapabilityToggles,
115}
116
117impl ModelInput {
118    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
119        let base_tab_index = (3 + (model_index * 4)) as isize;
120
121        let model_name = single_line_input(
122            "Model Name",
123            "e.g. gpt-5, claude-opus-4, gemini-2.5-pro",
124            None,
125            base_tab_index + 1,
126            window,
127            cx,
128        );
129        let max_completion_tokens = single_line_input(
130            "Max Completion Tokens",
131            "200000",
132            Some("200000"),
133            base_tab_index + 2,
134            window,
135            cx,
136        );
137        let max_output_tokens = single_line_input(
138            "Max Output Tokens",
139            "Max Output Tokens",
140            Some("32000"),
141            base_tab_index + 3,
142            window,
143            cx,
144        );
145        let max_tokens = single_line_input(
146            "Max Tokens",
147            "Max Tokens",
148            Some("200000"),
149            base_tab_index + 4,
150            window,
151            cx,
152        );
153
154        let ModelCapabilities {
155            tools,
156            images,
157            parallel_tool_calls,
158            prompt_cache_key,
159            chat_completions,
160            ..
161        } = ModelCapabilities::default();
162
163        Self {
164            name: model_name,
165            max_completion_tokens,
166            max_output_tokens,
167            max_tokens,
168            capabilities: ModelCapabilityToggles {
169                supports_tools: tools.into(),
170                supports_images: images.into(),
171                supports_parallel_tool_calls: parallel_tool_calls.into(),
172                supports_prompt_cache_key: prompt_cache_key.into(),
173                supports_chat_completions: chat_completions.into(),
174            },
175        }
176    }
177
178    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
179        let name = self.name.read(cx).text(cx);
180        if name.is_empty() {
181            return Err(SharedString::from("Model Name cannot be empty"));
182        }
183        Ok(AvailableModel {
184            name,
185            display_name: None,
186            max_completion_tokens: Some(
187                self.max_completion_tokens
188                    .read(cx)
189                    .text(cx)
190                    .parse::<u64>()
191                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
192            ),
193            max_output_tokens: Some(
194                self.max_output_tokens
195                    .read(cx)
196                    .text(cx)
197                    .parse::<u64>()
198                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
199            ),
200            max_tokens: self
201                .max_tokens
202                .read(cx)
203                .text(cx)
204                .parse::<u64>()
205                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
206            reasoning_effort: None,
207            capabilities: ModelCapabilities {
208                tools: self.capabilities.supports_tools.selected(),
209                images: self.capabilities.supports_images.selected(),
210                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
211                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
212                chat_completions: self.capabilities.supports_chat_completions.selected(),
213                interleaved_reasoning: false,
214            },
215        })
216    }
217}
218
219fn save_provider_to_settings(
220    input: &AddLlmProviderInput,
221    cx: &mut App,
222) -> Task<Result<(), SharedString>> {
223    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
224    if provider_name.is_empty() {
225        return Task::ready(Err("Provider Name cannot be empty".into()));
226    }
227
228    if LanguageModelRegistry::read_global(cx)
229        .providers()
230        .iter()
231        .any(|provider| {
232            provider.id().0.as_ref() == provider_name.as_ref()
233                || provider.name().0.as_ref() == provider_name.as_ref()
234        })
235    {
236        return Task::ready(Err(
237            "Provider Name is already taken by another provider".into()
238        ));
239    }
240
241    let api_url = input.api_url.read(cx).text(cx);
242    if api_url.is_empty() {
243        return Task::ready(Err("API URL cannot be empty".into()));
244    }
245
246    let api_key = input.api_key.read(cx).text(cx);
247    if api_key.is_empty() {
248        return Task::ready(Err("API Key cannot be empty".into()));
249    }
250
251    let mut models = Vec::new();
252    let mut model_names: HashSet<String> = HashSet::default();
253    for model in &input.models {
254        match model.parse(cx) {
255            Ok(model) => {
256                if !model_names.insert(model.name.clone()) {
257                    return Task::ready(Err("Model Names must be unique".into()));
258                }
259                models.push(model)
260            }
261            Err(err) => return Task::ready(Err(err)),
262        }
263    }
264
265    let fs = <dyn Fs>::global(cx);
266    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
267    cx.spawn(async move |cx| {
268        task.await
269            .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
270        cx.update(|cx| {
271            update_settings_file(fs, cx, |settings, _cx| {
272                settings
273                    .language_models
274                    .get_or_insert_default()
275                    .openai_compatible
276                    .get_or_insert_default()
277                    .insert(
278                        provider_name,
279                        OpenAiCompatibleSettingsContent {
280                            api_url,
281                            available_models: models,
282                        },
283                    );
284            });
285        });
286        Ok(())
287    })
288}
289
290pub struct AddLlmProviderModal {
291    provider: LlmCompatibleProvider,
292    input: AddLlmProviderInput,
293    scroll_handle: ScrollHandle,
294    focus_handle: FocusHandle,
295    last_error: Option<SharedString>,
296}
297
298impl AddLlmProviderModal {
299    pub fn toggle(
300        provider: LlmCompatibleProvider,
301        workspace: &mut Workspace,
302        window: &mut Window,
303        cx: &mut Context<Workspace>,
304    ) {
305        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
306    }
307
308    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
309        Self {
310            input: AddLlmProviderInput::new(provider, window, cx),
311            provider,
312            last_error: None,
313            focus_handle: cx.focus_handle(),
314            scroll_handle: ScrollHandle::new(),
315        }
316    }
317
318    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
319        let task = save_provider_to_settings(&self.input, cx);
320        cx.spawn(async move |this, cx| {
321            let result = task.await;
322            this.update(cx, |this, cx| match result {
323                Ok(_) => {
324                    cx.emit(DismissEvent);
325                }
326                Err(error) => {
327                    this.last_error = Some(error);
328                    cx.notify();
329                }
330            })
331        })
332        .detach_and_log_err(cx);
333    }
334
335    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
336        cx.emit(DismissEvent);
337    }
338
339    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
340        v_flex()
341            .mt_1()
342            .gap_2()
343            .child(
344                h_flex()
345                    .justify_between()
346                    .child(Label::new("Models").size(LabelSize::Small))
347                    .child(
348                        Button::new("add-model", "Add Model")
349                            .start_icon(
350                                Icon::new(IconName::Plus)
351                                    .size(IconSize::XSmall)
352                                    .color(Color::Muted),
353                            )
354                            .label_size(LabelSize::Small)
355                            .on_click(cx.listener(|this, _, window, cx| {
356                                this.input.add_model(window, cx);
357                                cx.notify();
358                            })),
359                    ),
360            )
361            .children(
362                self.input
363                    .models
364                    .iter()
365                    .enumerate()
366                    .map(|(ix, _)| self.render_model(ix, cx)),
367            )
368    }
369
370    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
371        let has_more_than_one_model = self.input.models.len() > 1;
372        let model = &self.input.models[ix];
373
374        v_flex()
375            .p_2()
376            .gap_2()
377            .rounded_sm()
378            .border_1()
379            .border_dashed()
380            .border_color(cx.theme().colors().border.opacity(0.6))
381            .bg(cx.theme().colors().element_active.opacity(0.15))
382            .child(model.name.clone())
383            .child(
384                h_flex()
385                    .gap_2()
386                    .child(model.max_completion_tokens.clone())
387                    .child(model.max_output_tokens.clone()),
388            )
389            .child(model.max_tokens.clone())
390            .child(
391                v_flex()
392                    .gap_1()
393                    .child(
394                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
395                            .label("Supports tools")
396                            .on_click(cx.listener(move |this, checked, _window, cx| {
397                                this.input.models[ix].capabilities.supports_tools = *checked;
398                                cx.notify();
399                            })),
400                    )
401                    .child(
402                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
403                            .label("Supports images")
404                            .on_click(cx.listener(move |this, checked, _window, cx| {
405                                this.input.models[ix].capabilities.supports_images = *checked;
406                                cx.notify();
407                            })),
408                    )
409                    .child(
410                        Checkbox::new(
411                            ("supports-parallel-tool-calls", ix),
412                            model.capabilities.supports_parallel_tool_calls,
413                        )
414                        .label("Supports parallel_tool_calls")
415                        .on_click(cx.listener(
416                            move |this, checked, _window, cx| {
417                                this.input.models[ix]
418                                    .capabilities
419                                    .supports_parallel_tool_calls = *checked;
420                                cx.notify();
421                            },
422                        )),
423                    )
424                    .child(
425                        Checkbox::new(
426                            ("supports-prompt-cache-key", ix),
427                            model.capabilities.supports_prompt_cache_key,
428                        )
429                        .label("Supports prompt_cache_key")
430                        .on_click(cx.listener(
431                            move |this, checked, _window, cx| {
432                                this.input.models[ix].capabilities.supports_prompt_cache_key =
433                                    *checked;
434                                cx.notify();
435                            },
436                        )),
437                    )
438                    .child(
439                        Checkbox::new(
440                            ("supports-chat-completions", ix),
441                            model.capabilities.supports_chat_completions,
442                        )
443                        .label("Supports /chat/completions")
444                        .on_click(cx.listener(
445                            move |this, checked, _window, cx| {
446                                this.input.models[ix].capabilities.supports_chat_completions =
447                                    *checked;
448                                cx.notify();
449                            },
450                        )),
451                    ),
452            )
453            .when(has_more_than_one_model, |this| {
454                this.child(
455                    Button::new(("remove-model", ix), "Remove Model")
456                        .start_icon(
457                            Icon::new(IconName::Trash)
458                                .size(IconSize::XSmall)
459                                .color(Color::Muted),
460                        )
461                        .label_size(LabelSize::Small)
462                        .style(ButtonStyle::Outlined)
463                        .full_width()
464                        .on_click(cx.listener(move |this, _, _window, cx| {
465                            this.input.remove_model(ix);
466                            cx.notify();
467                        })),
468                )
469            })
470    }
471
472    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
473        window.focus_next(cx);
474    }
475
476    fn on_tab_prev(
477        &mut self,
478        _: &menu::SelectPrevious,
479        window: &mut Window,
480        cx: &mut Context<Self>,
481    ) {
482        window.focus_prev(cx);
483    }
484}
485
486impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
487
488impl Focusable for AddLlmProviderModal {
489    fn focus_handle(&self, _cx: &App) -> FocusHandle {
490        self.focus_handle.clone()
491    }
492}
493
494impl ModalView for AddLlmProviderModal {}
495
496impl Render for AddLlmProviderModal {
497    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
498        let focus_handle = self.focus_handle(cx);
499
500        let window_size = window.viewport_size();
501        let rem_size = window.rem_size();
502        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
503
504        let modal_max_height = if is_large_window {
505            rems_from_px(450.)
506        } else {
507            rems_from_px(200.)
508        };
509
510        v_flex()
511            .id("add-llm-provider-modal")
512            .key_context("AddLlmProviderModal")
513            .w(rems(34.))
514            .elevation_3(cx)
515            .on_action(cx.listener(Self::cancel))
516            .on_action(cx.listener(Self::on_tab))
517            .on_action(cx.listener(Self::on_tab_prev))
518            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
519                this.focus_handle(cx).focus(window, cx);
520            }))
521            .child(
522                Modal::new("configure-context-server", None)
523                    .header(ModalHeader::new().headline("Add LLM Provider").description(
524                        match self.provider {
525                            LlmCompatibleProvider::OpenAi => {
526                                "This provider will use an OpenAI compatible API."
527                            }
528                        },
529                    ))
530                    .when_some(self.last_error.clone(), |this, error| {
531                        this.section(
532                            Section::new().child(
533                                Banner::new()
534                                    .severity(Severity::Warning)
535                                    .child(div().text_xs().child(error)),
536                            ),
537                        )
538                    })
539                    .child(
540                        div()
541                            .size_full()
542                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
543                            .child(
544                                v_flex()
545                                    .id("modal_content")
546                                    .size_full()
547                                    .tab_group()
548                                    .max_h(modal_max_height)
549                                    .pl_3()
550                                    .pr_4()
551                                    .pb_2()
552                                    .gap_2()
553                                    .overflow_y_scroll()
554                                    .track_scroll(&self.scroll_handle)
555                                    .child(self.input.provider_name.clone())
556                                    .child(self.input.api_url.clone())
557                                    .child(self.input.api_key.clone())
558                                    .child(self.render_model_section(cx)),
559                            ),
560                    )
561                    .footer(
562                        ModalFooter::new().end_slot(
563                            h_flex()
564                                .gap_1()
565                                .child(
566                                    Button::new("cancel", "Cancel")
567                                        .key_binding(
568                                            KeyBinding::for_action_in(
569                                                &menu::Cancel,
570                                                &focus_handle,
571                                                cx,
572                                            )
573                                            .map(|kb| kb.size(rems_from_px(12.))),
574                                        )
575                                        .on_click(cx.listener(|this, _event, window, cx| {
576                                            this.cancel(&menu::Cancel, window, cx)
577                                        })),
578                                )
579                                .child(
580                                    Button::new("save-server", "Save Provider")
581                                        .key_binding(
582                                            KeyBinding::for_action_in(
583                                                &menu::Confirm,
584                                                &focus_handle,
585                                                cx,
586                                            )
587                                            .map(|kb| kb.size(rems_from_px(12.))),
588                                        )
589                                        .on_click(cx.listener(|this, _event, window, cx| {
590                                            this.confirm(&menu::Confirm, window, cx)
591                                        })),
592                                ),
593                        ),
594                    ),
595            )
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602    use fs::FakeFs;
603    use gpui::{TestAppContext, VisualTestContext};
604    use language_model::{
605        LanguageModelProviderId, LanguageModelProviderName,
606        fake_provider::FakeLanguageModelProvider,
607    };
608    use project::Project;
609    use settings::SettingsStore;
610    use util::path;
611    use workspace::MultiWorkspace;
612
613    #[gpui::test]
614    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
615        let cx = setup_test(cx).await;
616
617        assert_eq!(
618            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
619            Some("Provider Name cannot be empty".into())
620        );
621
622        assert_eq!(
623            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
624            Some("API URL cannot be empty".into())
625        );
626
627        assert_eq!(
628            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
629            Some("API Key cannot be empty".into())
630        );
631
632        assert_eq!(
633            save_provider_validation_errors(
634                "someprovider",
635                "someurl",
636                "somekey",
637                vec![("", "200000", "200000", "32000")],
638                cx,
639            )
640            .await,
641            Some("Model Name cannot be empty".into())
642        );
643
644        assert_eq!(
645            save_provider_validation_errors(
646                "someprovider",
647                "someurl",
648                "somekey",
649                vec![("somemodel", "abc", "200000", "32000")],
650                cx,
651            )
652            .await,
653            Some("Max Tokens must be a number".into())
654        );
655
656        assert_eq!(
657            save_provider_validation_errors(
658                "someprovider",
659                "someurl",
660                "somekey",
661                vec![("somemodel", "200000", "abc", "32000")],
662                cx,
663            )
664            .await,
665            Some("Max Completion Tokens must be a number".into())
666        );
667
668        assert_eq!(
669            save_provider_validation_errors(
670                "someprovider",
671                "someurl",
672                "somekey",
673                vec![("somemodel", "200000", "200000", "abc")],
674                cx,
675            )
676            .await,
677            Some("Max Output Tokens must be a number".into())
678        );
679
680        assert_eq!(
681            save_provider_validation_errors(
682                "someprovider",
683                "someurl",
684                "somekey",
685                vec![
686                    ("somemodel", "200000", "200000", "32000"),
687                    ("somemodel", "200000", "200000", "32000"),
688                ],
689                cx,
690            )
691            .await,
692            Some("Model Names must be unique".into())
693        );
694    }
695
696    #[gpui::test]
697    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
698        let cx = setup_test(cx).await;
699
700        cx.update(|_window, cx| {
701            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
702                registry.register_provider(
703                    Arc::new(FakeLanguageModelProvider::new(
704                        LanguageModelProviderId::new("someprovider"),
705                        LanguageModelProviderName::new("Some Provider"),
706                    )),
707                    cx,
708                );
709            });
710        });
711
712        assert_eq!(
713            save_provider_validation_errors(
714                "someprovider",
715                "someurl",
716                "someapikey",
717                vec![("somemodel", "200000", "200000", "32000")],
718                cx,
719            )
720            .await,
721            Some("Provider Name is already taken by another provider".into())
722        );
723    }
724
725    #[gpui::test]
726    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
727        let cx = setup_test(cx).await;
728
729        cx.update(|window, cx| {
730            let model_input = ModelInput::new(0, window, cx);
731            model_input.name.update(cx, |input, cx| {
732                input.set_text("somemodel", window, cx);
733            });
734            assert_eq!(
735                model_input.capabilities.supports_tools,
736                ToggleState::Selected
737            );
738            assert_eq!(
739                model_input.capabilities.supports_images,
740                ToggleState::Unselected
741            );
742            assert_eq!(
743                model_input.capabilities.supports_parallel_tool_calls,
744                ToggleState::Unselected
745            );
746            assert_eq!(
747                model_input.capabilities.supports_prompt_cache_key,
748                ToggleState::Unselected
749            );
750            assert_eq!(
751                model_input.capabilities.supports_chat_completions,
752                ToggleState::Selected
753            );
754
755            let parsed_model = model_input.parse(cx).unwrap();
756            assert!(parsed_model.capabilities.tools);
757            assert!(!parsed_model.capabilities.images);
758            assert!(!parsed_model.capabilities.parallel_tool_calls);
759            assert!(!parsed_model.capabilities.prompt_cache_key);
760            assert!(parsed_model.capabilities.chat_completions);
761        });
762    }
763
764    #[gpui::test]
765    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
766        let cx = setup_test(cx).await;
767
768        cx.update(|window, cx| {
769            let mut model_input = ModelInput::new(0, window, cx);
770            model_input.name.update(cx, |input, cx| {
771                input.set_text("somemodel", window, cx);
772            });
773
774            model_input.capabilities.supports_tools = ToggleState::Unselected;
775            model_input.capabilities.supports_images = ToggleState::Unselected;
776            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
777            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
778            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
779
780            let parsed_model = model_input.parse(cx).unwrap();
781            assert!(!parsed_model.capabilities.tools);
782            assert!(!parsed_model.capabilities.images);
783            assert!(!parsed_model.capabilities.parallel_tool_calls);
784            assert!(!parsed_model.capabilities.prompt_cache_key);
785            assert!(!parsed_model.capabilities.chat_completions);
786        });
787    }
788
789    #[gpui::test]
790    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
791        let cx = setup_test(cx).await;
792
793        cx.update(|window, cx| {
794            let mut model_input = ModelInput::new(0, window, cx);
795            model_input.name.update(cx, |input, cx| {
796                input.set_text("somemodel", window, cx);
797            });
798
799            model_input.capabilities.supports_tools = ToggleState::Selected;
800            model_input.capabilities.supports_images = ToggleState::Unselected;
801            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
802            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
803            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
804
805            let parsed_model = model_input.parse(cx).unwrap();
806            assert_eq!(parsed_model.name, "somemodel");
807            assert!(parsed_model.capabilities.tools);
808            assert!(!parsed_model.capabilities.images);
809            assert!(parsed_model.capabilities.parallel_tool_calls);
810            assert!(!parsed_model.capabilities.prompt_cache_key);
811            assert!(parsed_model.capabilities.chat_completions);
812        });
813    }
814
815    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
816        cx.update(|cx| {
817            let store = SettingsStore::test(cx);
818            cx.set_global(store);
819            theme_settings::init(theme::LoadThemes::JustBase, cx);
820
821            language_model::init(cx);
822            editor::init(cx);
823        });
824
825        let fs = FakeFs::new(cx.executor());
826        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
827        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
828        let (multi_workspace, cx) =
829            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
830        let _workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
831
832        cx
833    }
834
835    async fn save_provider_validation_errors(
836        provider_name: &str,
837        api_url: &str,
838        api_key: &str,
839        models: Vec<(&str, &str, &str, &str)>,
840        cx: &mut VisualTestContext,
841    ) -> Option<SharedString> {
842        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
843            input.update(cx, |input, cx| {
844                input.set_text(text, window, cx);
845            });
846        }
847
848        let task = cx.update(|window, cx| {
849            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
850            set_text(&input.provider_name, provider_name, window, cx);
851            set_text(&input.api_url, api_url, window, cx);
852            set_text(&input.api_key, api_key, window, cx);
853
854            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
855                models.iter().enumerate()
856            {
857                if i >= input.models.len() {
858                    input.models.push(ModelInput::new(i, window, cx));
859                }
860                let model = &mut input.models[i];
861                set_text(&model.name, name, window, cx);
862                set_text(&model.max_tokens, max_tokens, window, cx);
863                set_text(
864                    &model.max_completion_tokens,
865                    max_completion_tokens,
866                    window,
867                    cx,
868                );
869                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
870            }
871            save_provider_to_settings(&input, cx)
872        });
873
874        task.await.err()
875    }
876}