1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashSet;
5use fs::Fs;
6use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
7use language_model::LanguageModelRegistry;
8use language_models::{
9 AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
10 provider::open_ai_compatible::AvailableModel,
11};
12use settings::update_settings_file;
13use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
14use ui_input::SingleLineInput;
15use workspace::{ModalView, Workspace};
16
17#[derive(Clone, Copy)]
18pub enum LlmCompatibleProvider {
19 OpenAi,
20}
21
22impl LlmCompatibleProvider {
23 fn name(&self) -> &'static str {
24 match self {
25 LlmCompatibleProvider::OpenAi => "OpenAI",
26 }
27 }
28
29 fn api_url(&self) -> &'static str {
30 match self {
31 LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
32 }
33 }
34}
35
36struct AddLlmProviderInput {
37 provider_name: Entity<SingleLineInput>,
38 api_url: Entity<SingleLineInput>,
39 api_key: Entity<SingleLineInput>,
40 models: Vec<ModelInput>,
41}
42
43impl AddLlmProviderInput {
44 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
45 let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
46 let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
47 let api_key = single_line_input(
48 "API Key",
49 "000000000000000000000000000000000000000000000000",
50 None,
51 window,
52 cx,
53 );
54
55 Self {
56 provider_name,
57 api_url,
58 api_key,
59 models: vec![ModelInput::new(window, cx)],
60 }
61 }
62
63 fn add_model(&mut self, window: &mut Window, cx: &mut App) {
64 self.models.push(ModelInput::new(window, cx));
65 }
66
67 fn remove_model(&mut self, index: usize) {
68 self.models.remove(index);
69 }
70}
71
72struct ModelInput {
73 name: Entity<SingleLineInput>,
74 max_completion_tokens: Entity<SingleLineInput>,
75 max_output_tokens: Entity<SingleLineInput>,
76 max_tokens: Entity<SingleLineInput>,
77}
78
79impl ModelInput {
80 fn new(window: &mut Window, cx: &mut App) -> Self {
81 let model_name = single_line_input(
82 "Model Name",
83 "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
84 None,
85 window,
86 cx,
87 );
88 let max_completion_tokens = single_line_input(
89 "Max Completion Tokens",
90 "200000",
91 Some("200000"),
92 window,
93 cx,
94 );
95 let max_output_tokens = single_line_input(
96 "Max Output Tokens",
97 "Max Output Tokens",
98 Some("32000"),
99 window,
100 cx,
101 );
102 let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
103 Self {
104 name: model_name,
105 max_completion_tokens,
106 max_output_tokens,
107 max_tokens,
108 }
109 }
110
111 fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
112 let name = self.name.read(cx).text(cx);
113 if name.is_empty() {
114 return Err(SharedString::from("Model Name cannot be empty"));
115 }
116 Ok(AvailableModel {
117 name,
118 display_name: None,
119 max_completion_tokens: Some(
120 self.max_completion_tokens
121 .read(cx)
122 .text(cx)
123 .parse::<u64>()
124 .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
125 ),
126 max_output_tokens: Some(
127 self.max_output_tokens
128 .read(cx)
129 .text(cx)
130 .parse::<u64>()
131 .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
132 ),
133 max_tokens: self
134 .max_tokens
135 .read(cx)
136 .text(cx)
137 .parse::<u64>()
138 .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
139 })
140 }
141}
142
143fn single_line_input(
144 label: impl Into<SharedString>,
145 placeholder: impl Into<SharedString>,
146 text: Option<&str>,
147 window: &mut Window,
148 cx: &mut App,
149) -> Entity<SingleLineInput> {
150 cx.new(|cx| {
151 let input = SingleLineInput::new(window, cx, placeholder).label(label);
152 if let Some(text) = text {
153 input
154 .editor()
155 .update(cx, |editor, cx| editor.set_text(text, window, cx));
156 }
157 input
158 })
159}
160
161fn save_provider_to_settings(
162 input: &AddLlmProviderInput,
163 cx: &mut App,
164) -> Task<Result<(), SharedString>> {
165 let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
166 if provider_name.is_empty() {
167 return Task::ready(Err("Provider Name cannot be empty".into()));
168 }
169
170 if LanguageModelRegistry::read_global(cx)
171 .providers()
172 .iter()
173 .any(|provider| {
174 provider.id().0.as_ref() == provider_name.as_ref()
175 || provider.name().0.as_ref() == provider_name.as_ref()
176 })
177 {
178 return Task::ready(Err(
179 "Provider Name is already taken by another provider".into()
180 ));
181 }
182
183 let api_url = input.api_url.read(cx).text(cx);
184 if api_url.is_empty() {
185 return Task::ready(Err("API URL cannot be empty".into()));
186 }
187
188 let api_key = input.api_key.read(cx).text(cx);
189 if api_key.is_empty() {
190 return Task::ready(Err("API Key cannot be empty".into()));
191 }
192
193 let mut models = Vec::new();
194 let mut model_names: HashSet<String> = HashSet::default();
195 for model in &input.models {
196 match model.parse(cx) {
197 Ok(model) => {
198 if !model_names.insert(model.name.clone()) {
199 return Task::ready(Err("Model Names must be unique".into()));
200 }
201 models.push(model)
202 }
203 Err(err) => return Task::ready(Err(err)),
204 }
205 }
206
207 let fs = <dyn Fs>::global(cx);
208 let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
209 cx.spawn(async move |cx| {
210 task.await
211 .map_err(|_| "Failed to write API key to keychain")?;
212 cx.update(|cx| {
213 update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
214 settings.openai_compatible.get_or_insert_default().insert(
215 provider_name,
216 OpenAiCompatibleSettingsContent {
217 api_url,
218 available_models: models,
219 },
220 );
221 });
222 })
223 .ok();
224 Ok(())
225 })
226}
227
228pub struct AddLlmProviderModal {
229 provider: LlmCompatibleProvider,
230 input: AddLlmProviderInput,
231 focus_handle: FocusHandle,
232 last_error: Option<SharedString>,
233}
234
235impl AddLlmProviderModal {
236 pub fn toggle(
237 provider: LlmCompatibleProvider,
238 workspace: &mut Workspace,
239 window: &mut Window,
240 cx: &mut Context<Workspace>,
241 ) {
242 workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
243 }
244
245 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
246 Self {
247 input: AddLlmProviderInput::new(provider, window, cx),
248 provider,
249 last_error: None,
250 focus_handle: cx.focus_handle(),
251 }
252 }
253
254 fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
255 let task = save_provider_to_settings(&self.input, cx);
256 cx.spawn(async move |this, cx| {
257 let result = task.await;
258 this.update(cx, |this, cx| match result {
259 Ok(_) => {
260 cx.emit(DismissEvent);
261 }
262 Err(error) => {
263 this.last_error = Some(error);
264 cx.notify();
265 }
266 })
267 })
268 .detach_and_log_err(cx);
269 }
270
271 fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
272 cx.emit(DismissEvent);
273 }
274
275 fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
276 v_flex()
277 .mt_1()
278 .gap_2()
279 .child(
280 h_flex()
281 .justify_between()
282 .child(Label::new("Models").size(LabelSize::Small))
283 .child(
284 Button::new("add-model", "Add Model")
285 .icon(IconName::Plus)
286 .icon_position(IconPosition::Start)
287 .icon_size(IconSize::XSmall)
288 .icon_color(Color::Muted)
289 .label_size(LabelSize::Small)
290 .on_click(cx.listener(|this, _, window, cx| {
291 this.input.add_model(window, cx);
292 cx.notify();
293 })),
294 ),
295 )
296 .children(
297 self.input
298 .models
299 .iter()
300 .enumerate()
301 .map(|(ix, _)| self.render_model(ix, cx)),
302 )
303 }
304
305 fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
306 let has_more_than_one_model = self.input.models.len() > 1;
307 let model = &self.input.models[ix];
308
309 v_flex()
310 .p_2()
311 .gap_2()
312 .rounded_sm()
313 .border_1()
314 .border_dashed()
315 .border_color(cx.theme().colors().border.opacity(0.6))
316 .bg(cx.theme().colors().element_active.opacity(0.15))
317 .child(model.name.clone())
318 .child(
319 h_flex()
320 .gap_2()
321 .child(model.max_completion_tokens.clone())
322 .child(model.max_output_tokens.clone()),
323 )
324 .child(model.max_tokens.clone())
325 .when(has_more_than_one_model, |this| {
326 this.child(
327 Button::new(("remove-model", ix), "Remove Model")
328 .icon(IconName::Trash)
329 .icon_position(IconPosition::Start)
330 .icon_size(IconSize::XSmall)
331 .icon_color(Color::Muted)
332 .label_size(LabelSize::Small)
333 .style(ButtonStyle::Outlined)
334 .full_width()
335 .on_click(cx.listener(move |this, _, _window, cx| {
336 this.input.remove_model(ix);
337 cx.notify();
338 })),
339 )
340 })
341 }
342}
343
344impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
345
346impl Focusable for AddLlmProviderModal {
347 fn focus_handle(&self, _cx: &App) -> FocusHandle {
348 self.focus_handle.clone()
349 }
350}
351
352impl ModalView for AddLlmProviderModal {}
353
354impl Render for AddLlmProviderModal {
355 fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
356 let focus_handle = self.focus_handle(cx);
357
358 div()
359 .id("add-llm-provider-modal")
360 .key_context("AddLlmProviderModal")
361 .w(rems(34.))
362 .elevation_3(cx)
363 .on_action(cx.listener(Self::cancel))
364 .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
365 this.focus_handle(cx).focus(window);
366 }))
367 .child(
368 Modal::new("configure-context-server", None)
369 .header(ModalHeader::new().headline("Add LLM Provider").description(
370 match self.provider {
371 LlmCompatibleProvider::OpenAi => {
372 "This provider will use an OpenAI compatible API."
373 }
374 },
375 ))
376 .when_some(self.last_error.clone(), |this, error| {
377 this.section(
378 Section::new().child(
379 Banner::new()
380 .severity(ui::Severity::Warning)
381 .child(div().text_xs().child(error)),
382 ),
383 )
384 })
385 .child(
386 v_flex()
387 .id("modal_content")
388 .size_full()
389 .max_h_128()
390 .overflow_y_scroll()
391 .px(DynamicSpacing::Base12.rems(cx))
392 .gap(DynamicSpacing::Base04.rems(cx))
393 .child(self.input.provider_name.clone())
394 .child(self.input.api_url.clone())
395 .child(self.input.api_key.clone())
396 .child(self.render_model_section(cx)),
397 )
398 .footer(
399 ModalFooter::new().end_slot(
400 h_flex()
401 .gap_1()
402 .child(
403 Button::new("cancel", "Cancel")
404 .key_binding(
405 KeyBinding::for_action_in(
406 &menu::Cancel,
407 &focus_handle,
408 window,
409 cx,
410 )
411 .map(|kb| kb.size(rems_from_px(12.))),
412 )
413 .on_click(cx.listener(|this, _event, window, cx| {
414 this.cancel(&menu::Cancel, window, cx)
415 })),
416 )
417 .child(
418 Button::new("save-server", "Save Provider")
419 .key_binding(
420 KeyBinding::for_action_in(
421 &menu::Confirm,
422 &focus_handle,
423 window,
424 cx,
425 )
426 .map(|kb| kb.size(rems_from_px(12.))),
427 )
428 .on_click(cx.listener(|this, _event, window, cx| {
429 this.confirm(&menu::Confirm, window, cx)
430 })),
431 ),
432 ),
433 ),
434 )
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use editor::EditorSettings;
442 use fs::FakeFs;
443 use gpui::{TestAppContext, VisualTestContext};
444 use language::language_settings;
445 use language_model::{
446 LanguageModelProviderId, LanguageModelProviderName,
447 fake_provider::FakeLanguageModelProvider,
448 };
449 use project::Project;
450 use settings::{Settings as _, SettingsStore};
451 use util::path;
452
453 #[gpui::test]
454 async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
455 let cx = setup_test(cx).await;
456
457 assert_eq!(
458 save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
459 Some("Provider Name cannot be empty".into())
460 );
461
462 assert_eq!(
463 save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
464 Some("API URL cannot be empty".into())
465 );
466
467 assert_eq!(
468 save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
469 Some("API Key cannot be empty".into())
470 );
471
472 assert_eq!(
473 save_provider_validation_errors(
474 "someprovider",
475 "someurl",
476 "somekey",
477 vec![("", "200000", "200000", "32000")],
478 cx,
479 )
480 .await,
481 Some("Model Name cannot be empty".into())
482 );
483
484 assert_eq!(
485 save_provider_validation_errors(
486 "someprovider",
487 "someurl",
488 "somekey",
489 vec![("somemodel", "abc", "200000", "32000")],
490 cx,
491 )
492 .await,
493 Some("Max Tokens must be a number".into())
494 );
495
496 assert_eq!(
497 save_provider_validation_errors(
498 "someprovider",
499 "someurl",
500 "somekey",
501 vec![("somemodel", "200000", "abc", "32000")],
502 cx,
503 )
504 .await,
505 Some("Max Completion Tokens must be a number".into())
506 );
507
508 assert_eq!(
509 save_provider_validation_errors(
510 "someprovider",
511 "someurl",
512 "somekey",
513 vec![("somemodel", "200000", "200000", "abc")],
514 cx,
515 )
516 .await,
517 Some("Max Output Tokens must be a number".into())
518 );
519
520 assert_eq!(
521 save_provider_validation_errors(
522 "someprovider",
523 "someurl",
524 "somekey",
525 vec![
526 ("somemodel", "200000", "200000", "32000"),
527 ("somemodel", "200000", "200000", "32000"),
528 ],
529 cx,
530 )
531 .await,
532 Some("Model Names must be unique".into())
533 );
534 }
535
536 #[gpui::test]
537 async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
538 let cx = setup_test(cx).await;
539
540 cx.update(|_window, cx| {
541 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
542 registry.register_provider(
543 FakeLanguageModelProvider::new(
544 LanguageModelProviderId::new("someprovider"),
545 LanguageModelProviderName::new("Some Provider"),
546 ),
547 cx,
548 );
549 });
550 });
551
552 assert_eq!(
553 save_provider_validation_errors(
554 "someprovider",
555 "someurl",
556 "someapikey",
557 vec![("somemodel", "200000", "200000", "32000")],
558 cx,
559 )
560 .await,
561 Some("Provider Name is already taken by another provider".into())
562 );
563 }
564
565 async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
566 cx.update(|cx| {
567 let store = SettingsStore::test(cx);
568 cx.set_global(store);
569 workspace::init_settings(cx);
570 Project::init_settings(cx);
571 theme::init(theme::LoadThemes::JustBase, cx);
572 language_settings::init(cx);
573 EditorSettings::register(cx);
574 language_model::init_settings(cx);
575 language_models::init_settings(cx);
576 });
577
578 let fs = FakeFs::new(cx.executor());
579 cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
580 let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
581 let (_, cx) =
582 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
583
584 cx
585 }
586
587 async fn save_provider_validation_errors(
588 provider_name: &str,
589 api_url: &str,
590 api_key: &str,
591 models: Vec<(&str, &str, &str, &str)>,
592 cx: &mut VisualTestContext,
593 ) -> Option<SharedString> {
594 fn set_text(
595 input: &Entity<SingleLineInput>,
596 text: &str,
597 window: &mut Window,
598 cx: &mut App,
599 ) {
600 input.update(cx, |input, cx| {
601 input.editor().update(cx, |editor, cx| {
602 editor.set_text(text, window, cx);
603 });
604 });
605 }
606
607 let task = cx.update(|window, cx| {
608 let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
609 set_text(&input.provider_name, provider_name, window, cx);
610 set_text(&input.api_url, api_url, window, cx);
611 set_text(&input.api_key, api_key, window, cx);
612
613 for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
614 models.iter().enumerate()
615 {
616 if i >= input.models.len() {
617 input.models.push(ModelInput::new(window, cx));
618 }
619 let model = &mut input.models[i];
620 set_text(&model.name, name, window, cx);
621 set_text(&model.max_tokens, max_tokens, window, cx);
622 set_text(
623 &model.max_completion_tokens,
624 max_completion_tokens,
625 window,
626 cx,
627 );
628 set_text(&model.max_output_tokens, max_output_tokens, window, cx);
629 }
630 save_provider_to_settings(&input, cx)
631 });
632
633 task.await.err()
634 }
635}