1use crate::ExtensionSettings;
2use crate::LEGACY_LLM_EXTENSION_IDS;
3use crate::wasm_host::WasmExtension;
4use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
5use collections::HashSet;
6
7use crate::wasm_host::wit::{
8 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
9 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
10 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
11 LlmToolUse,
12};
13use anyhow::{Result, anyhow};
14use credentials_provider::CredentialsProvider;
15use editor::Editor;
16use extension::{LanguageModelAuthConfig, OAuthConfig};
17use futures::future::BoxFuture;
18use futures::stream::BoxStream;
19use futures::{FutureExt, StreamExt};
20use gpui::Focusable;
21use gpui::{
22 AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
23 TextStyleRefinement, UnderlineStyle, Window, px,
24};
25use language_model::tool_schema::LanguageModelToolSchemaFormat;
26use language_model::{
27 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
28 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
29 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
30 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
31 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
32};
33use markdown::{Markdown, MarkdownElement, MarkdownStyle};
34use settings::Settings;
35use std::sync::Arc;
36use theme::ThemeSettings;
37use ui::{Label, LabelSize, prelude::*};
38use util::ResultExt as _;
39use workspace::Workspace;
40use workspace::oauth_device_flow_modal::{
41 OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
42};
43
44/// An extension-based language model provider.
45pub struct ExtensionLanguageModelProvider {
46 pub extension: WasmExtension,
47 pub provider_info: LlmProviderInfo,
48 icon_path: Option<SharedString>,
49 auth_config: Option<LanguageModelAuthConfig>,
50 state: Entity<ExtensionLlmProviderState>,
51}
52
53pub struct ExtensionLlmProviderState {
54 is_authenticated: bool,
55 available_models: Vec<LlmModelInfo>,
56 /// Set of env var names that are allowed to be read for this provider.
57 allowed_env_vars: HashSet<String>,
58 /// If authenticated via env var, which one was used.
59 env_var_name_used: Option<String>,
60}
61
62impl EventEmitter<()> for ExtensionLlmProviderState {}
63
64impl ExtensionLanguageModelProvider {
65 pub fn new(
66 extension: WasmExtension,
67 provider_info: LlmProviderInfo,
68 models: Vec<LlmModelInfo>,
69 is_authenticated: bool,
70 icon_path: Option<SharedString>,
71 auth_config: Option<LanguageModelAuthConfig>,
72 cx: &mut App,
73 ) -> Self {
74 let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
75
76 // Build set of allowed env vars for this provider
77 let settings = ExtensionSettings::get_global(cx);
78 let is_legacy_extension =
79 LEGACY_LLM_EXTENSION_IDS.contains(&extension.manifest.id.as_ref());
80
81 let mut allowed_env_vars = HashSet::default();
82 if let Some(env_vars) = auth_config.as_ref().and_then(|c| c.env_vars.as_ref()) {
83 for env_var_name in env_vars {
84 let key = format!("{}:{}", provider_id_string, env_var_name);
85 // For legacy extensions, auto-allow if env var is set (migration will persist this)
86 let env_var_is_set = std::env::var(env_var_name)
87 .map(|v| !v.is_empty())
88 .unwrap_or(false);
89 if settings.allowed_env_var_providers.contains(key.as_str())
90 || (is_legacy_extension && env_var_is_set)
91 {
92 allowed_env_vars.insert(env_var_name.clone());
93 }
94 }
95 }
96
97 // Check if any allowed env var is set
98 let env_var_name_used = allowed_env_vars.iter().find_map(|env_var_name| {
99 if let Ok(value) = std::env::var(env_var_name) {
100 if !value.is_empty() {
101 return Some(env_var_name.clone());
102 }
103 }
104 None
105 });
106
107 let is_authenticated = if env_var_name_used.is_some() {
108 true
109 } else {
110 is_authenticated
111 };
112
113 let state = cx.new(|_| ExtensionLlmProviderState {
114 is_authenticated,
115 available_models: models,
116 allowed_env_vars,
117 env_var_name_used,
118 });
119
120 Self {
121 extension,
122 provider_info,
123 icon_path,
124 auth_config,
125 state,
126 }
127 }
128
129 fn provider_id_string(&self) -> String {
130 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
131 }
132
133 /// The credential key used for storing the API key in the system keychain.
134 fn credential_key(&self) -> String {
135 format!("extension-llm-{}", self.provider_id_string())
136 }
137}
138
139impl LanguageModelProvider for ExtensionLanguageModelProvider {
140 fn id(&self) -> LanguageModelProviderId {
141 LanguageModelProviderId::from(self.provider_id_string())
142 }
143
144 fn name(&self) -> LanguageModelProviderName {
145 LanguageModelProviderName::from(format!("(Extension) {}", self.provider_info.name))
146 }
147
148 fn icon(&self) -> ui::IconName {
149 ui::IconName::ZedAssistant
150 }
151
152 fn icon_path(&self) -> Option<SharedString> {
153 self.icon_path.clone()
154 }
155
156 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
157 let state = self.state.read(cx);
158 state
159 .available_models
160 .iter()
161 .find(|m| m.is_default)
162 .or_else(|| state.available_models.first())
163 .map(|model_info| {
164 Arc::new(ExtensionLanguageModel {
165 extension: self.extension.clone(),
166 model_info: model_info.clone(),
167 provider_id: self.id(),
168 provider_name: self.name(),
169 provider_info: self.provider_info.clone(),
170 }) as Arc<dyn LanguageModel>
171 })
172 }
173
174 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
175 let state = self.state.read(cx);
176 state
177 .available_models
178 .iter()
179 .find(|m| m.is_default_fast)
180 .map(|model_info| {
181 Arc::new(ExtensionLanguageModel {
182 extension: self.extension.clone(),
183 model_info: model_info.clone(),
184 provider_id: self.id(),
185 provider_name: self.name(),
186 provider_info: self.provider_info.clone(),
187 }) as Arc<dyn LanguageModel>
188 })
189 }
190
191 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
192 let state = self.state.read(cx);
193 state
194 .available_models
195 .iter()
196 .map(|model_info| {
197 Arc::new(ExtensionLanguageModel {
198 extension: self.extension.clone(),
199 model_info: model_info.clone(),
200 provider_id: self.id(),
201 provider_name: self.name(),
202 provider_info: self.provider_info.clone(),
203 }) as Arc<dyn LanguageModel>
204 })
205 .collect()
206 }
207
208 fn is_authenticated(&self, cx: &App) -> bool {
209 // First check cached state
210 if self.state.read(cx).is_authenticated {
211 return true;
212 }
213
214 // Also check env var dynamically (in case settings changed after provider creation)
215 if let Some(ref auth_config) = self.auth_config {
216 if let Some(ref env_vars) = auth_config.env_vars {
217 let provider_id_string = self.provider_id_string();
218 let settings = ExtensionSettings::get_global(cx);
219 let is_legacy_extension =
220 LEGACY_LLM_EXTENSION_IDS.contains(&self.extension.manifest.id.as_ref());
221
222 for env_var_name in env_vars {
223 let key = format!("{}:{}", provider_id_string, env_var_name);
224 // For legacy extensions, auto-allow if env var is set
225 let env_var_is_set = std::env::var(env_var_name)
226 .map(|v| !v.is_empty())
227 .unwrap_or(false);
228 if settings.allowed_env_var_providers.contains(key.as_str())
229 || (is_legacy_extension && env_var_is_set)
230 {
231 if let Ok(value) = std::env::var(env_var_name) {
232 if !value.is_empty() {
233 return true;
234 }
235 }
236 }
237 }
238 }
239 }
240
241 false
242 }
243
244 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
245 // Check if already authenticated via is_authenticated
246 if self.is_authenticated(cx) {
247 return Task::ready(Ok(()));
248 }
249
250 // Not authenticated - return error indicating credentials not found
251 Task::ready(Err(AuthenticateError::CredentialsNotFound))
252 }
253
254 fn configuration_view(
255 &self,
256 _target_agent: ConfigurationViewTargetAgent,
257 window: &mut Window,
258 cx: &mut App,
259 ) -> AnyView {
260 let credential_key = self.credential_key();
261 let extension = self.extension.clone();
262 let extension_provider_id = self.provider_info.id.clone();
263 let full_provider_id = self.provider_id_string();
264 let state = self.state.clone();
265 let auth_config = self.auth_config.clone();
266
267 let icon_path = self.icon_path.clone();
268 cx.new(|cx| {
269 ExtensionProviderConfigurationView::new(
270 credential_key,
271 extension,
272 extension_provider_id,
273 full_provider_id,
274 auth_config,
275 state,
276 icon_path,
277 window,
278 cx,
279 )
280 })
281 .into()
282 }
283
284 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
285 let extension = self.extension.clone();
286 let provider_id = self.provider_info.id.clone();
287 let state = self.state.clone();
288 let credential_key = self.credential_key();
289
290 let credentials_provider = <dyn CredentialsProvider>::global(cx);
291
292 cx.spawn(async move |cx| {
293 // Delete from system keychain
294 credentials_provider
295 .delete_credentials(&credential_key, cx)
296 .await
297 .log_err();
298
299 // Call extension's reset_credentials
300 let result = extension
301 .call(|extension, store| {
302 async move {
303 extension
304 .call_llm_provider_reset_credentials(store, &provider_id)
305 .await
306 }
307 .boxed()
308 })
309 .await;
310
311 // Update state
312 cx.update(|cx| {
313 state.update(cx, |state, _| {
314 state.is_authenticated = false;
315 });
316 })?;
317
318 match result {
319 Ok(Ok(Ok(()))) => Ok(()),
320 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
321 Ok(Err(e)) => Err(e),
322 Err(e) => Err(e),
323 }
324 })
325 }
326}
327
328impl LanguageModelProviderState for ExtensionLanguageModelProvider {
329 type ObservableEntity = ExtensionLlmProviderState;
330
331 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
332 Some(self.state.clone())
333 }
334
335 fn subscribe<T: 'static>(
336 &self,
337 cx: &mut Context<T>,
338 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
339 ) -> Option<Subscription> {
340 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
341 }
342}
343
344/// Configuration view for extension-based LLM providers.
345struct ExtensionProviderConfigurationView {
346 credential_key: String,
347 extension: WasmExtension,
348 extension_provider_id: String,
349 full_provider_id: String,
350 auth_config: Option<LanguageModelAuthConfig>,
351 state: Entity<ExtensionLlmProviderState>,
352 settings_markdown: Option<Entity<Markdown>>,
353 api_key_editor: Entity<Editor>,
354 loading_settings: bool,
355 loading_credentials: bool,
356 oauth_in_progress: bool,
357 oauth_error: Option<String>,
358 icon_path: Option<SharedString>,
359 _subscriptions: Vec<Subscription>,
360}
361
362impl ExtensionProviderConfigurationView {
363 fn new(
364 credential_key: String,
365 extension: WasmExtension,
366 extension_provider_id: String,
367 full_provider_id: String,
368 auth_config: Option<LanguageModelAuthConfig>,
369 state: Entity<ExtensionLlmProviderState>,
370 icon_path: Option<SharedString>,
371 window: &mut Window,
372 cx: &mut Context<Self>,
373 ) -> Self {
374 // Subscribe to state changes
375 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
376 cx.notify();
377 });
378
379 // Create API key editor
380 let api_key_editor = cx.new(|cx| {
381 let mut editor = Editor::single_line(window, cx);
382 editor.set_placeholder_text("Enter API key...", window, cx);
383 editor
384 });
385
386 let mut this = Self {
387 credential_key,
388 extension,
389 extension_provider_id,
390 full_provider_id,
391 auth_config,
392 state,
393 settings_markdown: None,
394 api_key_editor,
395 loading_settings: true,
396 loading_credentials: true,
397 oauth_in_progress: false,
398 oauth_error: None,
399 icon_path,
400 _subscriptions: vec![state_subscription],
401 };
402
403 // Load settings text from extension
404 this.load_settings_text(cx);
405
406 // Load existing credentials
407 this.load_credentials(cx);
408
409 this
410 }
411
412 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
413 let extension = self.extension.clone();
414 let provider_id = self.extension_provider_id.clone();
415
416 cx.spawn(async move |this, cx| {
417 let result = extension
418 .call({
419 let provider_id = provider_id.clone();
420 |ext, store| {
421 async move {
422 ext.call_llm_provider_settings_markdown(store, &provider_id)
423 .await
424 }
425 .boxed()
426 }
427 })
428 .await;
429
430 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
431
432 this.update(cx, |this, cx| {
433 this.loading_settings = false;
434 if let Some(text) = settings_text {
435 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
436 this.settings_markdown = Some(markdown);
437 }
438 cx.notify();
439 })
440 .log_err();
441 })
442 .detach();
443 }
444
445 fn load_credentials(&mut self, cx: &mut Context<Self>) {
446 let credential_key = self.credential_key.clone();
447 let credentials_provider = <dyn CredentialsProvider>::global(cx);
448 let state = self.state.clone();
449
450 // Check if we should use env var (already set in state during provider construction)
451 let using_env_var = self.state.read(cx).env_var_name_used.is_some();
452
453 cx.spawn(async move |this, cx| {
454 // If using env var, we're already authenticated
455 if using_env_var {
456 this.update(cx, |this, cx| {
457 this.loading_credentials = false;
458 cx.notify();
459 })
460 .log_err();
461 return;
462 }
463
464 let credentials = credentials_provider
465 .read_credentials(&credential_key, cx)
466 .await
467 .log_err()
468 .flatten();
469
470 let has_credentials = credentials.is_some();
471
472 // Update authentication state based on stored credentials
473 cx.update(|cx| {
474 state.update(cx, |state, cx| {
475 state.is_authenticated = has_credentials;
476 cx.notify();
477 });
478 })
479 .log_err();
480
481 this.update(cx, |this, cx| {
482 this.loading_credentials = false;
483 cx.notify();
484 })
485 .log_err();
486 })
487 .detach();
488 }
489
490 fn toggle_env_var_permission(&mut self, env_var_name: String, cx: &mut Context<Self>) {
491 let full_provider_id = self.full_provider_id.clone();
492 let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
493
494 let state = self.state.clone();
495 let currently_allowed = self.state.read(cx).allowed_env_vars.contains(&env_var_name);
496
497 // Update settings file
498 settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
499 move |settings, _| {
500 let allowed = settings
501 .extension
502 .allowed_env_var_providers
503 .get_or_insert_with(Vec::new);
504
505 if currently_allowed {
506 allowed.retain(|id| id.as_ref() != settings_key.as_ref());
507 } else {
508 if !allowed
509 .iter()
510 .any(|id| id.as_ref() == settings_key.as_ref())
511 {
512 allowed.push(settings_key.clone());
513 }
514 }
515 }
516 });
517
518 // Update local state
519 let new_allowed = !currently_allowed;
520
521 state.update(cx, |state, cx| {
522 if new_allowed {
523 state.allowed_env_vars.insert(env_var_name.clone());
524 // Check if this env var is set and update env_var_name_used
525 if let Ok(value) = std::env::var(&env_var_name) {
526 if !value.is_empty() && state.env_var_name_used.is_none() {
527 state.env_var_name_used = Some(env_var_name.clone());
528 state.is_authenticated = true;
529 }
530 }
531 } else {
532 state.allowed_env_vars.remove(&env_var_name);
533 // If this was the env var being used, clear it and find another
534 if state.env_var_name_used.as_ref() == Some(&env_var_name) {
535 state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
536 if let Ok(value) = std::env::var(var) {
537 if !value.is_empty() {
538 return Some(var.clone());
539 }
540 }
541 None
542 });
543 if state.env_var_name_used.is_none() {
544 // No env var auth available, need to check keychain
545 state.is_authenticated = false;
546 }
547 }
548 }
549 cx.notify();
550 });
551
552 // If all env vars are being disabled, reload credentials from keychain
553 if !new_allowed && self.state.read(cx).allowed_env_vars.is_empty() {
554 self.reload_keychain_credentials(cx);
555 }
556
557 cx.notify();
558 }
559
560 fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
561 let credential_key = self.credential_key.clone();
562 let credentials_provider = <dyn CredentialsProvider>::global(cx);
563 let state = self.state.clone();
564
565 cx.spawn(async move |_this, cx| {
566 let credentials = credentials_provider
567 .read_credentials(&credential_key, cx)
568 .await
569 .log_err()
570 .flatten();
571
572 let has_credentials = credentials.is_some();
573
574 cx.update(|cx| {
575 state.update(cx, |state, cx| {
576 state.is_authenticated = has_credentials;
577 cx.notify();
578 });
579 })
580 .log_err();
581 })
582 .detach();
583 }
584
585 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
586 let api_key = self.api_key_editor.read(cx).text(cx);
587 if api_key.is_empty() {
588 return;
589 }
590
591 // Clear the editor
592 self.api_key_editor
593 .update(cx, |editor, cx| editor.set_text("", window, cx));
594
595 let credential_key = self.credential_key.clone();
596 let credentials_provider = <dyn CredentialsProvider>::global(cx);
597 let state = self.state.clone();
598
599 cx.spawn(async move |_this, cx| {
600 // Store in system keychain
601 credentials_provider
602 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
603 .await
604 .log_err();
605
606 // Update state to authenticated
607 cx.update(|cx| {
608 state.update(cx, |state, cx| {
609 state.is_authenticated = true;
610 cx.notify();
611 });
612 })
613 .log_err();
614 })
615 .detach();
616 }
617
618 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
619 // Clear the editor
620 self.api_key_editor
621 .update(cx, |editor, cx| editor.set_text("", window, cx));
622
623 let credential_key = self.credential_key.clone();
624 let credentials_provider = <dyn CredentialsProvider>::global(cx);
625 let state = self.state.clone();
626
627 cx.spawn(async move |_this, cx| {
628 // Delete from system keychain
629 credentials_provider
630 .delete_credentials(&credential_key, cx)
631 .await
632 .log_err();
633
634 // Update state to unauthenticated
635 cx.update(|cx| {
636 state.update(cx, |state, cx| {
637 state.is_authenticated = false;
638 cx.notify();
639 });
640 })
641 .log_err();
642 })
643 .detach();
644 }
645
646 fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
647 if self.oauth_in_progress {
648 return;
649 }
650
651 self.oauth_in_progress = true;
652 self.oauth_error = None;
653 cx.notify();
654
655 let extension = self.extension.clone();
656 let provider_id = self.extension_provider_id.clone();
657 let state = self.state.clone();
658 let icon_path = self.icon_path.clone();
659 let this_handle = cx.weak_entity();
660
661 // Get workspace to show modal
662 let Some(workspace) = window.root::<Workspace>().flatten() else {
663 self.oauth_in_progress = false;
664 self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string());
665 cx.notify();
666 return;
667 };
668
669 let workspace = workspace.downgrade();
670 let state = state.downgrade();
671 cx.spawn_in(window, async move |_this, cx| {
672 // Step 1: Start device flow - get prompt info from extension
673 let start_result = extension
674 .call({
675 let provider_id = provider_id.clone();
676 |ext, store| {
677 async move {
678 ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
679 .await
680 }
681 .boxed()
682 }
683 })
684 .await;
685
686 let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
687 Ok(Ok(Ok(info))) => info,
688 Ok(Ok(Err(e))) => {
689 log::error!("Device flow start failed: {}", e);
690 this_handle
691 .update_in(cx, |this, _window, cx| {
692 this.oauth_in_progress = false;
693 this.oauth_error = Some(e);
694 cx.notify();
695 })
696 .log_err();
697 return;
698 }
699 Ok(Err(e)) | Err(e) => {
700 log::error!("Device flow start error: {}", e);
701 this_handle
702 .update_in(cx, |this, _window, cx| {
703 this.oauth_in_progress = false;
704 this.oauth_error = Some(e.to_string());
705 cx.notify();
706 })
707 .log_err();
708 return;
709 }
710 };
711
712 // Step 2: Create state entity and show the modal
713 let modal_config = OAuthDeviceFlowModalConfig {
714 user_code: prompt_info.user_code,
715 verification_url: prompt_info.verification_url,
716 headline: prompt_info.headline,
717 description: prompt_info.description,
718 connect_button_label: prompt_info.connect_button_label,
719 success_headline: prompt_info.success_headline,
720 success_message: prompt_info.success_message,
721 icon_path,
722 };
723
724 let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace
725 .update_in(cx, |workspace, window, cx| {
726 let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
727 let flow_state_clone = flow_state.clone();
728 workspace.toggle_modal(window, cx, |_window, cx| {
729 OAuthDeviceFlowModal::new(flow_state_clone, cx)
730 });
731 flow_state
732 })
733 .ok();
734
735 let Some(flow_state) = flow_state else {
736 this_handle
737 .update_in(cx, |this, _window, cx| {
738 this.oauth_in_progress = false;
739 this.oauth_error = Some("Failed to show sign-in modal".to_string());
740 cx.notify();
741 })
742 .log_err();
743 return;
744 };
745
746 // Step 3: Poll for authentication completion
747 let poll_result = extension
748 .call({
749 let provider_id = provider_id.clone();
750 |ext, store| {
751 async move {
752 ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
753 .await
754 }
755 .boxed()
756 }
757 })
758 .await;
759
760 match poll_result {
761 Ok(Ok(Ok(()))) => {
762 // After successful auth, refresh the models list
763 let models_result = extension
764 .call({
765 let provider_id = provider_id.clone();
766 |ext, store| {
767 async move {
768 ext.call_llm_provider_models(store, &provider_id).await
769 }
770 .boxed()
771 }
772 })
773 .await;
774
775 let new_models: Vec<LlmModelInfo> = match models_result {
776 Ok(Ok(Ok(models))) => models,
777 _ => Vec::new(),
778 };
779
780 state
781 .update_in(cx, |state, _window, cx| {
782 state.is_authenticated = true;
783 state.available_models = new_models;
784 cx.notify();
785 })
786 .log_err();
787
788 // Update flow state to show success
789 flow_state
790 .update_in(cx, |state, _window, cx| {
791 state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
792 })
793 .log_err();
794 }
795 Ok(Ok(Err(e))) => {
796 log::error!("Device flow poll failed: {}", e);
797 flow_state
798 .update_in(cx, |state, _window, cx| {
799 state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
800 })
801 .log_err();
802 this_handle
803 .update_in(cx, |this, _window, cx| {
804 this.oauth_error = Some(e);
805 cx.notify();
806 })
807 .log_err();
808 }
809 Ok(Err(e)) | Err(e) => {
810 log::error!("Device flow poll error: {}", e);
811 let error_string = e.to_string();
812 flow_state
813 .update_in(cx, |state, _window, cx| {
814 state.set_status(
815 OAuthDeviceFlowStatus::Failed(error_string.clone()),
816 cx,
817 );
818 })
819 .log_err();
820 this_handle
821 .update_in(cx, |this, _window, cx| {
822 this.oauth_error = Some(error_string);
823 cx.notify();
824 })
825 .log_err();
826 }
827 };
828
829 this_handle
830 .update_in(cx, |this, _window, cx| {
831 this.oauth_in_progress = false;
832 cx.notify();
833 })
834 .log_err();
835 })
836 .detach();
837 }
838
839 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
840 self.state.read(cx).is_authenticated
841 }
842
843 fn has_oauth_config(&self) -> bool {
844 self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
845 }
846
847 fn oauth_config(&self) -> Option<&OAuthConfig> {
848 self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
849 }
850
851 fn has_api_key_config(&self) -> bool {
852 // API key is available if there's a credential_label or no oauth-only config
853 self.auth_config
854 .as_ref()
855 .map(|c| c.credential_label.is_some() || c.oauth.is_none())
856 .unwrap_or(true)
857 }
858}
859
860impl gpui::Render for ExtensionProviderConfigurationView {
861 fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
862 let is_loading = self.loading_settings || self.loading_credentials;
863 let is_authenticated = self.is_authenticated(cx);
864 let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
865 let env_var_name_used = self.state.read(cx).env_var_name_used.clone();
866 let has_oauth = self.has_oauth_config();
867 let has_api_key = self.has_api_key_config();
868
869 if is_loading {
870 return v_flex()
871 .gap_2()
872 .child(Label::new("Loading...").color(Color::Muted))
873 .into_any_element();
874 }
875
876 let mut content = v_flex().gap_4().size_full();
877
878 // Render settings markdown if available
879 if let Some(markdown) = &self.settings_markdown {
880 let style = settings_markdown_style(window, cx);
881 content = content.child(MarkdownElement::new(markdown.clone(), style));
882 }
883
884 // Render env var checkboxes - one for each env var the extension declares
885 if let Some(auth_config) = &self.auth_config {
886 if let Some(env_vars) = &auth_config.env_vars {
887 for env_var_name in env_vars {
888 let is_allowed = allowed_env_vars.contains(env_var_name);
889 let checkbox_label =
890 format!("Read API key from {} environment variable", env_var_name);
891 let env_var_for_click = env_var_name.clone();
892
893 content = content.child(
894 h_flex()
895 .gap_2()
896 .child(
897 ui::Checkbox::new(
898 SharedString::from(format!("env-var-{}", env_var_name)),
899 is_allowed.into(),
900 )
901 .on_click(cx.listener(
902 move |this, _, _window, cx| {
903 this.toggle_env_var_permission(
904 env_var_for_click.clone(),
905 cx,
906 );
907 },
908 )),
909 )
910 .child(Label::new(checkbox_label).size(LabelSize::Small)),
911 );
912 }
913
914 // Show status if any env var is being used
915 if let Some(used_var) = &env_var_name_used {
916 let tooltip_label = format!(
917 "To reset this API key, unset the {} environment variable.",
918 used_var
919 );
920 content = content.child(
921 h_flex()
922 .mt_0p5()
923 .p_1()
924 .justify_between()
925 .rounded_md()
926 .border_1()
927 .border_color(cx.theme().colors().border)
928 .bg(cx.theme().colors().background)
929 .child(
930 h_flex()
931 .flex_1()
932 .min_w_0()
933 .gap_1()
934 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
935 .child(
936 Label::new(format!(
937 "API key set in {} environment variable",
938 used_var
939 ))
940 .truncate(),
941 ),
942 )
943 .child(
944 ui::Button::new("reset-key", "Reset Key")
945 .label_size(LabelSize::Small)
946 .icon(ui::IconName::Undo)
947 .icon_size(ui::IconSize::Small)
948 .icon_color(Color::Muted)
949 .icon_position(ui::IconPosition::Start)
950 .disabled(true)
951 .tooltip(ui::Tooltip::text(tooltip_label)),
952 ),
953 );
954 return content.into_any_element();
955 }
956 }
957 }
958
959 // If authenticated, show success state with sign out option
960 if is_authenticated && env_var_name_used.is_none() {
961 let reset_label = if has_oauth && !has_api_key {
962 "Sign Out"
963 } else {
964 "Reset Key"
965 };
966
967 let status_label = if has_oauth && !has_api_key {
968 "Signed in"
969 } else {
970 "API key configured"
971 };
972
973 content = content.child(
974 h_flex()
975 .mt_0p5()
976 .p_1()
977 .justify_between()
978 .rounded_md()
979 .border_1()
980 .border_color(cx.theme().colors().border)
981 .bg(cx.theme().colors().background)
982 .child(
983 h_flex()
984 .flex_1()
985 .min_w_0()
986 .gap_1()
987 .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
988 .child(Label::new(status_label).truncate()),
989 )
990 .child(
991 ui::Button::new("reset-key", reset_label)
992 .label_size(LabelSize::Small)
993 .icon(ui::IconName::Undo)
994 .icon_size(ui::IconSize::Small)
995 .icon_color(Color::Muted)
996 .icon_position(ui::IconPosition::Start)
997 .on_click(cx.listener(|this, _, window, cx| {
998 this.reset_api_key(window, cx);
999 })),
1000 ),
1001 );
1002
1003 return content.into_any_element();
1004 }
1005
1006 // Not authenticated - show available auth options
1007 if env_var_name_used.is_none() {
1008 // Render OAuth sign-in button if configured
1009 if has_oauth {
1010 let oauth_config = self.oauth_config();
1011 let button_label = oauth_config
1012 .and_then(|c| c.sign_in_button_label.clone())
1013 .unwrap_or_else(|| "Sign In".to_string());
1014 let button_icon = oauth_config
1015 .and_then(|c| c.sign_in_button_icon.as_ref())
1016 .and_then(|icon_name| match icon_name.as_str() {
1017 "github" => Some(ui::IconName::Github),
1018 _ => None,
1019 });
1020
1021 let oauth_in_progress = self.oauth_in_progress;
1022
1023 let oauth_error = self.oauth_error.clone();
1024
1025 let mut button = ui::Button::new("oauth-sign-in", button_label)
1026 .full_width()
1027 .style(ui::ButtonStyle::Outlined)
1028 .disabled(oauth_in_progress)
1029 .on_click(cx.listener(|this, _, window, cx| {
1030 this.start_oauth_sign_in(window, cx);
1031 }));
1032 if let Some(icon) = button_icon {
1033 button = button
1034 .icon(icon)
1035 .icon_position(ui::IconPosition::Start)
1036 .icon_size(ui::IconSize::Small)
1037 .icon_color(Color::Muted);
1038 }
1039
1040 content = content.child(
1041 v_flex()
1042 .gap_2()
1043 .child(button)
1044 .when(oauth_in_progress, |this| {
1045 this.child(
1046 Label::new("Sign-in in progress...")
1047 .size(LabelSize::Small)
1048 .color(Color::Muted),
1049 )
1050 })
1051 .when_some(oauth_error, |this, error| {
1052 this.child(
1053 v_flex()
1054 .gap_1()
1055 .child(
1056 h_flex()
1057 .gap_2()
1058 .child(
1059 ui::Icon::new(ui::IconName::Warning)
1060 .color(Color::Error)
1061 .size(ui::IconSize::Small),
1062 )
1063 .child(
1064 Label::new("Authentication failed")
1065 .color(Color::Error)
1066 .size(LabelSize::Small),
1067 ),
1068 )
1069 .child(
1070 div().pl_6().child(
1071 Label::new(error)
1072 .color(Color::Error)
1073 .size(LabelSize::Small),
1074 ),
1075 ),
1076 )
1077 }),
1078 );
1079 }
1080
1081 // Render API key input if configured (and we have both options, show a separator)
1082 if has_api_key {
1083 if has_oauth {
1084 content = content.child(
1085 h_flex()
1086 .gap_2()
1087 .items_center()
1088 .child(div().h_px().flex_1().bg(cx.theme().colors().border))
1089 .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
1090 .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
1091 );
1092 }
1093
1094 let credential_label = self
1095 .auth_config
1096 .as_ref()
1097 .and_then(|c| c.credential_label.clone())
1098 .unwrap_or_else(|| "API Key".to_string());
1099
1100 content = content.child(
1101 v_flex()
1102 .gap_2()
1103 .on_action(cx.listener(Self::save_api_key))
1104 .child(
1105 Label::new(credential_label)
1106 .size(LabelSize::Small)
1107 .color(Color::Muted),
1108 )
1109 .child(self.api_key_editor.clone())
1110 .child(
1111 Label::new("Enter your API key and press Enter to save")
1112 .size(LabelSize::Small)
1113 .color(Color::Muted),
1114 ),
1115 );
1116 }
1117 }
1118
1119 // Show OpenAI-compatible models notification for OpenAI extension
1120 if self.extension_provider_id == "openai" {
1121 content = content.child(
1122 h_flex()
1123 .gap_1()
1124 .child(
1125 ui::Icon::new(ui::IconName::Info)
1126 .size(ui::IconSize::Small)
1127 .color(Color::Muted),
1128 )
1129 .child(
1130 Label::new("Zed also supports OpenAI-compatible models.")
1131 .size(LabelSize::Small)
1132 .color(Color::Muted),
1133 )
1134 .child(
1135 ui::Button::new("learn-more", "Learn More")
1136 .style(ui::ButtonStyle::Subtle)
1137 .label_size(LabelSize::Small)
1138 .icon(ui::IconName::ArrowUpRight)
1139 .icon_size(ui::IconSize::Small)
1140 .icon_color(Color::Muted)
1141 .icon_position(ui::IconPosition::End)
1142 .on_click(|_, _, cx| {
1143 cx.open_url("https://zed.dev/docs/configuring-llm-providers#openai-compatible-providers");
1144 }),
1145 ),
1146 );
1147 }
1148
1149 content.into_any_element()
1150 }
1151}
1152
1153impl Focusable for ExtensionProviderConfigurationView {
1154 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
1155 self.api_key_editor.focus_handle(cx)
1156 }
1157}
1158
1159fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
1160 let theme_settings = ThemeSettings::get_global(cx);
1161 let colors = cx.theme().colors();
1162 let mut text_style = window.text_style();
1163 text_style.refine(&TextStyleRefinement {
1164 font_family: Some(theme_settings.ui_font.family.clone()),
1165 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
1166 font_features: Some(theme_settings.ui_font.features.clone()),
1167 color: Some(colors.text),
1168 ..Default::default()
1169 });
1170
1171 MarkdownStyle {
1172 base_text_style: text_style,
1173 selection_background_color: colors.element_selection_background,
1174 inline_code: TextStyleRefinement {
1175 background_color: Some(colors.editor_background),
1176 ..Default::default()
1177 },
1178 link: TextStyleRefinement {
1179 color: Some(colors.text_accent),
1180 underline: Some(UnderlineStyle {
1181 color: Some(colors.text_accent.opacity(0.5)),
1182 thickness: px(1.),
1183 ..Default::default()
1184 }),
1185 ..Default::default()
1186 },
1187 syntax: cx.theme().syntax().clone(),
1188 ..Default::default()
1189 }
1190}
1191
1192/// An extension-based language model.
1193pub struct ExtensionLanguageModel {
1194 extension: WasmExtension,
1195 model_info: LlmModelInfo,
1196 provider_id: LanguageModelProviderId,
1197 provider_name: LanguageModelProviderName,
1198 provider_info: LlmProviderInfo,
1199}
1200
1201impl LanguageModel for ExtensionLanguageModel {
1202 fn id(&self) -> LanguageModelId {
1203 LanguageModelId::from(self.model_info.id.clone())
1204 }
1205
1206 fn name(&self) -> LanguageModelName {
1207 LanguageModelName::from(self.model_info.name.clone())
1208 }
1209
1210 fn provider_id(&self) -> LanguageModelProviderId {
1211 self.provider_id.clone()
1212 }
1213
1214 fn provider_name(&self) -> LanguageModelProviderName {
1215 self.provider_name.clone()
1216 }
1217
1218 fn telemetry_id(&self) -> String {
1219 format!("extension-{}", self.model_info.id)
1220 }
1221
1222 fn supports_images(&self) -> bool {
1223 self.model_info.capabilities.supports_images
1224 }
1225
1226 fn supports_tools(&self) -> bool {
1227 self.model_info.capabilities.supports_tools
1228 }
1229
1230 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
1231 match choice {
1232 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
1233 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
1234 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
1235 }
1236 }
1237
1238 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
1239 match self.model_info.capabilities.tool_input_format {
1240 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
1241 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
1242 }
1243 }
1244
1245 fn max_token_count(&self) -> u64 {
1246 self.model_info.max_token_count
1247 }
1248
1249 fn max_output_tokens(&self) -> Option<u64> {
1250 self.model_info.max_output_tokens
1251 }
1252
1253 fn count_tokens(
1254 &self,
1255 request: LanguageModelRequest,
1256 cx: &App,
1257 ) -> BoxFuture<'static, Result<u64>> {
1258 let extension = self.extension.clone();
1259 let provider_id = self.provider_info.id.clone();
1260 let model_id = self.model_info.id.clone();
1261
1262 let wit_request = convert_request_to_wit(request);
1263
1264 cx.background_spawn(async move {
1265 extension
1266 .call({
1267 let provider_id = provider_id.clone();
1268 let model_id = model_id.clone();
1269 let wit_request = wit_request.clone();
1270 |ext, store| {
1271 async move {
1272 let count = ext
1273 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
1274 .await?
1275 .map_err(|e| anyhow!("{}", e))?;
1276 Ok(count)
1277 }
1278 .boxed()
1279 }
1280 })
1281 .await?
1282 })
1283 .boxed()
1284 }
1285
1286 fn stream_completion(
1287 &self,
1288 request: LanguageModelRequest,
1289 _cx: &AsyncApp,
1290 ) -> BoxFuture<
1291 'static,
1292 Result<
1293 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
1294 LanguageModelCompletionError,
1295 >,
1296 > {
1297 let extension = self.extension.clone();
1298 let provider_id = self.provider_info.id.clone();
1299 let model_id = self.model_info.id.clone();
1300
1301 let wit_request = convert_request_to_wit(request);
1302
1303 async move {
1304 // Start the stream
1305 let stream_id_result = extension
1306 .call({
1307 let provider_id = provider_id.clone();
1308 let model_id = model_id.clone();
1309 let wit_request = wit_request.clone();
1310 |ext, store| {
1311 async move {
1312 let id = ext
1313 .call_llm_stream_completion_start(
1314 store,
1315 &provider_id,
1316 &model_id,
1317 &wit_request,
1318 )
1319 .await?
1320 .map_err(|e| anyhow!("{}", e))?;
1321 Ok(id)
1322 }
1323 .boxed()
1324 }
1325 })
1326 .await;
1327
1328 let stream_id = stream_id_result
1329 .map_err(LanguageModelCompletionError::Other)?
1330 .map_err(LanguageModelCompletionError::Other)?;
1331
1332 // Create a stream that polls for events
1333 let stream = futures::stream::unfold(
1334 (extension.clone(), stream_id, false),
1335 move |(extension, stream_id, done)| async move {
1336 if done {
1337 return None;
1338 }
1339
1340 let result = extension
1341 .call({
1342 let stream_id = stream_id.clone();
1343 |ext, store| {
1344 async move {
1345 let event = ext
1346 .call_llm_stream_completion_next(store, &stream_id)
1347 .await?
1348 .map_err(|e| anyhow!("{}", e))?;
1349 Ok(event)
1350 }
1351 .boxed()
1352 }
1353 })
1354 .await
1355 .and_then(|inner| inner);
1356
1357 match result {
1358 Ok(Some(event)) => {
1359 let converted = convert_completion_event(event);
1360 let is_done =
1361 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
1362 Some((converted, (extension, stream_id, is_done)))
1363 }
1364 Ok(None) => {
1365 // Stream complete, close it
1366 let _ = extension
1367 .call({
1368 let stream_id = stream_id.clone();
1369 |ext, store| {
1370 async move {
1371 ext.call_llm_stream_completion_close(store, &stream_id)
1372 .await?;
1373 Ok::<(), anyhow::Error>(())
1374 }
1375 .boxed()
1376 }
1377 })
1378 .await;
1379 None
1380 }
1381 Err(e) => Some((
1382 Err(LanguageModelCompletionError::Other(e)),
1383 (extension, stream_id, true),
1384 )),
1385 }
1386 },
1387 );
1388
1389 Ok(stream.boxed())
1390 }
1391 .boxed()
1392 }
1393
1394 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
1395 // Extensions can implement this via llm_cache_configuration
1396 None
1397 }
1398}
1399
1400fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
1401 use language_model::{MessageContent, Role};
1402
1403 let messages: Vec<LlmRequestMessage> = request
1404 .messages
1405 .into_iter()
1406 .map(|msg| {
1407 let role = match msg.role {
1408 Role::User => LlmMessageRole::User,
1409 Role::Assistant => LlmMessageRole::Assistant,
1410 Role::System => LlmMessageRole::System,
1411 };
1412
1413 let content: Vec<LlmMessageContent> = msg
1414 .content
1415 .into_iter()
1416 .map(|c| match c {
1417 MessageContent::Text(text) => LlmMessageContent::Text(text),
1418 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
1419 source: image.source.to_string(),
1420 width: Some(image.size.width.0 as u32),
1421 height: Some(image.size.height.0 as u32),
1422 }),
1423 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
1424 id: tool_use.id.to_string(),
1425 name: tool_use.name.to_string(),
1426 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1427 is_input_complete: tool_use.is_input_complete,
1428 thought_signature: tool_use.thought_signature,
1429 }),
1430 MessageContent::ToolResult(tool_result) => {
1431 let content = match tool_result.content {
1432 language_model::LanguageModelToolResultContent::Text(text) => {
1433 LlmToolResultContent::Text(text.to_string())
1434 }
1435 language_model::LanguageModelToolResultContent::Image(image) => {
1436 LlmToolResultContent::Image(LlmImageData {
1437 source: image.source.to_string(),
1438 width: Some(image.size.width.0 as u32),
1439 height: Some(image.size.height.0 as u32),
1440 })
1441 }
1442 };
1443 LlmMessageContent::ToolResult(LlmToolResult {
1444 tool_use_id: tool_result.tool_use_id.to_string(),
1445 tool_name: tool_result.tool_name.to_string(),
1446 is_error: tool_result.is_error,
1447 content,
1448 })
1449 }
1450 MessageContent::Thinking { text, signature } => {
1451 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
1452 }
1453 MessageContent::RedactedThinking(data) => {
1454 LlmMessageContent::RedactedThinking(data)
1455 }
1456 })
1457 .collect();
1458
1459 LlmRequestMessage {
1460 role,
1461 content,
1462 cache: msg.cache,
1463 }
1464 })
1465 .collect();
1466
1467 let tools: Vec<LlmToolDefinition> = request
1468 .tools
1469 .into_iter()
1470 .map(|tool| LlmToolDefinition {
1471 name: tool.name,
1472 description: tool.description,
1473 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
1474 })
1475 .collect();
1476
1477 let tool_choice = request.tool_choice.map(|tc| match tc {
1478 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
1479 LanguageModelToolChoice::Any => LlmToolChoice::Any,
1480 LanguageModelToolChoice::None => LlmToolChoice::None,
1481 });
1482
1483 LlmCompletionRequest {
1484 messages,
1485 tools,
1486 tool_choice,
1487 stop_sequences: request.stop,
1488 temperature: request.temperature,
1489 thinking_allowed: false,
1490 max_tokens: None,
1491 }
1492}
1493
1494fn convert_completion_event(
1495 event: LlmCompletionEvent,
1496) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
1497 match event {
1498 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
1499 message_id: String::new(),
1500 }),
1501 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
1502 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
1503 text: thinking.text,
1504 signature: thinking.signature,
1505 }),
1506 LlmCompletionEvent::RedactedThinking(data) => {
1507 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
1508 }
1509 LlmCompletionEvent::ToolUse(tool_use) => {
1510 let raw_input = tool_use.input.clone();
1511 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
1512 Ok(LanguageModelCompletionEvent::ToolUse(
1513 LanguageModelToolUse {
1514 id: LanguageModelToolUseId::from(tool_use.id),
1515 name: tool_use.name.into(),
1516 raw_input,
1517 input,
1518 is_input_complete: tool_use.is_input_complete,
1519 thought_signature: tool_use.thought_signature,
1520 },
1521 ))
1522 }
1523 LlmCompletionEvent::ToolUseJsonParseError(error) => {
1524 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1525 id: LanguageModelToolUseId::from(error.id),
1526 tool_name: error.tool_name.into(),
1527 raw_input: error.raw_input.into(),
1528 json_parse_error: error.error,
1529 })
1530 }
1531 LlmCompletionEvent::Stop(reason) => {
1532 let stop_reason = match reason {
1533 LlmStopReason::EndTurn => StopReason::EndTurn,
1534 LlmStopReason::MaxTokens => StopReason::MaxTokens,
1535 LlmStopReason::ToolUse => StopReason::ToolUse,
1536 LlmStopReason::Refusal => StopReason::Refusal,
1537 };
1538 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
1539 }
1540 LlmCompletionEvent::Usage(usage) => {
1541 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1542 input_tokens: usage.input_tokens,
1543 output_tokens: usage.output_tokens,
1544 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1545 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1546 }))
1547 }
1548 LlmCompletionEvent::ReasoningDetails(json) => {
1549 Ok(LanguageModelCompletionEvent::ReasoningDetails(
1550 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
1551 ))
1552 }
1553 }
1554}