1use crate::ExtensionSettings;
2use crate::wasm_host::WasmExtension;
3
4use crate::wasm_host::wit::{
5 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
6 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
7 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
8 LlmToolUse,
9};
10use anyhow::{Result, anyhow};
11use credentials_provider::CredentialsProvider;
12use editor::Editor;
13use extension::LanguageModelAuthConfig;
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, StreamExt};
17use gpui::Focusable;
18use gpui::{
19 AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
20 TextStyleRefinement, UnderlineStyle, Window, px,
21};
22use language_model::tool_schema::LanguageModelToolSchemaFormat;
23use language_model::{
24 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
25 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
27 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
28 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
29};
30use markdown::{Markdown, MarkdownElement, MarkdownStyle};
31use settings::Settings;
32use std::sync::Arc;
33use theme::ThemeSettings;
34use ui::{Label, LabelSize, prelude::*};
35use util::ResultExt as _;
36
37/// An extension-based language model provider.
38pub struct ExtensionLanguageModelProvider {
39 pub extension: WasmExtension,
40 pub provider_info: LlmProviderInfo,
41 icon_path: Option<SharedString>,
42 auth_config: Option<LanguageModelAuthConfig>,
43 state: Entity<ExtensionLlmProviderState>,
44}
45
46pub struct ExtensionLlmProviderState {
47 is_authenticated: bool,
48 available_models: Vec<LlmModelInfo>,
49 env_var_allowed: bool,
50 api_key_from_env: bool,
51}
52
53impl EventEmitter<()> for ExtensionLlmProviderState {}
54
55impl ExtensionLanguageModelProvider {
56 pub fn new(
57 extension: WasmExtension,
58 provider_info: LlmProviderInfo,
59 models: Vec<LlmModelInfo>,
60 is_authenticated: bool,
61 icon_path: Option<SharedString>,
62 auth_config: Option<LanguageModelAuthConfig>,
63 cx: &mut App,
64 ) -> Self {
65 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
66 let env_var_allowed = ExtensionSettings::get_global(cx)
67 .allowed_env_var_providers
68 .contains(provider_id_string.as_str());
69
70 let (is_authenticated, api_key_from_env) =
71 if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
72 let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
73 if let Ok(value) = std::env::var(env_var_name) {
74 if !value.is_empty() {
75 (true, true)
76 } else {
77 (is_authenticated, false)
78 }
79 } else {
80 (is_authenticated, false)
81 }
82 } else {
83 (is_authenticated, false)
84 };
85
86 let state = cx.new(|_| ExtensionLlmProviderState {
87 is_authenticated,
88 available_models: models,
89 env_var_allowed,
90 api_key_from_env,
91 });
92
93 Self {
94 extension,
95 provider_info,
96 icon_path,
97 auth_config,
98 state,
99 }
100 }
101
102 fn provider_id_string(&self) -> String {
103 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
104 }
105
106 /// The credential key used for storing the API key in the system keychain.
107 fn credential_key(&self) -> String {
108 format!("extension-llm-{}", self.provider_id_string())
109 }
110}
111
112impl LanguageModelProvider for ExtensionLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId::from(self.provider_id_string())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName::from(self.provider_info.name.clone())
119 }
120
121 fn icon(&self) -> ui::IconName {
122 ui::IconName::ZedAssistant
123 }
124
125 fn icon_path(&self) -> Option<SharedString> {
126 self.icon_path.clone()
127 }
128
129 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
130 let state = self.state.read(cx);
131 state
132 .available_models
133 .iter()
134 .find(|m| m.is_default)
135 .or_else(|| state.available_models.first())
136 .map(|model_info| {
137 Arc::new(ExtensionLanguageModel {
138 extension: self.extension.clone(),
139 model_info: model_info.clone(),
140 provider_id: self.id(),
141 provider_name: self.name(),
142 provider_info: self.provider_info.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 }
146
147 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
148 let state = self.state.read(cx);
149 state
150 .available_models
151 .iter()
152 .find(|m| m.is_default_fast)
153 .map(|model_info| {
154 Arc::new(ExtensionLanguageModel {
155 extension: self.extension.clone(),
156 model_info: model_info.clone(),
157 provider_id: self.id(),
158 provider_name: self.name(),
159 provider_info: self.provider_info.clone(),
160 }) as Arc<dyn LanguageModel>
161 })
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let state = self.state.read(cx);
166 state
167 .available_models
168 .iter()
169 .map(|model_info| {
170 Arc::new(ExtensionLanguageModel {
171 extension: self.extension.clone(),
172 model_info: model_info.clone(),
173 provider_id: self.id(),
174 provider_name: self.name(),
175 provider_info: self.provider_info.clone(),
176 }) as Arc<dyn LanguageModel>
177 })
178 .collect()
179 }
180
181 fn is_authenticated(&self, cx: &App) -> bool {
182 self.state.read(cx).is_authenticated
183 }
184
185 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
186 // Authentication is handled via the configuration view UI
187 Task::ready(Ok(()))
188 }
189
190 fn configuration_view(
191 &self,
192 _target_agent: ConfigurationViewTargetAgent,
193 window: &mut Window,
194 cx: &mut App,
195 ) -> AnyView {
196 let credential_key = self.credential_key();
197 let extension = self.extension.clone();
198 let extension_provider_id = self.provider_info.id.clone();
199 let full_provider_id = self.provider_id_string();
200 let state = self.state.clone();
201 let auth_config = self.auth_config.clone();
202
203 cx.new(|cx| {
204 ExtensionProviderConfigurationView::new(
205 credential_key,
206 extension,
207 extension_provider_id,
208 full_provider_id,
209 auth_config,
210 state,
211 window,
212 cx,
213 )
214 })
215 .into()
216 }
217
218 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219 let extension = self.extension.clone();
220 let provider_id = self.provider_info.id.clone();
221 let state = self.state.clone();
222 let credential_key = self.credential_key();
223
224 let credentials_provider = <dyn CredentialsProvider>::global(cx);
225
226 cx.spawn(async move |cx| {
227 // Delete from system keychain
228 credentials_provider
229 .delete_credentials(&credential_key, cx)
230 .await
231 .log_err();
232
233 // Call extension's reset_credentials
234 let result = extension
235 .call(|extension, store| {
236 async move {
237 extension
238 .call_llm_provider_reset_credentials(store, &provider_id)
239 .await
240 }
241 .boxed()
242 })
243 .await;
244
245 // Update state
246 cx.update(|cx| {
247 state.update(cx, |state, _| {
248 state.is_authenticated = false;
249 });
250 })?;
251
252 match result {
253 Ok(Ok(Ok(()))) => Ok(()),
254 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
255 Ok(Err(e)) => Err(e),
256 Err(e) => Err(e),
257 }
258 })
259 }
260}
261
262impl LanguageModelProviderState for ExtensionLanguageModelProvider {
263 type ObservableEntity = ExtensionLlmProviderState;
264
265 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
266 Some(self.state.clone())
267 }
268
269 fn subscribe<T: 'static>(
270 &self,
271 cx: &mut Context<T>,
272 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
273 ) -> Option<Subscription> {
274 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
275 }
276}
277
278/// Configuration view for extension-based LLM providers.
279struct ExtensionProviderConfigurationView {
280 credential_key: String,
281 extension: WasmExtension,
282 extension_provider_id: String,
283 full_provider_id: String,
284 auth_config: Option<LanguageModelAuthConfig>,
285 state: Entity<ExtensionLlmProviderState>,
286 settings_markdown: Option<Entity<Markdown>>,
287 api_key_editor: Entity<Editor>,
288 loading_settings: bool,
289 loading_credentials: bool,
290 _subscriptions: Vec<Subscription>,
291}
292
293impl ExtensionProviderConfigurationView {
294 fn new(
295 credential_key: String,
296 extension: WasmExtension,
297 extension_provider_id: String,
298 full_provider_id: String,
299 auth_config: Option<LanguageModelAuthConfig>,
300 state: Entity<ExtensionLlmProviderState>,
301 window: &mut Window,
302 cx: &mut Context<Self>,
303 ) -> Self {
304 // Subscribe to state changes
305 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
306 cx.notify();
307 });
308
309 // Create API key editor
310 let api_key_editor = cx.new(|cx| {
311 let mut editor = Editor::single_line(window, cx);
312 editor.set_placeholder_text("Enter API key...", window, cx);
313 editor
314 });
315
316 let mut this = Self {
317 credential_key,
318 extension,
319 extension_provider_id,
320 full_provider_id,
321 auth_config,
322 state,
323 settings_markdown: None,
324 api_key_editor,
325 loading_settings: true,
326 loading_credentials: true,
327 _subscriptions: vec![state_subscription],
328 };
329
330 // Load settings text from extension
331 this.load_settings_text(cx);
332
333 // Load existing credentials
334 this.load_credentials(cx);
335
336 this
337 }
338
339 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
340 let extension = self.extension.clone();
341 let provider_id = self.extension_provider_id.clone();
342
343 cx.spawn(async move |this, cx| {
344 let result = extension
345 .call({
346 let provider_id = provider_id.clone();
347 |ext, store| {
348 async move {
349 ext.call_llm_provider_settings_markdown(store, &provider_id)
350 .await
351 }
352 .boxed()
353 }
354 })
355 .await;
356
357 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
358
359 this.update(cx, |this, cx| {
360 this.loading_settings = false;
361 if let Some(text) = settings_text {
362 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
363 this.settings_markdown = Some(markdown);
364 }
365 cx.notify();
366 })
367 .log_err();
368 })
369 .detach();
370 }
371
372 fn load_credentials(&mut self, cx: &mut Context<Self>) {
373 let credential_key = self.credential_key.clone();
374 let credentials_provider = <dyn CredentialsProvider>::global(cx);
375 let state = self.state.clone();
376
377 // Check if we should use env var (already set in state during provider construction)
378 let api_key_from_env = self.state.read(cx).api_key_from_env;
379
380 cx.spawn(async move |this, cx| {
381 // If using env var, we're already authenticated
382 if api_key_from_env {
383 this.update(cx, |this, cx| {
384 this.loading_credentials = false;
385 cx.notify();
386 })
387 .log_err();
388 return;
389 }
390
391 let credentials = credentials_provider
392 .read_credentials(&credential_key, cx)
393 .await
394 .log_err()
395 .flatten();
396
397 let has_credentials = credentials.is_some();
398
399 // Update authentication state based on stored credentials
400 let _ = cx.update(|cx| {
401 state.update(cx, |state, cx| {
402 state.is_authenticated = has_credentials;
403 cx.notify();
404 });
405 });
406
407 this.update(cx, |this, cx| {
408 this.loading_credentials = false;
409 cx.notify();
410 })
411 .log_err();
412 })
413 .detach();
414 }
415
416 fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
417 let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
418 let env_var_name = match &self.auth_config {
419 Some(config) => config.env_var.clone(),
420 None => return,
421 };
422
423 let state = self.state.clone();
424 let currently_allowed = self.state.read(cx).env_var_allowed;
425
426 // Update settings file
427 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
428 let providers = settings
429 .extension
430 .allowed_env_var_providers
431 .get_or_insert_with(Vec::new);
432
433 if currently_allowed {
434 providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
435 } else {
436 if !providers
437 .iter()
438 .any(|id| id.as_ref() == full_provider_id.as_ref())
439 {
440 providers.push(full_provider_id.clone());
441 }
442 }
443 });
444
445 // Update local state
446 let new_allowed = !currently_allowed;
447 let new_from_env = if new_allowed {
448 if let Some(var_name) = &env_var_name {
449 if let Ok(value) = std::env::var(var_name) {
450 !value.is_empty()
451 } else {
452 false
453 }
454 } else {
455 false
456 }
457 } else {
458 false
459 };
460
461 state.update(cx, |state, cx| {
462 state.env_var_allowed = new_allowed;
463 state.api_key_from_env = new_from_env;
464 if new_from_env {
465 state.is_authenticated = true;
466 }
467 cx.notify();
468 });
469
470 // If env var is being enabled, clear any stored keychain credentials
471 // so there's only one source of truth for the API key
472 if new_allowed {
473 let credential_key = self.credential_key.clone();
474 let credentials_provider = <dyn CredentialsProvider>::global(cx);
475 cx.spawn(async move |_this, cx| {
476 credentials_provider
477 .delete_credentials(&credential_key, cx)
478 .await
479 .log_err();
480 })
481 .detach();
482 }
483
484 // If env var is being disabled, reload credentials from keychain
485 if !new_allowed {
486 self.reload_keychain_credentials(cx);
487 }
488
489 cx.notify();
490 }
491
492 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
493 let credential_key = self.credential_key.clone();
494 let credentials_provider = <dyn CredentialsProvider>::global(cx);
495 let state = self.state.clone();
496
497 cx.spawn(async move |_this, cx| {
498 let credentials = credentials_provider
499 .read_credentials(&credential_key, cx)
500 .await
501 .log_err()
502 .flatten();
503
504 let has_credentials = credentials.is_some();
505
506 let _ = cx.update(|cx| {
507 state.update(cx, |state, cx| {
508 state.is_authenticated = has_credentials;
509 cx.notify();
510 });
511 });
512 })
513 .detach();
514 }
515
516 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
517 let api_key = self.api_key_editor.read(cx).text(cx);
518 if api_key.is_empty() {
519 return;
520 }
521
522 // Clear the editor
523 self.api_key_editor
524 .update(cx, |editor, cx| editor.set_text("", window, cx));
525
526 let credential_key = self.credential_key.clone();
527 let credentials_provider = <dyn CredentialsProvider>::global(cx);
528 let state = self.state.clone();
529
530 cx.spawn(async move |_this, cx| {
531 // Store in system keychain
532 credentials_provider
533 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
534 .await
535 .log_err();
536
537 // Update state to authenticated
538 let _ = cx.update(|cx| {
539 state.update(cx, |state, cx| {
540 state.is_authenticated = true;
541 cx.notify();
542 });
543 });
544 })
545 .detach();
546 }
547
548 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
549 // Clear the editor
550 self.api_key_editor
551 .update(cx, |editor, cx| editor.set_text("", window, cx));
552
553 let credential_key = self.credential_key.clone();
554 let credentials_provider = <dyn CredentialsProvider>::global(cx);
555 let state = self.state.clone();
556
557 cx.spawn(async move |_this, cx| {
558 // Delete from system keychain
559 credentials_provider
560 .delete_credentials(&credential_key, cx)
561 .await
562 .log_err();
563
564 // Update state to unauthenticated
565 let _ = cx.update(|cx| {
566 state.update(cx, |state, cx| {
567 state.is_authenticated = false;
568 cx.notify();
569 });
570 });
571 })
572 .detach();
573 }
574
575 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
576 self.state.read(cx).is_authenticated
577 }
578}
579
580impl gpui::Render for ExtensionProviderConfigurationView {
581 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
582 let is_loading = self.loading_settings || self.loading_credentials;
583 let is_authenticated = self.is_authenticated(cx);
584 let env_var_allowed = self.state.read(cx).env_var_allowed;
585 let api_key_from_env = self.state.read(cx).api_key_from_env;
586
587 if is_loading {
588 return v_flex()
589 .gap_2()
590 .child(Label::new("Loading...").color(Color::Muted))
591 .into_any_element();
592 }
593
594 let mut content = v_flex().gap_4().size_full();
595
596 // Render settings markdown if available
597 if let Some(markdown) = &self.settings_markdown {
598 let style = settings_markdown_style(_window, cx);
599 content = content.child(
600 div()
601 .p_2()
602 .rounded_md()
603 .bg(cx.theme().colors().surface_background)
604 .child(MarkdownElement::new(markdown.clone(), style)),
605 );
606 }
607
608 // Render env var checkbox if the extension specifies an env var
609 if let Some(auth_config) = &self.auth_config {
610 if let Some(env_var_name) = &auth_config.env_var {
611 let env_var_name = env_var_name.clone();
612 let checkbox_label =
613 format!("Read API key from {} environment variable", env_var_name);
614
615 content = content.child(
616 h_flex()
617 .gap_2()
618 .child(
619 ui::Checkbox::new("env-var-permission", env_var_allowed.into())
620 .on_click(cx.listener(|this, _, _window, cx| {
621 this.toggle_env_var_permission(cx);
622 })),
623 )
624 .child(Label::new(checkbox_label).size(LabelSize::Small)),
625 );
626
627 // Show status if env var is allowed
628 if env_var_allowed {
629 if api_key_from_env {
630 content = content.child(
631 h_flex()
632 .gap_2()
633 .child(
634 ui::Icon::new(ui::IconName::Check)
635 .color(Color::Success)
636 .size(ui::IconSize::Small),
637 )
638 .child(
639 Label::new(format!("API key loaded from {}", env_var_name))
640 .color(Color::Success),
641 ),
642 );
643 return content.into_any_element();
644 } else {
645 content = content.child(
646 h_flex()
647 .gap_2()
648 .child(
649 ui::Icon::new(ui::IconName::Warning)
650 .color(Color::Warning)
651 .size(ui::IconSize::Small),
652 )
653 .child(
654 Label::new(format!(
655 "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
656 env_var_name
657 ))
658 .color(Color::Warning)
659 .size(LabelSize::Small),
660 ),
661 );
662 }
663 }
664 }
665 }
666
667 // Render API key section
668 if is_authenticated && !api_key_from_env {
669 content = content.child(
670 v_flex()
671 .gap_2()
672 .child(
673 h_flex()
674 .gap_2()
675 .child(
676 ui::Icon::new(ui::IconName::Check)
677 .color(Color::Success)
678 .size(ui::IconSize::Small),
679 )
680 .child(Label::new("API key configured").color(Color::Success)),
681 )
682 .child(
683 ui::Button::new("reset-api-key", "Reset API Key")
684 .style(ui::ButtonStyle::Subtle)
685 .on_click(cx.listener(|this, _, window, cx| {
686 this.reset_api_key(window, cx);
687 })),
688 ),
689 );
690 } else if !api_key_from_env {
691 let credential_label = self
692 .auth_config
693 .as_ref()
694 .and_then(|c| c.credential_label.clone())
695 .unwrap_or_else(|| "API Key".to_string());
696
697 content = content.child(
698 v_flex()
699 .gap_2()
700 .on_action(cx.listener(Self::save_api_key))
701 .child(
702 Label::new(credential_label)
703 .size(LabelSize::Small)
704 .color(Color::Muted),
705 )
706 .child(self.api_key_editor.clone())
707 .child(
708 Label::new("Enter your API key and press Enter to save")
709 .size(LabelSize::Small)
710 .color(Color::Muted),
711 ),
712 );
713 }
714
715 content.into_any_element()
716 }
717}
718
719impl Focusable for ExtensionProviderConfigurationView {
720 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
721 self.api_key_editor.focus_handle(cx)
722 }
723}
724
725fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
726 let theme_settings = ThemeSettings::get_global(cx);
727 let colors = cx.theme().colors();
728 let mut text_style = window.text_style();
729 text_style.refine(&TextStyleRefinement {
730 font_family: Some(theme_settings.ui_font.family.clone()),
731 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
732 font_features: Some(theme_settings.ui_font.features.clone()),
733 color: Some(colors.text),
734 ..Default::default()
735 });
736
737 MarkdownStyle {
738 base_text_style: text_style,
739 selection_background_color: colors.element_selection_background,
740 inline_code: TextStyleRefinement {
741 background_color: Some(colors.editor_background),
742 ..Default::default()
743 },
744 link: TextStyleRefinement {
745 color: Some(colors.text_accent),
746 underline: Some(UnderlineStyle {
747 color: Some(colors.text_accent.opacity(0.5)),
748 thickness: px(1.),
749 ..Default::default()
750 }),
751 ..Default::default()
752 },
753 syntax: cx.theme().syntax().clone(),
754 ..Default::default()
755 }
756}
757
758/// An extension-based language model.
759pub struct ExtensionLanguageModel {
760 extension: WasmExtension,
761 model_info: LlmModelInfo,
762 provider_id: LanguageModelProviderId,
763 provider_name: LanguageModelProviderName,
764 provider_info: LlmProviderInfo,
765}
766
767impl LanguageModel for ExtensionLanguageModel {
768 fn id(&self) -> LanguageModelId {
769 LanguageModelId::from(self.model_info.id.clone())
770 }
771
772 fn name(&self) -> LanguageModelName {
773 LanguageModelName::from(self.model_info.name.clone())
774 }
775
776 fn provider_id(&self) -> LanguageModelProviderId {
777 self.provider_id.clone()
778 }
779
780 fn provider_name(&self) -> LanguageModelProviderName {
781 self.provider_name.clone()
782 }
783
784 fn telemetry_id(&self) -> String {
785 format!("extension-{}", self.model_info.id)
786 }
787
788 fn supports_images(&self) -> bool {
789 self.model_info.capabilities.supports_images
790 }
791
792 fn supports_tools(&self) -> bool {
793 self.model_info.capabilities.supports_tools
794 }
795
796 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
797 match choice {
798 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
799 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
800 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
801 }
802 }
803
804 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
805 match self.model_info.capabilities.tool_input_format {
806 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
807 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
808 }
809 }
810
811 fn max_token_count(&self) -> u64 {
812 self.model_info.max_token_count
813 }
814
815 fn max_output_tokens(&self) -> Option<u64> {
816 self.model_info.max_output_tokens
817 }
818
819 fn count_tokens(
820 &self,
821 request: LanguageModelRequest,
822 cx: &App,
823 ) -> BoxFuture<'static, Result<u64>> {
824 let extension = self.extension.clone();
825 let provider_id = self.provider_info.id.clone();
826 let model_id = self.model_info.id.clone();
827
828 let wit_request = convert_request_to_wit(request);
829
830 cx.background_spawn(async move {
831 extension
832 .call({
833 let provider_id = provider_id.clone();
834 let model_id = model_id.clone();
835 let wit_request = wit_request.clone();
836 |ext, store| {
837 async move {
838 let count = ext
839 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
840 .await?
841 .map_err(|e| anyhow!("{}", e))?;
842 Ok(count)
843 }
844 .boxed()
845 }
846 })
847 .await?
848 })
849 .boxed()
850 }
851
852 fn stream_completion(
853 &self,
854 request: LanguageModelRequest,
855 _cx: &AsyncApp,
856 ) -> BoxFuture<
857 'static,
858 Result<
859 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
860 LanguageModelCompletionError,
861 >,
862 > {
863 let extension = self.extension.clone();
864 let provider_id = self.provider_info.id.clone();
865 let model_id = self.model_info.id.clone();
866
867 let wit_request = convert_request_to_wit(request);
868
869 async move {
870 // Start the stream
871 let stream_id_result = extension
872 .call({
873 let provider_id = provider_id.clone();
874 let model_id = model_id.clone();
875 let wit_request = wit_request.clone();
876 |ext, store| {
877 async move {
878 let id = ext
879 .call_llm_stream_completion_start(
880 store,
881 &provider_id,
882 &model_id,
883 &wit_request,
884 )
885 .await?
886 .map_err(|e| anyhow!("{}", e))?;
887 Ok(id)
888 }
889 .boxed()
890 }
891 })
892 .await;
893
894 let stream_id = stream_id_result
895 .map_err(LanguageModelCompletionError::Other)?
896 .map_err(LanguageModelCompletionError::Other)?;
897
898 // Create a stream that polls for events
899 let stream = futures::stream::unfold(
900 (extension.clone(), stream_id, false),
901 move |(extension, stream_id, done)| async move {
902 if done {
903 return None;
904 }
905
906 let result = extension
907 .call({
908 let stream_id = stream_id.clone();
909 |ext, store| {
910 async move {
911 let event = ext
912 .call_llm_stream_completion_next(store, &stream_id)
913 .await?
914 .map_err(|e| anyhow!("{}", e))?;
915 Ok(event)
916 }
917 .boxed()
918 }
919 })
920 .await
921 .and_then(|inner| inner);
922
923 match result {
924 Ok(Some(event)) => {
925 let converted = convert_completion_event(event);
926 let is_done =
927 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
928 Some((converted, (extension, stream_id, is_done)))
929 }
930 Ok(None) => {
931 // Stream complete, close it
932 let _ = extension
933 .call({
934 let stream_id = stream_id.clone();
935 |ext, store| {
936 async move {
937 ext.call_llm_stream_completion_close(store, &stream_id)
938 .await?;
939 Ok::<(), anyhow::Error>(())
940 }
941 .boxed()
942 }
943 })
944 .await;
945 None
946 }
947 Err(e) => Some((
948 Err(LanguageModelCompletionError::Other(e)),
949 (extension, stream_id, true),
950 )),
951 }
952 },
953 );
954
955 Ok(stream.boxed())
956 }
957 .boxed()
958 }
959
960 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
961 // Extensions can implement this via llm_cache_configuration
962 None
963 }
964}
965
966fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
967 use language_model::{MessageContent, Role};
968
969 let messages: Vec<LlmRequestMessage> = request
970 .messages
971 .into_iter()
972 .map(|msg| {
973 let role = match msg.role {
974 Role::User => LlmMessageRole::User,
975 Role::Assistant => LlmMessageRole::Assistant,
976 Role::System => LlmMessageRole::System,
977 };
978
979 let content: Vec<LlmMessageContent> = msg
980 .content
981 .into_iter()
982 .map(|c| match c {
983 MessageContent::Text(text) => LlmMessageContent::Text(text),
984 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
985 source: image.source.to_string(),
986 width: Some(image.size.width.0 as u32),
987 height: Some(image.size.height.0 as u32),
988 }),
989 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
990 id: tool_use.id.to_string(),
991 name: tool_use.name.to_string(),
992 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
993 thought_signature: tool_use.thought_signature,
994 }),
995 MessageContent::ToolResult(tool_result) => {
996 let content = match tool_result.content {
997 language_model::LanguageModelToolResultContent::Text(text) => {
998 LlmToolResultContent::Text(text.to_string())
999 }
1000 language_model::LanguageModelToolResultContent::Image(image) => {
1001 LlmToolResultContent::Image(LlmImageData {
1002 source: image.source.to_string(),
1003 width: Some(image.size.width.0 as u32),
1004 height: Some(image.size.height.0 as u32),
1005 })
1006 }
1007 };
1008 LlmMessageContent::ToolResult(LlmToolResult {
1009 tool_use_id: tool_result.tool_use_id.to_string(),
1010 tool_name: tool_result.tool_name.to_string(),
1011 is_error: tool_result.is_error,
1012 content,
1013 })
1014 }
1015 MessageContent::Thinking { text, signature } => {
1016 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1017 }
1018 MessageContent::RedactedThinking(data) => {
1019 LlmMessageContent::RedactedThinking(data)
1020 }
1021 })
1022 .collect();
1023
1024 LlmRequestMessage {
1025 role,
1026 content,
1027 cache: msg.cache,
1028 }
1029 })
1030 .collect();
1031
1032 let tools: Vec<LlmToolDefinition> = request
1033 .tools
1034 .into_iter()
1035 .map(|tool| LlmToolDefinition {
1036 name: tool.name,
1037 description: tool.description,
1038 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1039 })
1040 .collect();
1041
1042 let tool_choice = request.tool_choice.map(|tc| match tc {
1043 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1044 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1045 LanguageModelToolChoice::None => LlmToolChoice::None,
1046 });
1047
1048 LlmCompletionRequest {
1049 messages,
1050 tools,
1051 tool_choice,
1052 stop_sequences: request.stop,
1053 temperature: request.temperature,
1054 thinking_allowed: false,
1055 max_tokens: None,
1056 }
1057}
1058
1059fn convert_completion_event(
1060 event: LlmCompletionEvent,
1061) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1062 match event {
1063 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1064 message_id: String::new(),
1065 }),
1066 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1067 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1068 text: thinking.text,
1069 signature: thinking.signature,
1070 }),
1071 LlmCompletionEvent::RedactedThinking(data) => {
1072 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1073 }
1074 LlmCompletionEvent::ToolUse(tool_use) => {
1075 let raw_input = tool_use.input.clone();
1076 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1077 Ok(LanguageModelCompletionEvent::ToolUse(
1078 LanguageModelToolUse {
1079 id: LanguageModelToolUseId::from(tool_use.id),
1080 name: tool_use.name.into(),
1081 raw_input,
1082 input,
1083 is_input_complete: true,
1084 thought_signature: tool_use.thought_signature,
1085 },
1086 ))
1087 }
1088 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1089 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1090 id: LanguageModelToolUseId::from(error.id),
1091 tool_name: error.tool_name.into(),
1092 raw_input: error.raw_input.into(),
1093 json_parse_error: error.error,
1094 })
1095 }
1096 LlmCompletionEvent::Stop(reason) => {
1097 let stop_reason = match reason {
1098 LlmStopReason::EndTurn => StopReason::EndTurn,
1099 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1100 LlmStopReason::ToolUse => StopReason::ToolUse,
1101 LlmStopReason::Refusal => StopReason::Refusal,
1102 };
1103 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1104 }
1105 LlmCompletionEvent::Usage(usage) => {
1106 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1107 input_tokens: usage.input_tokens,
1108 output_tokens: usage.output_tokens,
1109 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1110 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1111 }))
1112 }
1113 LlmCompletionEvent::ReasoningDetails(json) => {
1114 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1115 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1116 ))
1117 }
1118 }
1119}