1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::LanguageModelAuthConfig;
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
20 TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(self.provider_info.name.clone())
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 self.state.read(cx).is_authenticated
183 }
184
185 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
186 let extension = self.extension.clone();
187 let provider_id = self.provider_info.id.clone();
188 let state = self.state.clone();
189
190 cx.spawn(async move |cx| {
191 let result = extension
192 .call(|extension, store| {
193 async move {
194 extension
195 .call_llm_provider_authenticate(store, &provider_id)
196 .await
197 }
198 .boxed()
199 })
200 .await;
201
202 match result {
203 Ok(Ok(Ok(()))) => {
204 cx.update(|cx| {
205 state.update(cx, |state, _| {
206 state.is_authenticated = true;
207 });
208 })?;
209 Ok(())
210 }
211 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
212 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
213 Err(e) => Err(AuthenticateError::Other(e)),
214 }
215 })
216 }
217
218 fn configuration_view(
219 &self,
220 _target_agent: ConfigurationViewTargetAgent,
221 window: &mut Window,
222 cx: &mut App,
223 ) -> AnyView {
224 let credential_key = self.credential_key();
225 let extension = self.extension.clone();
226 let extension_provider_id = self.provider_info.id.clone();
227 let full_provider_id = self.provider_id_string();
228 let state = self.state.clone();
229 let auth_config = self.auth_config.clone();
230
231 cx.new(|cx| {
232 ExtensionProviderConfigurationView::new(
233 credential_key,
234 extension,
235 extension_provider_id,
236 full_provider_id,
237 auth_config,
238 state,
239 window,
240 cx,
241 )
242 })
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 let extension = self.extension.clone();
248 let provider_id = self.provider_info.id.clone();
249 let state = self.state.clone();
250 let credential_key = self.credential_key();
251
252 let credentials_provider = <dyn CredentialsProvider>::global(cx);
253
254 cx.spawn(async move |cx| {
255 // Delete from system keychain
256 credentials_provider
257 .delete_credentials(&credential_key, cx)
258 .await
259 .log_err();
260
261 // Call extension's reset_credentials
262 let result = extension
263 .call(|extension, store| {
264 async move {
265 extension
266 .call_llm_provider_reset_credentials(store, &provider_id)
267 .await
268 }
269 .boxed()
270 })
271 .await;
272
273 // Update state
274 cx.update(|cx| {
275 state.update(cx, |state, _| {
276 state.is_authenticated = false;
277 });
278 })?;
279
280 match result {
281 Ok(Ok(Ok(()))) => Ok(()),
282 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
283 Ok(Err(e)) => Err(e),
284 Err(e) => Err(e),
285 }
286 })
287 }
288}
289
290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
291 type ObservableEntity = ExtensionLlmProviderState;
292
293 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
294 Some(self.state.clone())
295 }
296
297 fn subscribe<T: 'static>(
298 &self,
299 cx: &mut Context<T>,
300 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
301 ) -> Option<Subscription> {
302 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
303 }
304}
305
306/// Configuration view for extension-based LLM providers.
307struct ExtensionProviderConfigurationView {
308 credential_key: String,
309 extension: WasmExtension,
310 extension_provider_id: String,
311 full_provider_id: String,
312 auth_config: Option<LanguageModelAuthConfig>,
313 state: Entity<ExtensionLlmProviderState>,
314 settings_markdown: Option<Entity<Markdown>>,
315 api_key_editor: Entity<Editor>,
316 loading_settings: bool,
317 loading_credentials: bool,
318 _subscriptions: Vec<Subscription>,
319}
320
321impl ExtensionProviderConfigurationView {
322 fn new(
323 credential_key: String,
324 extension: WasmExtension,
325 extension_provider_id: String,
326 full_provider_id: String,
327 auth_config: Option<LanguageModelAuthConfig>,
328 state: Entity<ExtensionLlmProviderState>,
329 window: &mut Window,
330 cx: &mut Context<Self>,
331 ) -> Self {
332 // Subscribe to state changes
333 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
334 cx.notify();
335 });
336
337 // Create API key editor
338 let api_key_editor = cx.new(|cx| {
339 let mut editor = Editor::single_line(window, cx);
340 editor.set_placeholder_text("Enter API key...", window, cx);
341 editor
342 });
343
344 let mut this = Self {
345 credential_key,
346 extension,
347 extension_provider_id,
348 full_provider_id,
349 auth_config,
350 state,
351 settings_markdown: None,
352 api_key_editor,
353 loading_settings: true,
354 loading_credentials: true,
355 _subscriptions: vec![state_subscription],
356 };
357
358 // Load settings text from extension
359 this.load_settings_text(cx);
360
361 // Load existing credentials
362 this.load_credentials(cx);
363
364 this
365 }
366
367 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
368 let extension = self.extension.clone();
369 let provider_id = self.extension_provider_id.clone();
370
371 cx.spawn(async move |this, cx| {
372 let result = extension
373 .call({
374 let provider_id = provider_id.clone();
375 |ext, store| {
376 async move {
377 ext.call_llm_provider_settings_markdown(store, &provider_id)
378 .await
379 }
380 .boxed()
381 }
382 })
383 .await;
384
385 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
386
387 this.update(cx, |this, cx| {
388 this.loading_settings = false;
389 if let Some(text) = settings_text {
390 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
391 this.settings_markdown = Some(markdown);
392 }
393 cx.notify();
394 })
395 .log_err();
396 })
397 .detach();
398 }
399
400 fn load_credentials(&mut self, cx: &mut Context<Self>) {
401 let credential_key = self.credential_key.clone();
402 let credentials_provider = <dyn CredentialsProvider>::global(cx);
403 let state = self.state.clone();
404
405 // Check if we should use env var (already set in state during provider construction)
406 let api_key_from_env = self.state.read(cx).api_key_from_env;
407
408 cx.spawn(async move |this, cx| {
409 // If using env var, we're already authenticated
410 if api_key_from_env {
411 this.update(cx, |this, cx| {
412 this.loading_credentials = false;
413 cx.notify();
414 })
415 .log_err();
416 return;
417 }
418
419 let credentials = credentials_provider
420 .read_credentials(&credential_key, cx)
421 .await
422 .log_err()
423 .flatten();
424
425 let has_credentials = credentials.is_some();
426
427 // Update authentication state based on stored credentials
428 let _ = cx.update(|cx| {
429 state.update(cx, |state, cx| {
430 state.is_authenticated = has_credentials;
431 cx.notify();
432 });
433 });
434
435 this.update(cx, |this, cx| {
436 this.loading_credentials = false;
437 cx.notify();
438 })
439 .log_err();
440 })
441 .detach();
442 }
443
444 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
445 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
446 let env_var_name = match &self.auth_config {
447 Some(config) => config.env_var.clone(),
448 None => return,
449 };
450
451 let state = self.state.clone();
452 let currently_allowed = self.state.read(cx).env_var_allowed;
453
454 // Update settings file
455 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
456 let providers = settings
457 .extension
458 .allowed_env_var_providers
459 .get_or_insert_with(Vec::new);
460
461 if currently_allowed {
462 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
463 } else {
464 if !providers
465 .iter()
466 .any(|id| id.as_ref() == full_provider_id.as_ref())
467 {
468 providers.push(full_provider_id.clone());
469 }
470 }
471 });
472
473 // Update local state
474 let new_allowed = !currently_allowed;
475 let new_from_env = if new_allowed {
476 if let Some(var_name) = &env_var_name {
477 if let Ok(value) = std::env::var(var_name) {
478 !value.is_empty()
479 } else {
480 false
481 }
482 } else {
483 false
484 }
485 } else {
486 false
487 };
488
489 state.update(cx, |state, cx| {
490 state.env_var_allowed = new_allowed;
491 state.api_key_from_env = new_from_env;
492 if new_from_env {
493 state.is_authenticated = true;
494 }
495 cx.notify();
496 });
497
498 // If env var is being enabled, clear any stored keychain credentials
499 // so there's only one source of truth for the API key
500 if new_allowed {
501 let credential_key = self.credential_key.clone();
502 let credentials_provider = <dyn CredentialsProvider>::global(cx);
503 cx.spawn(async move |_this, cx| {
504 credentials_provider
505 .delete_credentials(&credential_key, cx)
506 .await
507 .log_err();
508 })
509 .detach();
510 }
511
512 // If env var is being disabled, reload credentials from keychain
513 if !new_allowed {
514 self.reload_keychain_credentials(cx);
515 }
516
517 cx.notify();
518 }
519
520 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
521 let credential_key = self.credential_key.clone();
522 let credentials_provider = <dyn CredentialsProvider>::global(cx);
523 let state = self.state.clone();
524
525 cx.spawn(async move |_this, cx| {
526 let credentials = credentials_provider
527 .read_credentials(&credential_key, cx)
528 .await
529 .log_err()
530 .flatten();
531
532 let has_credentials = credentials.is_some();
533
534 let _ = cx.update(|cx| {
535 state.update(cx, |state, cx| {
536 state.is_authenticated = has_credentials;
537 cx.notify();
538 });
539 });
540 })
541 .detach();
542 }
543
544 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
545 let api_key = self.api_key_editor.read(cx).text(cx);
546 if api_key.is_empty() {
547 return;
548 }
549
550 // Clear the editor
551 self.api_key_editor
552 .update(cx, |editor, cx| editor.set_text("", window, cx));
553
554 let credential_key = self.credential_key.clone();
555 let credentials_provider = <dyn CredentialsProvider>::global(cx);
556 let state = self.state.clone();
557
558 cx.spawn(async move |_this, cx| {
559 // Store in system keychain
560 credentials_provider
561 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
562 .await
563 .log_err();
564
565 // Update state to authenticated
566 let _ = cx.update(|cx| {
567 state.update(cx, |state, cx| {
568 state.is_authenticated = true;
569 cx.notify();
570 });
571 });
572 })
573 .detach();
574 }
575
576 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
577 // Clear the editor
578 self.api_key_editor
579 .update(cx, |editor, cx| editor.set_text("", window, cx));
580
581 let credential_key = self.credential_key.clone();
582 let credentials_provider = <dyn CredentialsProvider>::global(cx);
583 let state = self.state.clone();
584
585 cx.spawn(async move |_this, cx| {
586 // Delete from system keychain
587 credentials_provider
588 .delete_credentials(&credential_key, cx)
589 .await
590 .log_err();
591
592 // Update state to unauthenticated
593 let _ = cx.update(|cx| {
594 state.update(cx, |state, cx| {
595 state.is_authenticated = false;
596 cx.notify();
597 });
598 });
599 })
600 .detach();
601 }
602
603 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
604 self.state.read(cx).is_authenticated
605 }
606}
607
608impl gpui::Render for ExtensionProviderConfigurationView {
609 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
610 let is_loading = self.loading_settings || self.loading_credentials;
611 let is_authenticated = self.is_authenticated(cx);
612 let env_var_allowed = self.state.read(cx).env_var_allowed;
613 let api_key_from_env = self.state.read(cx).api_key_from_env;
614
615 if is_loading {
616 return v_flex()
617 .gap_2()
618 .child(Label::new("Loading...").color(Color::Muted))
619 .into_any_element();
620 }
621
622 let mut content = v_flex().gap_4().size_full();
623
624 // Render settings markdown if available
625 if let Some(markdown) = &self.settings_markdown {
626 let style = settings_markdown_style(_window, cx);
627 content = content.child(
628 div()
629 .p_2()
630 .rounded_md()
631 .bg(cx.theme().colors().surface_background)
632 .child(MarkdownElement::new(markdown.clone(), style)),
633 );
634 }
635
636 // Render env var checkbox if the extension specifies an env var
637 if let Some(auth_config) = &self.auth_config {
638 if let Some(env_var_name) = &auth_config.env_var {
639 let env_var_name = env_var_name.clone();
640 let checkbox_label =
641 format!("Read API key from {} environment variable", env_var_name);
642
643 content = content.child(
644 h_flex()
645 .gap_2()
646 .child(
647 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
648 .on_click(cx.listener(|this, _, _window, cx| {
649 this.toggle_env_var_permission(cx);
650 })),
651 )
652 .child(Label::new(checkbox_label).size(LabelSize::Small)),
653 );
654
655 // Show status if env var is allowed
656 if env_var_allowed {
657 if api_key_from_env {
658 content = content.child(
659 h_flex()
660 .gap_2()
661 .child(
662 ui::Icon::new(ui::IconName::Check)
663 .color(Color::Success)
664 .size(ui::IconSize::Small),
665 )
666 .child(
667 Label::new(format!("API key loaded from {}", env_var_name))
668 .color(Color::Success),
669 ),
670 );
671 return content.into_any_element();
672 } else {
673 content = content.child(
674 h_flex()
675 .gap_2()
676 .child(
677 ui::Icon::new(ui::IconName::Warning)
678 .color(Color::Warning)
679 .size(ui::IconSize::Small),
680 )
681 .child(
682 Label::new(format!(
683 "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
684 env_var_name
685 ))
686 .color(Color::Warning)
687 .size(LabelSize::Small),
688 ),
689 );
690 }
691 }
692 }
693 }
694
695 // Render API key section
696 if is_authenticated && !api_key_from_env {
697 content = content.child(
698 v_flex()
699 .gap_2()
700 .child(
701 h_flex()
702 .gap_2()
703 .child(
704 ui::Icon::new(ui::IconName::Check)
705 .color(Color::Success)
706 .size(ui::IconSize::Small),
707 )
708 .child(Label::new("API key configured").color(Color::Success)),
709 )
710 .child(
711 ui::Button::new("reset-api-key", "Reset API Key")
712 .style(ui::ButtonStyle::Subtle)
713 .on_click(cx.listener(|this, _, window, cx| {
714 this.reset_api_key(window, cx);
715 })),
716 ),
717 );
718 } else if !api_key_from_env {
719 let credential_label = self
720 .auth_config
721 .as_ref()
722 .and_then(|c| c.credential_label.clone())
723 .unwrap_or_else(|| "API Key".to_string());
724
725 content = content.child(
726 v_flex()
727 .gap_2()
728 .on_action(cx.listener(Self::save_api_key))
729 .child(
730 Label::new(credential_label)
731 .size(LabelSize::Small)
732 .color(Color::Muted),
733 )
734 .child(self.api_key_editor.clone())
735 .child(
736 Label::new("Enter your API key and press Enter to save")
737 .size(LabelSize::Small)
738 .color(Color::Muted),
739 ),
740 );
741 }
742
743 content.into_any_element()
744 }
745}
746
747impl Focusable for ExtensionProviderConfigurationView {
748 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
749 self.api_key_editor.focus_handle(cx)
750 }
751}
752
753fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
754 let theme_settings = ThemeSettings::get_global(cx);
755 let colors = cx.theme().colors();
756 let mut text_style = window.text_style();
757 text_style.refine(&TextStyleRefinement {
758 font_family: Some(theme_settings.ui_font.family.clone()),
759 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
760 font_features: Some(theme_settings.ui_font.features.clone()),
761 color: Some(colors.text),
762 ..Default::default()
763 });
764
765 MarkdownStyle {
766 base_text_style: text_style,
767 selection_background_color: colors.element_selection_background,
768 inline_code: TextStyleRefinement {
769 background_color: Some(colors.editor_background),
770 ..Default::default()
771 },
772 link: TextStyleRefinement {
773 color: Some(colors.text_accent),
774 underline: Some(UnderlineStyle {
775 color: Some(colors.text_accent.opacity(0.5)),
776 thickness: px(1.),
777 ..Default::default()
778 }),
779 ..Default::default()
780 },
781 syntax: cx.theme().syntax().clone(),
782 ..Default::default()
783 }
784}
785
786/// An extension-based language model.
787pub struct ExtensionLanguageModel {
788 extension: WasmExtension,
789 model_info: LlmModelInfo,
790 provider_id: LanguageModelProviderId,
791 provider_name: LanguageModelProviderName,
792 provider_info: LlmProviderInfo,
793}
794
795impl LanguageModel for ExtensionLanguageModel {
796 fn id(&self) -> LanguageModelId {
797 LanguageModelId::from(self.model_info.id.clone())
798 }
799
800 fn name(&self) -> LanguageModelName {
801 LanguageModelName::from(self.model_info.name.clone())
802 }
803
804 fn provider_id(&self) -> LanguageModelProviderId {
805 self.provider_id.clone()
806 }
807
808 fn provider_name(&self) -> LanguageModelProviderName {
809 self.provider_name.clone()
810 }
811
812 fn telemetry_id(&self) -> String {
813 format!("extension-{}", self.model_info.id)
814 }
815
816 fn supports_images(&self) -> bool {
817 self.model_info.capabilities.supports_images
818 }
819
820 fn supports_tools(&self) -> bool {
821 self.model_info.capabilities.supports_tools
822 }
823
824 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
825 match choice {
826 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
827 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
828 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
829 }
830 }
831
832 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
833 match self.model_info.capabilities.tool_input_format {
834 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
835 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
836 }
837 }
838
839 fn max_token_count(&self) -> u64 {
840 self.model_info.max_token_count
841 }
842
843 fn max_output_tokens(&self) -> Option<u64> {
844 self.model_info.max_output_tokens
845 }
846
847 fn count_tokens(
848 &self,
849 request: LanguageModelRequest,
850 cx: &App,
851 ) -> BoxFuture<'static, Result<u64>> {
852 let extension = self.extension.clone();
853 let provider_id = self.provider_info.id.clone();
854 let model_id = self.model_info.id.clone();
855
856 let wit_request = convert_request_to_wit(request);
857
858 cx.background_spawn(async move {
859 extension
860 .call({
861 let provider_id = provider_id.clone();
862 let model_id = model_id.clone();
863 let wit_request = wit_request.clone();
864 |ext, store| {
865 async move {
866 let count = ext
867 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
868 .await?
869 .map_err(|e| anyhow!("{}", e))?;
870 Ok(count)
871 }
872 .boxed()
873 }
874 })
875 .await?
876 })
877 .boxed()
878 }
879
880 fn stream_completion(
881 &self,
882 request: LanguageModelRequest,
883 _cx: &AsyncApp,
884 ) -> BoxFuture<
885 'static,
886 Result<
887 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
888 LanguageModelCompletionError,
889 >,
890 > {
891 let extension = self.extension.clone();
892 let provider_id = self.provider_info.id.clone();
893 let model_id = self.model_info.id.clone();
894
895 let wit_request = convert_request_to_wit(request);
896
897 async move {
898 // Start the stream
899 let stream_id_result = extension
900 .call({
901 let provider_id = provider_id.clone();
902 let model_id = model_id.clone();
903 let wit_request = wit_request.clone();
904 |ext, store| {
905 async move {
906 let id = ext
907 .call_llm_stream_completion_start(
908 store,
909 &provider_id,
910 &model_id,
911 &wit_request,
912 )
913 .await?
914 .map_err(|e| anyhow!("{}", e))?;
915 Ok(id)
916 }
917 .boxed()
918 }
919 })
920 .await;
921
922 let stream_id = stream_id_result
923 .map_err(LanguageModelCompletionError::Other)?
924 .map_err(LanguageModelCompletionError::Other)?;
925
926 // Create a stream that polls for events
927 let stream = futures::stream::unfold(
928 (extension.clone(), stream_id, false),
929 move |(extension, stream_id, done)| async move {
930 if done {
931 return None;
932 }
933
934 let result = extension
935 .call({
936 let stream_id = stream_id.clone();
937 |ext, store| {
938 async move {
939 let event = ext
940 .call_llm_stream_completion_next(store, &stream_id)
941 .await?
942 .map_err(|e| anyhow!("{}", e))?;
943 Ok(event)
944 }
945 .boxed()
946 }
947 })
948 .await
949 .and_then(|inner| inner);
950
951 match result {
952 Ok(Some(event)) => {
953 let converted = convert_completion_event(event);
954 let is_done =
955 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
956 Some((converted, (extension, stream_id, is_done)))
957 }
958 Ok(None) => {
959 // Stream complete, close it
960 let _ = extension
961 .call({
962 let stream_id = stream_id.clone();
963 |ext, store| {
964 async move {
965 ext.call_llm_stream_completion_close(store, &stream_id)
966 .await?;
967 Ok::<(), anyhow::Error>(())
968 }
969 .boxed()
970 }
971 })
972 .await;
973 None
974 }
975 Err(e) => Some((
976 Err(LanguageModelCompletionError::Other(e)),
977 (extension, stream_id, true),
978 )),
979 }
980 },
981 );
982
983 Ok(stream.boxed())
984 }
985 .boxed()
986 }
987
988 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
989 // Extensions can implement this via llm_cache_configuration
990 None
991 }
992}
993
994fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
995 use language_model::{MessageContent, Role};
996
997 let messages: Vec<LlmRequestMessage> = request
998 .messages
999 .into_iter()
1000 .map(|msg| {
1001 let role = match msg.role {
1002 Role::User => LlmMessageRole::User,
1003 Role::Assistant => LlmMessageRole::Assistant,
1004 Role::System => LlmMessageRole::System,
1005 };
1006
1007 let content: Vec<LlmMessageContent> = msg
1008 .content
1009 .into_iter()
1010 .map(|c| match c {
1011 MessageContent::Text(text) => LlmMessageContent::Text(text),
1012 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1013 source: image.source.to_string(),
1014 width: Some(image.size.width.0 as u32),
1015 height: Some(image.size.height.0 as u32),
1016 }),
1017 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1018 id: tool_use.id.to_string(),
1019 name: tool_use.name.to_string(),
1020 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1021 thought_signature: tool_use.thought_signature,
1022 }),
1023 MessageContent::ToolResult(tool_result) => {
1024 let content = match tool_result.content {
1025 language_model::LanguageModelToolResultContent::Text(text) => {
1026 LlmToolResultContent::Text(text.to_string())
1027 }
1028 language_model::LanguageModelToolResultContent::Image(image) => {
1029 LlmToolResultContent::Image(LlmImageData {
1030 source: image.source.to_string(),
1031 width: Some(image.size.width.0 as u32),
1032 height: Some(image.size.height.0 as u32),
1033 })
1034 }
1035 };
1036 LlmMessageContent::ToolResult(LlmToolResult {
1037 tool_use_id: tool_result.tool_use_id.to_string(),
1038 tool_name: tool_result.tool_name.to_string(),
1039 is_error: tool_result.is_error,
1040 content,
1041 })
1042 }
1043 MessageContent::Thinking { text, signature } => {
1044 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1045 }
1046 MessageContent::RedactedThinking(data) => {
1047 LlmMessageContent::RedactedThinking(data)
1048 }
1049 })
1050 .collect();
1051
1052 LlmRequestMessage {
1053 role,
1054 content,
1055 cache: msg.cache,
1056 }
1057 })
1058 .collect();
1059
1060 let tools: Vec<LlmToolDefinition> = request
1061 .tools
1062 .into_iter()
1063 .map(|tool| LlmToolDefinition {
1064 name: tool.name,
1065 description: tool.description,
1066 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1067 })
1068 .collect();
1069
1070 let tool_choice = request.tool_choice.map(|tc| match tc {
1071 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1072 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1073 LanguageModelToolChoice::None => LlmToolChoice::None,
1074 });
1075
1076 LlmCompletionRequest {
1077 messages,
1078 tools,
1079 tool_choice,
1080 stop_sequences: request.stop,
1081 temperature: request.temperature,
1082 thinking_allowed: false,
1083 max_tokens: None,
1084 }
1085}
1086
1087fn convert_completion_event(
1088 event: LlmCompletionEvent,
1089) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1090 match event {
1091 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1092 message_id: String::new(),
1093 }),
1094 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1095 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1096 text: thinking.text,
1097 signature: thinking.signature,
1098 }),
1099 LlmCompletionEvent::RedactedThinking(data) => {
1100 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1101 }
1102 LlmCompletionEvent::ToolUse(tool_use) => {
1103 let raw_input = tool_use.input.clone();
1104 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1105 Ok(LanguageModelCompletionEvent::ToolUse(
1106 LanguageModelToolUse {
1107 id: LanguageModelToolUseId::from(tool_use.id),
1108 name: tool_use.name.into(),
1109 raw_input,
1110 input,
1111 is_input_complete: true,
1112 thought_signature: tool_use.thought_signature,
1113 },
1114 ))
1115 }
1116 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1117 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1118 id: LanguageModelToolUseId::from(error.id),
1119 tool_name: error.tool_name.into(),
1120 raw_input: error.raw_input.into(),
1121 json_parse_error: error.error,
1122 })
1123 }
1124 LlmCompletionEvent::Stop(reason) => {
1125 let stop_reason = match reason {
1126 LlmStopReason::EndTurn => StopReason::EndTurn,
1127 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1128 LlmStopReason::ToolUse => StopReason::ToolUse,
1129 LlmStopReason::Refusal => StopReason::Refusal,
1130 };
1131 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1132 }
1133 LlmCompletionEvent::Usage(usage) => {
1134 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1135 input_tokens: usage.input_tokens,
1136 output_tokens: usage.output_tokens,
1137 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1138 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1139 }))
1140 }
1141 LlmCompletionEvent::ReasoningDetails(json) => {
1142 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1143 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1144 ))
1145 }
1146 }
1147}