1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::LanguageModelAuthConfig;
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
20 TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(self.provider_info.name.clone())
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 self.state.read(cx).is_authenticated
183 }
184
185 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
186 let extension = self.extension.clone();
187 let provider_id = self.provider_info.id.clone();
188 let state = self.state.clone();
189
190 cx.spawn(async move |cx| {
191 let result = extension
192 .call(|extension, store| {
193 async move {
194 extension
195 .call_llm_provider_authenticate(store, &provider_id)
196 .await
197 }
198 .boxed()
199 })
200 .await;
201
202 match result {
203 Ok(Ok(Ok(()))) => {
204 cx.update(|cx| {
205 state.update(cx, |state, _| {
206 state.is_authenticated = true;
207 });
208 })?;
209 Ok(())
210 }
211 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
212 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
213 Err(e) => Err(AuthenticateError::Other(e)),
214 }
215 })
216 }
217
218 fn configuration_view(
219 &self,
220 _target_agent: ConfigurationViewTargetAgent,
221 window: &mut Window,
222 cx: &mut App,
223 ) -> AnyView {
224 let credential_key = self.credential_key();
225 let extension = self.extension.clone();
226 let extension_provider_id = self.provider_info.id.clone();
227 let full_provider_id = self.provider_id_string();
228 let state = self.state.clone();
229 let auth_config = self.auth_config.clone();
230
231 cx.new(|cx| {
232 ExtensionProviderConfigurationView::new(
233 credential_key,
234 extension,
235 extension_provider_id,
236 full_provider_id,
237 auth_config,
238 state,
239 window,
240 cx,
241 )
242 })
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 let extension = self.extension.clone();
248 let provider_id = self.provider_info.id.clone();
249 let state = self.state.clone();
250 let credential_key = self.credential_key();
251
252 let credentials_provider = <dyn CredentialsProvider>::global(cx);
253
254 cx.spawn(async move |cx| {
255 // Delete from system keychain
256 credentials_provider
257 .delete_credentials(&credential_key, cx)
258 .await
259 .log_err();
260
261 // Call extension's reset_credentials
262 let result = extension
263 .call(|extension, store| {
264 async move {
265 extension
266 .call_llm_provider_reset_credentials(store, &provider_id)
267 .await
268 }
269 .boxed()
270 })
271 .await;
272
273 // Update state
274 cx.update(|cx| {
275 state.update(cx, |state, _| {
276 state.is_authenticated = false;
277 });
278 })?;
279
280 match result {
281 Ok(Ok(Ok(()))) => Ok(()),
282 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
283 Ok(Err(e)) => Err(e),
284 Err(e) => Err(e),
285 }
286 })
287 }
288}
289
290impl LanguageModelProviderState for ExtensionLanguageModelProvider {
291 type ObservableEntity = ExtensionLlmProviderState;
292
293 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
294 Some(self.state.clone())
295 }
296
297 fn subscribe<T: 'static>(
298 &self,
299 cx: &mut Context<T>,
300 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
301 ) -> Option<Subscription> {
302 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
303 }
304}
305
306/// Configuration view for extension-based LLM providers.
307struct ExtensionProviderConfigurationView {
308 credential_key: String,
309 extension: WasmExtension,
310 extension_provider_id: String,
311 full_provider_id: String,
312 auth_config: Option<LanguageModelAuthConfig>,
313 state: Entity<ExtensionLlmProviderState>,
314 settings_markdown: Option<Entity<Markdown>>,
315 api_key_editor: Entity<Editor>,
316 loading_settings: bool,
317 loading_credentials: bool,
318 _subscriptions: Vec<Subscription>,
319}
320
321impl ExtensionProviderConfigurationView {
322 fn new(
323 credential_key: String,
324 extension: WasmExtension,
325 extension_provider_id: String,
326 full_provider_id: String,
327 auth_config: Option<LanguageModelAuthConfig>,
328 state: Entity<ExtensionLlmProviderState>,
329 window: &mut Window,
330 cx: &mut Context<Self>,
331 ) -> Self {
332 // Subscribe to state changes
333 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
334 cx.notify();
335 });
336
337 // Create API key editor
338 let api_key_editor = cx.new(|cx| {
339 let mut editor = Editor::single_line(window, cx);
340 editor.set_placeholder_text("Enter API key...", window, cx);
341 editor
342 });
343
344 let mut this = Self {
345 credential_key,
346 extension,
347 extension_provider_id,
348 full_provider_id,
349 auth_config,
350 state,
351 settings_markdown: None,
352 api_key_editor,
353 loading_settings: true,
354 loading_credentials: true,
355 _subscriptions: vec![state_subscription],
356 };
357
358 // Load settings text from extension
359 this.load_settings_text(cx);
360
361 // Load existing credentials
362 this.load_credentials(cx);
363
364 this
365 }
366
367 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
368 let extension = self.extension.clone();
369 let provider_id = self.extension_provider_id.clone();
370
371 cx.spawn(async move |this, cx| {
372 let result = extension
373 .call({
374 let provider_id = provider_id.clone();
375 |ext, store| {
376 async move {
377 ext.call_llm_provider_settings_markdown(store, &provider_id)
378 .await
379 }
380 .boxed()
381 }
382 })
383 .await;
384
385 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
386
387 this.update(cx, |this, cx| {
388 this.loading_settings = false;
389 if let Some(text) = settings_text {
390 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
391 this.settings_markdown = Some(markdown);
392 }
393 cx.notify();
394 })
395 .log_err();
396 })
397 .detach();
398 }
399
400 fn load_credentials(&mut self, cx: &mut Context<Self>) {
401 let credential_key = self.credential_key.clone();
402 let credentials_provider = <dyn CredentialsProvider>::global(cx);
403 let state = self.state.clone();
404
405 // Check if we should use env var (already set in state during provider construction)
406 let api_key_from_env = self.state.read(cx).api_key_from_env;
407
408 cx.spawn(async move |this, cx| {
409 // If using env var, we're already authenticated
410 if api_key_from_env {
411 this.update(cx, |this, cx| {
412 this.loading_credentials = false;
413 cx.notify();
414 })
415 .log_err();
416 return;
417 }
418
419 let credentials = credentials_provider
420 .read_credentials(&credential_key, cx)
421 .await
422 .log_err()
423 .flatten();
424
425 let has_credentials = credentials.is_some();
426
427 // Update authentication state based on stored credentials
428 let _ = cx.update(|cx| {
429 state.update(cx, |state, cx| {
430 state.is_authenticated = has_credentials;
431 cx.notify();
432 });
433 });
434
435 this.update(cx, |this, cx| {
436 this.loading_credentials = false;
437 cx.notify();
438 })
439 .log_err();
440 })
441 .detach();
442 }
443
444 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
445 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
446 let env_var_name = match &self.auth_config {
447 Some(config) => config.env_var.clone(),
448 None => return,
449 };
450
451 let state = self.state.clone();
452 let currently_allowed = self.state.read(cx).env_var_allowed;
453
454 // Update settings file
455 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
456 let providers = settings
457 .extension
458 .allowed_env_var_providers
459 .get_or_insert_with(Vec::new);
460
461 if currently_allowed {
462 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
463 } else {
464 if !providers
465 .iter()
466 .any(|id| id.as_ref() == full_provider_id.as_ref())
467 {
468 providers.push(full_provider_id.clone());
469 }
470 }
471 });
472
473 // Update local state
474 let new_allowed = !currently_allowed;
475 let new_from_env = if new_allowed {
476 if let Some(var_name) = &env_var_name {
477 if let Ok(value) = std::env::var(var_name) {
478 !value.is_empty()
479 } else {
480 false
481 }
482 } else {
483 false
484 }
485 } else {
486 false
487 };
488
489 state.update(cx, |state, cx| {
490 state.env_var_allowed = new_allowed;
491 state.api_key_from_env = new_from_env;
492 if new_from_env {
493 state.is_authenticated = true;
494 }
495 cx.notify();
496 });
497
498 // If env var is being disabled, reload credentials from keychain
499 if !new_allowed {
500 self.reload_keychain_credentials(cx);
501 }
502
503 cx.notify();
504 }
505
506 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
507 let credential_key = self.credential_key.clone();
508 let credentials_provider = <dyn CredentialsProvider>::global(cx);
509 let state = self.state.clone();
510
511 cx.spawn(async move |_this, cx| {
512 let credentials = credentials_provider
513 .read_credentials(&credential_key, cx)
514 .await
515 .log_err()
516 .flatten();
517
518 let has_credentials = credentials.is_some();
519
520 let _ = cx.update(|cx| {
521 state.update(cx, |state, cx| {
522 state.is_authenticated = has_credentials;
523 cx.notify();
524 });
525 });
526 })
527 .detach();
528 }
529
530 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
531 let api_key = self.api_key_editor.read(cx).text(cx);
532 if api_key.is_empty() {
533 return;
534 }
535
536 // Clear the editor
537 self.api_key_editor
538 .update(cx, |editor, cx| editor.set_text("", window, cx));
539
540 let credential_key = self.credential_key.clone();
541 let credentials_provider = <dyn CredentialsProvider>::global(cx);
542 let state = self.state.clone();
543
544 cx.spawn(async move |_this, cx| {
545 // Store in system keychain
546 credentials_provider
547 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
548 .await
549 .log_err();
550
551 // Update state to authenticated
552 let _ = cx.update(|cx| {
553 state.update(cx, |state, cx| {
554 state.is_authenticated = true;
555 cx.notify();
556 });
557 });
558 })
559 .detach();
560 }
561
562 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
563 // Clear the editor
564 self.api_key_editor
565 .update(cx, |editor, cx| editor.set_text("", window, cx));
566
567 let credential_key = self.credential_key.clone();
568 let credentials_provider = <dyn CredentialsProvider>::global(cx);
569 let state = self.state.clone();
570
571 cx.spawn(async move |_this, cx| {
572 // Delete from system keychain
573 credentials_provider
574 .delete_credentials(&credential_key, cx)
575 .await
576 .log_err();
577
578 // Update state to unauthenticated
579 let _ = cx.update(|cx| {
580 state.update(cx, |state, cx| {
581 state.is_authenticated = false;
582 cx.notify();
583 });
584 });
585 })
586 .detach();
587 }
588
589 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
590 self.state.read(cx).is_authenticated
591 }
592}
593
594impl gpui::Render for ExtensionProviderConfigurationView {
595 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
596 let is_loading = self.loading_settings || self.loading_credentials;
597 let is_authenticated = self.is_authenticated(cx);
598 let env_var_allowed = self.state.read(cx).env_var_allowed;
599 let api_key_from_env = self.state.read(cx).api_key_from_env;
600
601 if is_loading {
602 return v_flex()
603 .gap_2()
604 .child(Label::new("Loading...").color(Color::Muted))
605 .into_any_element();
606 }
607
608 let mut content = v_flex().gap_4().size_full();
609
610 // Render settings markdown if available
611 if let Some(markdown) = &self.settings_markdown {
612 let style = settings_markdown_style(_window, cx);
613 content = content.child(
614 div()
615 .p_2()
616 .rounded_md()
617 .bg(cx.theme().colors().surface_background)
618 .child(MarkdownElement::new(markdown.clone(), style)),
619 );
620 }
621
622 // Render env var checkbox if the extension specifies an env var
623 if let Some(auth_config) = &self.auth_config {
624 if let Some(env_var_name) = &auth_config.env_var {
625 let env_var_name = env_var_name.clone();
626 let checkbox_label =
627 format!("Read API key from {} environment variable", env_var_name);
628
629 content = content.child(
630 h_flex()
631 .gap_2()
632 .child(
633 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
634 .on_click(cx.listener(|this, _, _window, cx| {
635 this.toggle_env_var_permission(cx);
636 })),
637 )
638 .child(Label::new(checkbox_label).size(LabelSize::Small)),
639 );
640
641 // Show status if env var is allowed
642 if env_var_allowed {
643 if api_key_from_env {
644 content = content.child(
645 h_flex()
646 .gap_2()
647 .child(
648 ui::Icon::new(ui::IconName::Check)
649 .color(Color::Success)
650 .size(ui::IconSize::Small),
651 )
652 .child(
653 Label::new(format!("API key loaded from {}", env_var_name))
654 .color(Color::Success),
655 ),
656 );
657 return content.into_any_element();
658 } else {
659 content = content.child(
660 h_flex()
661 .gap_2()
662 .child(
663 ui::Icon::new(ui::IconName::Warning)
664 .color(Color::Warning)
665 .size(ui::IconSize::Small),
666 )
667 .child(
668 Label::new(format!(
669 "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
670 env_var_name
671 ))
672 .color(Color::Warning)
673 .size(LabelSize::Small),
674 ),
675 );
676 }
677 }
678 }
679 }
680
681 // Render API key section
682 if is_authenticated && !api_key_from_env {
683 content = content.child(
684 v_flex()
685 .gap_2()
686 .child(
687 h_flex()
688 .gap_2()
689 .child(
690 ui::Icon::new(ui::IconName::Check)
691 .color(Color::Success)
692 .size(ui::IconSize::Small),
693 )
694 .child(Label::new("API key configured").color(Color::Success)),
695 )
696 .child(
697 ui::Button::new("reset-api-key", "Reset API Key")
698 .style(ui::ButtonStyle::Subtle)
699 .on_click(cx.listener(|this, _, window, cx| {
700 this.reset_api_key(window, cx);
701 })),
702 ),
703 );
704 } else if !api_key_from_env {
705 let credential_label = self
706 .auth_config
707 .as_ref()
708 .and_then(|c| c.credential_label.clone())
709 .unwrap_or_else(|| "API Key".to_string());
710
711 content = content.child(
712 v_flex()
713 .gap_2()
714 .on_action(cx.listener(Self::save_api_key))
715 .child(
716 Label::new(credential_label)
717 .size(LabelSize::Small)
718 .color(Color::Muted),
719 )
720 .child(self.api_key_editor.clone())
721 .child(
722 Label::new("Enter your API key and press Enter to save")
723 .size(LabelSize::Small)
724 .color(Color::Muted),
725 ),
726 );
727 }
728
729 content.into_any_element()
730 }
731}
732
733impl Focusable for ExtensionProviderConfigurationView {
734 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
735 self.api_key_editor.focus_handle(cx)
736 }
737}
738
739fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
740 let theme_settings = ThemeSettings::get_global(cx);
741 let colors = cx.theme().colors();
742 let mut text_style = window.text_style();
743 text_style.refine(&TextStyleRefinement {
744 font_family: Some(theme_settings.ui_font.family.clone()),
745 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
746 font_features: Some(theme_settings.ui_font.features.clone()),
747 color: Some(colors.text),
748 ..Default::default()
749 });
750
751 MarkdownStyle {
752 base_text_style: text_style,
753 selection_background_color: colors.element_selection_background,
754 inline_code: TextStyleRefinement {
755 background_color: Some(colors.editor_background),
756 ..Default::default()
757 },
758 link: TextStyleRefinement {
759 color: Some(colors.text_accent),
760 underline: Some(UnderlineStyle {
761 color: Some(colors.text_accent.opacity(0.5)),
762 thickness: px(1.),
763 ..Default::default()
764 }),
765 ..Default::default()
766 },
767 syntax: cx.theme().syntax().clone(),
768 ..Default::default()
769 }
770}
771
772/// An extension-based language model.
773pub struct ExtensionLanguageModel {
774 extension: WasmExtension,
775 model_info: LlmModelInfo,
776 provider_id: LanguageModelProviderId,
777 provider_name: LanguageModelProviderName,
778 provider_info: LlmProviderInfo,
779}
780
781impl LanguageModel for ExtensionLanguageModel {
782 fn id(&self) -> LanguageModelId {
783 LanguageModelId::from(self.model_info.id.clone())
784 }
785
786 fn name(&self) -> LanguageModelName {
787 LanguageModelName::from(self.model_info.name.clone())
788 }
789
790 fn provider_id(&self) -> LanguageModelProviderId {
791 self.provider_id.clone()
792 }
793
794 fn provider_name(&self) -> LanguageModelProviderName {
795 self.provider_name.clone()
796 }
797
798 fn telemetry_id(&self) -> String {
799 format!("extension-{}", self.model_info.id)
800 }
801
802 fn supports_images(&self) -> bool {
803 self.model_info.capabilities.supports_images
804 }
805
806 fn supports_tools(&self) -> bool {
807 self.model_info.capabilities.supports_tools
808 }
809
810 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
811 match choice {
812 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
813 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
814 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
815 }
816 }
817
818 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
819 match self.model_info.capabilities.tool_input_format {
820 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
821 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
822 }
823 }
824
825 fn max_token_count(&self) -> u64 {
826 self.model_info.max_token_count
827 }
828
829 fn max_output_tokens(&self) -> Option<u64> {
830 self.model_info.max_output_tokens
831 }
832
833 fn count_tokens(
834 &self,
835 request: LanguageModelRequest,
836 cx: &App,
837 ) -> BoxFuture<'static, Result<u64>> {
838 let extension = self.extension.clone();
839 let provider_id = self.provider_info.id.clone();
840 let model_id = self.model_info.id.clone();
841
842 let wit_request = convert_request_to_wit(request);
843
844 cx.background_spawn(async move {
845 extension
846 .call({
847 let provider_id = provider_id.clone();
848 let model_id = model_id.clone();
849 let wit_request = wit_request.clone();
850 |ext, store| {
851 async move {
852 let count = ext
853 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
854 .await?
855 .map_err(|e| anyhow!("{}", e))?;
856 Ok(count)
857 }
858 .boxed()
859 }
860 })
861 .await?
862 })
863 .boxed()
864 }
865
866 fn stream_completion(
867 &self,
868 request: LanguageModelRequest,
869 _cx: &AsyncApp,
870 ) -> BoxFuture<
871 'static,
872 Result<
873 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
874 LanguageModelCompletionError,
875 >,
876 > {
877 let extension = self.extension.clone();
878 let provider_id = self.provider_info.id.clone();
879 let model_id = self.model_info.id.clone();
880
881 let wit_request = convert_request_to_wit(request);
882
883 async move {
884 // Start the stream
885 let stream_id_result = extension
886 .call({
887 let provider_id = provider_id.clone();
888 let model_id = model_id.clone();
889 let wit_request = wit_request.clone();
890 |ext, store| {
891 async move {
892 let id = ext
893 .call_llm_stream_completion_start(
894 store,
895 &provider_id,
896 &model_id,
897 &wit_request,
898 )
899 .await?
900 .map_err(|e| anyhow!("{}", e))?;
901 Ok(id)
902 }
903 .boxed()
904 }
905 })
906 .await;
907
908 let stream_id = stream_id_result
909 .map_err(LanguageModelCompletionError::Other)?
910 .map_err(LanguageModelCompletionError::Other)?;
911
912 // Create a stream that polls for events
913 let stream = futures::stream::unfold(
914 (extension.clone(), stream_id, false),
915 move |(extension, stream_id, done)| async move {
916 if done {
917 return None;
918 }
919
920 let result = extension
921 .call({
922 let stream_id = stream_id.clone();
923 |ext, store| {
924 async move {
925 let event = ext
926 .call_llm_stream_completion_next(store, &stream_id)
927 .await?
928 .map_err(|e| anyhow!("{}", e))?;
929 Ok(event)
930 }
931 .boxed()
932 }
933 })
934 .await
935 .and_then(|inner| inner);
936
937 match result {
938 Ok(Some(event)) => {
939 let converted = convert_completion_event(event);
940 let is_done =
941 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
942 Some((converted, (extension, stream_id, is_done)))
943 }
944 Ok(None) => {
945 // Stream complete, close it
946 let _ = extension
947 .call({
948 let stream_id = stream_id.clone();
949 |ext, store| {
950 async move {
951 ext.call_llm_stream_completion_close(store, &stream_id)
952 .await?;
953 Ok::<(), anyhow::Error>(())
954 }
955 .boxed()
956 }
957 })
958 .await;
959 None
960 }
961 Err(e) => Some((
962 Err(LanguageModelCompletionError::Other(e)),
963 (extension, stream_id, true),
964 )),
965 }
966 },
967 );
968
969 Ok(stream.boxed())
970 }
971 .boxed()
972 }
973
974 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
975 // Extensions can implement this via llm_cache_configuration
976 None
977 }
978}
979
980fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
981 use language_model::{MessageContent, Role};
982
983 let messages: Vec<LlmRequestMessage> = request
984 .messages
985 .into_iter()
986 .map(|msg| {
987 let role = match msg.role {
988 Role::User => LlmMessageRole::User,
989 Role::Assistant => LlmMessageRole::Assistant,
990 Role::System => LlmMessageRole::System,
991 };
992
993 let content: Vec<LlmMessageContent> = msg
994 .content
995 .into_iter()
996 .map(|c| match c {
997 MessageContent::Text(text) => LlmMessageContent::Text(text),
998 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
999 source: image.source.to_string(),
1000 width: Some(image.size.width.0 as u32),
1001 height: Some(image.size.height.0 as u32),
1002 }),
1003 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1004 id: tool_use.id.to_string(),
1005 name: tool_use.name.to_string(),
1006 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1007 thought_signature: tool_use.thought_signature,
1008 }),
1009 MessageContent::ToolResult(tool_result) => {
1010 let content = match tool_result.content {
1011 language_model::LanguageModelToolResultContent::Text(text) => {
1012 LlmToolResultContent::Text(text.to_string())
1013 }
1014 language_model::LanguageModelToolResultContent::Image(image) => {
1015 LlmToolResultContent::Image(LlmImageData {
1016 source: image.source.to_string(),
1017 width: Some(image.size.width.0 as u32),
1018 height: Some(image.size.height.0 as u32),
1019 })
1020 }
1021 };
1022 LlmMessageContent::ToolResult(LlmToolResult {
1023 tool_use_id: tool_result.tool_use_id.to_string(),
1024 tool_name: tool_result.tool_name.to_string(),
1025 is_error: tool_result.is_error,
1026 content,
1027 })
1028 }
1029 MessageContent::Thinking { text, signature } => {
1030 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1031 }
1032 MessageContent::RedactedThinking(data) => {
1033 LlmMessageContent::RedactedThinking(data)
1034 }
1035 })
1036 .collect();
1037
1038 LlmRequestMessage {
1039 role,
1040 content,
1041 cache: msg.cache,
1042 }
1043 })
1044 .collect();
1045
1046 let tools: Vec<LlmToolDefinition> = request
1047 .tools
1048 .into_iter()
1049 .map(|tool| LlmToolDefinition {
1050 name: tool.name,
1051 description: tool.description,
1052 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1053 })
1054 .collect();
1055
1056 let tool_choice = request.tool_choice.map(|tc| match tc {
1057 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1058 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1059 LanguageModelToolChoice::None => LlmToolChoice::None,
1060 });
1061
1062 LlmCompletionRequest {
1063 messages,
1064 tools,
1065 tool_choice,
1066 stop_sequences: request.stop,
1067 temperature: request.temperature,
1068 thinking_allowed: false,
1069 max_tokens: None,
1070 }
1071}
1072
1073fn convert_completion_event(
1074 event: LlmCompletionEvent,
1075) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1076 match event {
1077 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1078 message_id: String::new(),
1079 }),
1080 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1081 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1082 text: thinking.text,
1083 signature: thinking.signature,
1084 }),
1085 LlmCompletionEvent::RedactedThinking(data) => {
1086 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1087 }
1088 LlmCompletionEvent::ToolUse(tool_use) => {
1089 let raw_input = tool_use.input.clone();
1090 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1091 Ok(LanguageModelCompletionEvent::ToolUse(
1092 LanguageModelToolUse {
1093 id: LanguageModelToolUseId::from(tool_use.id),
1094 name: tool_use.name.into(),
1095 raw_input,
1096 input,
1097 is_input_complete: true,
1098 thought_signature: tool_use.thought_signature,
1099 },
1100 ))
1101 }
1102 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1103 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1104 id: LanguageModelToolUseId::from(error.id),
1105 tool_name: error.tool_name.into(),
1106 raw_input: error.raw_input.into(),
1107 json_parse_error: error.error,
1108 })
1109 }
1110 LlmCompletionEvent::Stop(reason) => {
1111 let stop_reason = match reason {
1112 LlmStopReason::EndTurn => StopReason::EndTurn,
1113 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1114 LlmStopReason::ToolUse => StopReason::ToolUse,
1115 LlmStopReason::Refusal => StopReason::Refusal,
1116 };
1117 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1118 }
1119 LlmCompletionEvent::Usage(usage) => {
1120 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1121 input_tokens: usage.input_tokens,
1122 output_tokens: usage.output_tokens,
1123 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1124 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1125 }))
1126 }
1127 LlmCompletionEvent::ReasoningDetails(json) => {
1128 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1129 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1130 ))
1131 }
1132 }
1133}