1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashSet;
5use fs::Fs;
6use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
7use language_model::LanguageModelRegistry;
8use language_models::{
9 AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
10 provider::open_ai_compatible::AvailableModel,
11};
12use settings::update_settings_file;
13use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
14use ui_input::SingleLineInput;
15use workspace::{ModalView, Workspace};
16
17#[derive(Clone, Copy)]
18pub enum LlmCompatibleProvider {
19 OpenAi,
20}
21
22impl LlmCompatibleProvider {
23 fn name(&self) -> &'static str {
24 match self {
25 LlmCompatibleProvider::OpenAi => "OpenAI",
26 }
27 }
28
29 fn api_url(&self) -> &'static str {
30 match self {
31 LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
32 }
33 }
34}
35
36struct AddLlmProviderInput {
37 provider_name: Entity<SingleLineInput>,
38 api_url: Entity<SingleLineInput>,
39 api_key: Entity<SingleLineInput>,
40 models: Vec<ModelInput>,
41}
42
43impl AddLlmProviderInput {
44 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
45 let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
46 let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
47 let api_key = single_line_input(
48 "API Key",
49 "000000000000000000000000000000000000000000000000",
50 None,
51 window,
52 cx,
53 );
54
55 Self {
56 provider_name,
57 api_url,
58 api_key,
59 models: vec![ModelInput::new(window, cx)],
60 }
61 }
62
63 fn add_model(&mut self, window: &mut Window, cx: &mut App) {
64 self.models.push(ModelInput::new(window, cx));
65 }
66
67 fn remove_model(&mut self, index: usize) {
68 self.models.remove(index);
69 }
70}
71
72struct ModelInput {
73 name: Entity<SingleLineInput>,
74 max_completion_tokens: Entity<SingleLineInput>,
75 max_output_tokens: Entity<SingleLineInput>,
76 max_tokens: Entity<SingleLineInput>,
77}
78
79impl ModelInput {
80 fn new(window: &mut Window, cx: &mut App) -> Self {
81 let model_name = single_line_input(
82 "Model Name",
83 "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
84 None,
85 window,
86 cx,
87 );
88 let max_completion_tokens = single_line_input(
89 "Max Completion Tokens",
90 "200000",
91 Some("200000"),
92 window,
93 cx,
94 );
95 let max_output_tokens = single_line_input(
96 "Max Output Tokens",
97 "Max Output Tokens",
98 Some("32000"),
99 window,
100 cx,
101 );
102 let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
103 Self {
104 name: model_name,
105 max_completion_tokens,
106 max_output_tokens,
107 max_tokens,
108 }
109 }
110
111 fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
112 let name = self.name.read(cx).text(cx);
113 if name.is_empty() {
114 return Err(SharedString::from("Model Name cannot be empty"));
115 }
116 Ok(AvailableModel {
117 name,
118 display_name: None,
119 max_completion_tokens: Some(
120 self.max_completion_tokens
121 .read(cx)
122 .text(cx)
123 .parse::<u64>()
124 .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
125 ),
126 max_output_tokens: Some(
127 self.max_output_tokens
128 .read(cx)
129 .text(cx)
130 .parse::<u64>()
131 .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
132 ),
133 max_tokens: self
134 .max_tokens
135 .read(cx)
136 .text(cx)
137 .parse::<u64>()
138 .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
139 })
140 }
141}
142
143fn single_line_input(
144 label: impl Into<SharedString>,
145 placeholder: impl Into<SharedString>,
146 text: Option<&str>,
147 window: &mut Window,
148 cx: &mut App,
149) -> Entity<SingleLineInput> {
150 cx.new(|cx| {
151 let input = SingleLineInput::new(window, cx, placeholder).label(label);
152 if let Some(text) = text {
153 input
154 .editor()
155 .update(cx, |editor, cx| editor.set_text(text, window, cx));
156 }
157 input
158 })
159}
160
161fn save_provider_to_settings(
162 input: &AddLlmProviderInput,
163 cx: &mut App,
164) -> Task<Result<(), SharedString>> {
165 let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
166 if provider_name.is_empty() {
167 return Task::ready(Err("Provider Name cannot be empty".into()));
168 }
169
170 if LanguageModelRegistry::read_global(cx)
171 .providers()
172 .iter()
173 .any(|provider| {
174 provider.id().0.as_ref() == provider_name.as_ref()
175 || provider.name().0.as_ref() == provider_name.as_ref()
176 })
177 {
178 return Task::ready(Err(
179 "Provider Name is already taken by another provider".into()
180 ));
181 }
182
183 let api_url = input.api_url.read(cx).text(cx);
184 if api_url.is_empty() {
185 return Task::ready(Err("API URL cannot be empty".into()));
186 }
187
188 let api_key = input.api_key.read(cx).text(cx);
189 if api_key.is_empty() {
190 return Task::ready(Err("API Key cannot be empty".into()));
191 }
192
193 let mut models = Vec::new();
194 let mut model_names: HashSet<String> = HashSet::default();
195 for model in &input.models {
196 match model.parse(cx) {
197 Ok(model) => {
198 if !model_names.insert(model.name.clone()) {
199 return Task::ready(Err("Model Names must be unique".into()));
200 }
201 models.push(model)
202 }
203 Err(err) => return Task::ready(Err(err)),
204 }
205 }
206
207 let fs = <dyn Fs>::global(cx);
208 let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
209 cx.spawn(async move |cx| {
210 task.await
211 .map_err(|_| "Failed to write API key to keychain")?;
212 cx.update(|cx| {
213 update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
214 settings.openai_compatible.get_or_insert_default().insert(
215 provider_name,
216 OpenAiCompatibleSettingsContent {
217 api_url,
218 available_models: models,
219 },
220 );
221 });
222 })
223 .ok();
224 Ok(())
225 })
226}
227
228pub struct AddLlmProviderModal {
229 provider: LlmCompatibleProvider,
230 input: AddLlmProviderInput,
231 focus_handle: FocusHandle,
232 last_error: Option<SharedString>,
233}
234
235impl AddLlmProviderModal {
236 pub fn toggle(
237 provider: LlmCompatibleProvider,
238 workspace: &mut Workspace,
239 window: &mut Window,
240 cx: &mut Context<Workspace>,
241 ) {
242 workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
243 }
244
245 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
246 Self {
247 input: AddLlmProviderInput::new(provider, window, cx),
248 provider,
249 last_error: None,
250 focus_handle: cx.focus_handle(),
251 }
252 }
253
254 fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
255 let task = save_provider_to_settings(&self.input, cx);
256 cx.spawn(async move |this, cx| {
257 let result = task.await;
258 this.update(cx, |this, cx| match result {
259 Ok(_) => {
260 cx.emit(DismissEvent);
261 }
262 Err(error) => {
263 this.last_error = Some(error);
264 cx.notify();
265 }
266 })
267 })
268 .detach_and_log_err(cx);
269 }
270
271 fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
272 cx.emit(DismissEvent);
273 }
274
275 fn render_section(&self) -> Section {
276 Section::new()
277 .child(self.input.provider_name.clone())
278 .child(self.input.api_url.clone())
279 .child(self.input.api_key.clone())
280 }
281
282 fn render_model_section(&self, cx: &mut Context<Self>) -> Section {
283 Section::new().child(
284 v_flex()
285 .gap_2()
286 .child(
287 h_flex()
288 .justify_between()
289 .child(Label::new("Models").size(LabelSize::Small))
290 .child(
291 Button::new("add-model", "Add Model")
292 .icon(IconName::Plus)
293 .icon_position(IconPosition::Start)
294 .icon_size(IconSize::XSmall)
295 .icon_color(Color::Muted)
296 .label_size(LabelSize::Small)
297 .on_click(cx.listener(|this, _, window, cx| {
298 this.input.add_model(window, cx);
299 cx.notify();
300 })),
301 ),
302 )
303 .children(
304 self.input
305 .models
306 .iter()
307 .enumerate()
308 .map(|(ix, _)| self.render_model(ix, cx)),
309 ),
310 )
311 }
312
313 fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
314 let has_more_than_one_model = self.input.models.len() > 1;
315 let model = &self.input.models[ix];
316
317 v_flex()
318 .p_2()
319 .gap_2()
320 .rounded_sm()
321 .border_1()
322 .border_dashed()
323 .border_color(cx.theme().colors().border.opacity(0.6))
324 .bg(cx.theme().colors().element_active.opacity(0.15))
325 .child(model.name.clone())
326 .child(
327 h_flex()
328 .gap_2()
329 .child(model.max_completion_tokens.clone())
330 .child(model.max_output_tokens.clone()),
331 )
332 .child(model.max_tokens.clone())
333 .when(has_more_than_one_model, |this| {
334 this.child(
335 Button::new(("remove-model", ix), "Remove Model")
336 .icon(IconName::Trash)
337 .icon_position(IconPosition::Start)
338 .icon_size(IconSize::XSmall)
339 .icon_color(Color::Muted)
340 .label_size(LabelSize::Small)
341 .style(ButtonStyle::Outlined)
342 .full_width()
343 .on_click(cx.listener(move |this, _, _window, cx| {
344 this.input.remove_model(ix);
345 cx.notify();
346 })),
347 )
348 })
349 }
350}
351
352impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
353
354impl Focusable for AddLlmProviderModal {
355 fn focus_handle(&self, _cx: &App) -> FocusHandle {
356 self.focus_handle.clone()
357 }
358}
359
360impl ModalView for AddLlmProviderModal {}
361
362impl Render for AddLlmProviderModal {
363 fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
364 let focus_handle = self.focus_handle(cx);
365
366 div()
367 .id("add-llm-provider-modal")
368 .key_context("AddLlmProviderModal")
369 .w(rems(34.))
370 .elevation_3(cx)
371 .on_action(cx.listener(Self::cancel))
372 .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
373 this.focus_handle(cx).focus(window);
374 }))
375 .child(
376 Modal::new("configure-context-server", None)
377 .header(ModalHeader::new().headline("Add LLM Provider").description(
378 match self.provider {
379 LlmCompatibleProvider::OpenAi => {
380 "This provider will use an OpenAI compatible API."
381 }
382 },
383 ))
384 .when_some(self.last_error.clone(), |this, error| {
385 this.section(
386 Section::new().child(
387 Banner::new()
388 .severity(ui::Severity::Warning)
389 .child(div().text_xs().child(error)),
390 ),
391 )
392 })
393 .child(
394 v_flex()
395 .id("modal_content")
396 .max_h_128()
397 .overflow_y_scroll()
398 .gap_2()
399 .child(self.render_section())
400 .child(self.render_model_section(cx)),
401 )
402 .footer(
403 ModalFooter::new().end_slot(
404 h_flex()
405 .gap_1()
406 .child(
407 Button::new("cancel", "Cancel")
408 .key_binding(
409 KeyBinding::for_action_in(
410 &menu::Cancel,
411 &focus_handle,
412 window,
413 cx,
414 )
415 .map(|kb| kb.size(rems_from_px(12.))),
416 )
417 .on_click(cx.listener(|this, _event, window, cx| {
418 this.cancel(&menu::Cancel, window, cx)
419 })),
420 )
421 .child(
422 Button::new("save-server", "Save Provider")
423 .key_binding(
424 KeyBinding::for_action_in(
425 &menu::Confirm,
426 &focus_handle,
427 window,
428 cx,
429 )
430 .map(|kb| kb.size(rems_from_px(12.))),
431 )
432 .on_click(cx.listener(|this, _event, window, cx| {
433 this.confirm(&menu::Confirm, window, cx)
434 })),
435 ),
436 ),
437 ),
438 )
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use editor::EditorSettings;
446 use fs::FakeFs;
447 use gpui::{TestAppContext, VisualTestContext};
448 use language::language_settings;
449 use language_model::{
450 LanguageModelProviderId, LanguageModelProviderName,
451 fake_provider::FakeLanguageModelProvider,
452 };
453 use project::Project;
454 use settings::{Settings as _, SettingsStore};
455 use util::path;
456
457 #[gpui::test]
458 async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
459 let cx = setup_test(cx).await;
460
461 assert_eq!(
462 save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
463 Some("Provider Name cannot be empty".into())
464 );
465
466 assert_eq!(
467 save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
468 Some("API URL cannot be empty".into())
469 );
470
471 assert_eq!(
472 save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
473 Some("API Key cannot be empty".into())
474 );
475
476 assert_eq!(
477 save_provider_validation_errors(
478 "someprovider",
479 "someurl",
480 "somekey",
481 vec![("", "200000", "200000", "32000")],
482 cx,
483 )
484 .await,
485 Some("Model Name cannot be empty".into())
486 );
487
488 assert_eq!(
489 save_provider_validation_errors(
490 "someprovider",
491 "someurl",
492 "somekey",
493 vec![("somemodel", "abc", "200000", "32000")],
494 cx,
495 )
496 .await,
497 Some("Max Tokens must be a number".into())
498 );
499
500 assert_eq!(
501 save_provider_validation_errors(
502 "someprovider",
503 "someurl",
504 "somekey",
505 vec![("somemodel", "200000", "abc", "32000")],
506 cx,
507 )
508 .await,
509 Some("Max Completion Tokens must be a number".into())
510 );
511
512 assert_eq!(
513 save_provider_validation_errors(
514 "someprovider",
515 "someurl",
516 "somekey",
517 vec![("somemodel", "200000", "200000", "abc")],
518 cx,
519 )
520 .await,
521 Some("Max Output Tokens must be a number".into())
522 );
523
524 assert_eq!(
525 save_provider_validation_errors(
526 "someprovider",
527 "someurl",
528 "somekey",
529 vec![
530 ("somemodel", "200000", "200000", "32000"),
531 ("somemodel", "200000", "200000", "32000"),
532 ],
533 cx,
534 )
535 .await,
536 Some("Model Names must be unique".into())
537 );
538 }
539
540 #[gpui::test]
541 async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
542 let cx = setup_test(cx).await;
543
544 cx.update(|_window, cx| {
545 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
546 registry.register_provider(
547 FakeLanguageModelProvider::new(
548 LanguageModelProviderId::new("someprovider"),
549 LanguageModelProviderName::new("Some Provider"),
550 ),
551 cx,
552 );
553 });
554 });
555
556 assert_eq!(
557 save_provider_validation_errors(
558 "someprovider",
559 "someurl",
560 "someapikey",
561 vec![("somemodel", "200000", "200000", "32000")],
562 cx,
563 )
564 .await,
565 Some("Provider Name is already taken by another provider".into())
566 );
567 }
568
569 async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
570 cx.update(|cx| {
571 let store = SettingsStore::test(cx);
572 cx.set_global(store);
573 workspace::init_settings(cx);
574 Project::init_settings(cx);
575 theme::init(theme::LoadThemes::JustBase, cx);
576 language_settings::init(cx);
577 EditorSettings::register(cx);
578 language_model::init_settings(cx);
579 language_models::init_settings(cx);
580 });
581
582 let fs = FakeFs::new(cx.executor());
583 cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
584 let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
585 let (_, cx) =
586 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
587
588 cx
589 }
590
591 async fn save_provider_validation_errors(
592 provider_name: &str,
593 api_url: &str,
594 api_key: &str,
595 models: Vec<(&str, &str, &str, &str)>,
596 cx: &mut VisualTestContext,
597 ) -> Option<SharedString> {
598 fn set_text(
599 input: &Entity<SingleLineInput>,
600 text: &str,
601 window: &mut Window,
602 cx: &mut App,
603 ) {
604 input.update(cx, |input, cx| {
605 input.editor().update(cx, |editor, cx| {
606 editor.set_text(text, window, cx);
607 });
608 });
609 }
610
611 let task = cx.update(|window, cx| {
612 let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
613 set_text(&input.provider_name, provider_name, window, cx);
614 set_text(&input.api_url, api_url, window, cx);
615 set_text(&input.api_key, api_key, window, cx);
616
617 for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
618 models.iter().enumerate()
619 {
620 if i >= input.models.len() {
621 input.models.push(ModelInput::new(window, cx));
622 }
623 let model = &mut input.models[i];
624 set_text(&model.name, name, window, cx);
625 set_text(&model.max_tokens, max_tokens, window, cx);
626 set_text(
627 &model.max_completion_tokens,
628 max_completion_tokens,
629 window,
630 cx,
631 );
632 set_text(&model.max_output_tokens, max_output_tokens, window, cx);
633 }
634 save_provider_to_settings(&input, cx)
635 });
636
637 task.await.err()
638 }
639}