1use crate::wasm_host::WasmExtension;
2
3use crate::wasm_host::wit::{
4 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
5 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
6 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
7 LlmToolUse,
8};
9use anyhow::{Result, anyhow};
10use credentials_provider::CredentialsProvider;
11use editor::Editor;
12use futures::future::BoxFuture;
13use futures::stream::BoxStream;
14use futures::{FutureExt, StreamExt};
15use gpui::Focusable;
16use gpui::{
17 AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
18 TextStyleRefinement, UnderlineStyle, Window, px,
19};
20use language_model::tool_schema::LanguageModelToolSchemaFormat;
21use language_model::{
22 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
23 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
27};
28use markdown::{Markdown, MarkdownElement, MarkdownStyle};
29use settings::Settings;
30use std::sync::Arc;
31use theme::ThemeSettings;
32use ui::{Label, LabelSize, prelude::*};
33use util::ResultExt as _;
34
35/// An extension-based language model provider.
36pub struct ExtensionLanguageModelProvider {
37 pub extension: WasmExtension,
38 pub provider_info: LlmProviderInfo,
39 icon_path: Option<SharedString>,
40 state: Entity<ExtensionLlmProviderState>,
41}
42
43pub struct ExtensionLlmProviderState {
44 is_authenticated: bool,
45 available_models: Vec<LlmModelInfo>,
46}
47
48impl EventEmitter<()> for ExtensionLlmProviderState {}
49
50impl ExtensionLanguageModelProvider {
51 pub fn new(
52 extension: WasmExtension,
53 provider_info: LlmProviderInfo,
54 models: Vec<LlmModelInfo>,
55 is_authenticated: bool,
56 icon_path: Option<SharedString>,
57 cx: &mut App,
58 ) -> Self {
59 let state = cx.new(|_| ExtensionLlmProviderState {
60 is_authenticated,
61 available_models: models,
62 });
63
64 Self {
65 extension,
66 provider_info,
67 icon_path,
68 state,
69 }
70 }
71
72 fn provider_id_string(&self) -> String {
73 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
74 }
75
76 /// The credential key used for storing the API key in the system keychain.
77 fn credential_key(&self) -> String {
78 format!("extension-llm-{}", self.provider_id_string())
79 }
80}
81
82impl LanguageModelProvider for ExtensionLanguageModelProvider {
83 fn id(&self) -> LanguageModelProviderId {
84 LanguageModelProviderId::from(self.provider_id_string())
85 }
86
87 fn name(&self) -> LanguageModelProviderName {
88 LanguageModelProviderName::from(self.provider_info.name.clone())
89 }
90
91 fn icon(&self) -> ui::IconName {
92 ui::IconName::ZedAssistant
93 }
94
95 fn icon_path(&self) -> Option<SharedString> {
96 self.icon_path.clone()
97 }
98
99 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
100 let state = self.state.read(cx);
101 state
102 .available_models
103 .iter()
104 .find(|m| m.is_default)
105 .or_else(|| state.available_models.first())
106 .map(|model_info| {
107 Arc::new(ExtensionLanguageModel {
108 extension: self.extension.clone(),
109 model_info: model_info.clone(),
110 provider_id: self.id(),
111 provider_name: self.name(),
112 provider_info: self.provider_info.clone(),
113 }) as Arc<dyn LanguageModel>
114 })
115 }
116
117 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
118 let state = self.state.read(cx);
119 state
120 .available_models
121 .iter()
122 .find(|m| m.is_default_fast)
123 .map(|model_info| {
124 Arc::new(ExtensionLanguageModel {
125 extension: self.extension.clone(),
126 model_info: model_info.clone(),
127 provider_id: self.id(),
128 provider_name: self.name(),
129 provider_info: self.provider_info.clone(),
130 }) as Arc<dyn LanguageModel>
131 })
132 }
133
134 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
135 let state = self.state.read(cx);
136 state
137 .available_models
138 .iter()
139 .map(|model_info| {
140 Arc::new(ExtensionLanguageModel {
141 extension: self.extension.clone(),
142 model_info: model_info.clone(),
143 provider_id: self.id(),
144 provider_name: self.name(),
145 provider_info: self.provider_info.clone(),
146 }) as Arc<dyn LanguageModel>
147 })
148 .collect()
149 }
150
151 fn is_authenticated(&self, cx: &App) -> bool {
152 self.state.read(cx).is_authenticated
153 }
154
155 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
156 let extension = self.extension.clone();
157 let provider_id = self.provider_info.id.clone();
158 let state = self.state.clone();
159
160 cx.spawn(async move |cx| {
161 let result = extension
162 .call(|extension, store| {
163 async move {
164 extension
165 .call_llm_provider_authenticate(store, &provider_id)
166 .await
167 }
168 .boxed()
169 })
170 .await;
171
172 match result {
173 Ok(Ok(Ok(()))) => {
174 cx.update(|cx| {
175 state.update(cx, |state, _| {
176 state.is_authenticated = true;
177 });
178 })?;
179 Ok(())
180 }
181 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
182 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
183 Err(e) => Err(AuthenticateError::Other(e)),
184 }
185 })
186 }
187
188 fn configuration_view(
189 &self,
190 _target_agent: ConfigurationViewTargetAgent,
191 window: &mut Window,
192 cx: &mut App,
193 ) -> AnyView {
194 let credential_key = self.credential_key();
195 let extension = self.extension.clone();
196 let extension_provider_id = self.provider_info.id.clone();
197 let state = self.state.clone();
198
199 cx.new(|cx| {
200 ExtensionProviderConfigurationView::new(
201 credential_key,
202 extension,
203 extension_provider_id,
204 state,
205 window,
206 cx,
207 )
208 })
209 .into()
210 }
211
212 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
213 let extension = self.extension.clone();
214 let provider_id = self.provider_info.id.clone();
215 let state = self.state.clone();
216 let credential_key = self.credential_key();
217
218 let credentials_provider = <dyn CredentialsProvider>::global(cx);
219
220 cx.spawn(async move |cx| {
221 // Delete from system keychain
222 credentials_provider
223 .delete_credentials(&credential_key, cx)
224 .await
225 .log_err();
226
227 // Call extension's reset_credentials
228 let result = extension
229 .call(|extension, store| {
230 async move {
231 extension
232 .call_llm_provider_reset_credentials(store, &provider_id)
233 .await
234 }
235 .boxed()
236 })
237 .await;
238
239 // Update state
240 cx.update(|cx| {
241 state.update(cx, |state, _| {
242 state.is_authenticated = false;
243 });
244 })?;
245
246 match result {
247 Ok(Ok(Ok(()))) => Ok(()),
248 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
249 Ok(Err(e)) => Err(e),
250 Err(e) => Err(e),
251 }
252 })
253 }
254}
255
256impl LanguageModelProviderState for ExtensionLanguageModelProvider {
257 type ObservableEntity = ExtensionLlmProviderState;
258
259 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
260 Some(self.state.clone())
261 }
262
263 fn subscribe<T: 'static>(
264 &self,
265 cx: &mut Context<T>,
266 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
267 ) -> Option<Subscription> {
268 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
269 }
270}
271
272/// Configuration view for extension-based LLM providers.
273struct ExtensionProviderConfigurationView {
274 credential_key: String,
275 extension: WasmExtension,
276 extension_provider_id: String,
277 state: Entity<ExtensionLlmProviderState>,
278 settings_markdown: Option<Entity<Markdown>>,
279 api_key_editor: Entity<Editor>,
280 loading_settings: bool,
281 loading_credentials: bool,
282 _subscriptions: Vec<Subscription>,
283}
284
285impl ExtensionProviderConfigurationView {
286 fn new(
287 credential_key: String,
288 extension: WasmExtension,
289 extension_provider_id: String,
290 state: Entity<ExtensionLlmProviderState>,
291 window: &mut Window,
292 cx: &mut Context<Self>,
293 ) -> Self {
294 // Subscribe to state changes
295 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
296 cx.notify();
297 });
298
299 // Create API key editor
300 let api_key_editor = cx.new(|cx| {
301 let mut editor = Editor::single_line(window, cx);
302 editor.set_placeholder_text("Enter API key...", window, cx);
303 editor
304 });
305
306 let mut this = Self {
307 credential_key,
308 extension,
309 extension_provider_id,
310 state,
311 settings_markdown: None,
312 api_key_editor,
313 loading_settings: true,
314 loading_credentials: true,
315 _subscriptions: vec![state_subscription],
316 };
317
318 // Load settings text from extension
319 this.load_settings_text(cx);
320
321 // Load existing credentials
322 this.load_credentials(cx);
323
324 this
325 }
326
327 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
328 let extension = self.extension.clone();
329 let provider_id = self.extension_provider_id.clone();
330
331 cx.spawn(async move |this, cx| {
332 let result = extension
333 .call({
334 let provider_id = provider_id.clone();
335 |ext, store| {
336 async move {
337 ext.call_llm_provider_settings_markdown(store, &provider_id)
338 .await
339 }
340 .boxed()
341 }
342 })
343 .await;
344
345 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
346
347 this.update(cx, |this, cx| {
348 this.loading_settings = false;
349 if let Some(text) = settings_text {
350 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
351 this.settings_markdown = Some(markdown);
352 }
353 cx.notify();
354 })
355 .log_err();
356 })
357 .detach();
358 }
359
360 fn load_credentials(&mut self, cx: &mut Context<Self>) {
361 let credential_key = self.credential_key.clone();
362 let credentials_provider = <dyn CredentialsProvider>::global(cx);
363 let state = self.state.clone();
364
365 cx.spawn(async move |this, cx| {
366 let credentials = credentials_provider
367 .read_credentials(&credential_key, cx)
368 .await
369 .log_err()
370 .flatten();
371
372 let has_credentials = credentials.is_some();
373
374 // Update authentication state based on stored credentials
375 let _ = cx.update(|cx| {
376 state.update(cx, |state, cx| {
377 state.is_authenticated = has_credentials;
378 cx.notify();
379 });
380 });
381
382 this.update(cx, |this, cx| {
383 this.loading_credentials = false;
384 cx.notify();
385 })
386 .log_err();
387 })
388 .detach();
389 }
390
391 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
392 let api_key = self.api_key_editor.read(cx).text(cx);
393 if api_key.is_empty() {
394 return;
395 }
396
397 // Clear the editor
398 self.api_key_editor
399 .update(cx, |editor, cx| editor.set_text("", window, cx));
400
401 let credential_key = self.credential_key.clone();
402 let credentials_provider = <dyn CredentialsProvider>::global(cx);
403 let state = self.state.clone();
404
405 cx.spawn(async move |_this, cx| {
406 // Store in system keychain
407 credentials_provider
408 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
409 .await
410 .log_err();
411
412 // Update state to authenticated
413 let _ = cx.update(|cx| {
414 state.update(cx, |state, cx| {
415 state.is_authenticated = true;
416 cx.notify();
417 });
418 });
419 })
420 .detach();
421 }
422
423 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
424 // Clear the editor
425 self.api_key_editor
426 .update(cx, |editor, cx| editor.set_text("", window, cx));
427
428 let credential_key = self.credential_key.clone();
429 let credentials_provider = <dyn CredentialsProvider>::global(cx);
430 let state = self.state.clone();
431
432 cx.spawn(async move |_this, cx| {
433 // Delete from system keychain
434 credentials_provider
435 .delete_credentials(&credential_key, cx)
436 .await
437 .log_err();
438
439 // Update state to unauthenticated
440 let _ = cx.update(|cx| {
441 state.update(cx, |state, cx| {
442 state.is_authenticated = false;
443 cx.notify();
444 });
445 });
446 })
447 .detach();
448 }
449
450 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
451 self.state.read(cx).is_authenticated
452 }
453}
454
455impl gpui::Render for ExtensionProviderConfigurationView {
456 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
457 let is_loading = self.loading_settings || self.loading_credentials;
458 let is_authenticated = self.is_authenticated(cx);
459
460 if is_loading {
461 return v_flex()
462 .gap_2()
463 .child(Label::new("Loading...").color(Color::Muted))
464 .into_any_element();
465 }
466
467 let mut content = v_flex().gap_4().size_full();
468
469 // Render settings markdown if available
470 if let Some(markdown) = &self.settings_markdown {
471 let style = settings_markdown_style(_window, cx);
472 content = content.child(
473 div()
474 .p_2()
475 .rounded_md()
476 .bg(cx.theme().colors().surface_background)
477 .child(MarkdownElement::new(markdown.clone(), style)),
478 );
479 }
480
481 // Render API key section
482 if is_authenticated {
483 content = content.child(
484 v_flex()
485 .gap_2()
486 .child(
487 h_flex()
488 .gap_2()
489 .child(
490 ui::Icon::new(ui::IconName::Check)
491 .color(Color::Success)
492 .size(ui::IconSize::Small),
493 )
494 .child(Label::new("API key configured").color(Color::Success)),
495 )
496 .child(
497 ui::Button::new("reset-api-key", "Reset API Key")
498 .style(ui::ButtonStyle::Subtle)
499 .on_click(cx.listener(|this, _, window, cx| {
500 this.reset_api_key(window, cx);
501 })),
502 ),
503 );
504 } else {
505 content = content.child(
506 v_flex()
507 .gap_2()
508 .on_action(cx.listener(Self::save_api_key))
509 .child(
510 Label::new("API Key")
511 .size(LabelSize::Small)
512 .color(Color::Muted),
513 )
514 .child(self.api_key_editor.clone())
515 .child(
516 Label::new("Enter your API key and press Enter to save")
517 .size(LabelSize::Small)
518 .color(Color::Muted),
519 ),
520 );
521 }
522
523 content.into_any_element()
524 }
525}
526
527impl Focusable for ExtensionProviderConfigurationView {
528 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
529 self.api_key_editor.focus_handle(cx)
530 }
531}
532
533fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
534 let theme_settings = ThemeSettings::get_global(cx);
535 let colors = cx.theme().colors();
536 let mut text_style = window.text_style();
537 text_style.refine(&TextStyleRefinement {
538 font_family: Some(theme_settings.ui_font.family.clone()),
539 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
540 font_features: Some(theme_settings.ui_font.features.clone()),
541 color: Some(colors.text),
542 ..Default::default()
543 });
544
545 MarkdownStyle {
546 base_text_style: text_style,
547 selection_background_color: colors.element_selection_background,
548 inline_code: TextStyleRefinement {
549 background_color: Some(colors.editor_background),
550 ..Default::default()
551 },
552 link: TextStyleRefinement {
553 color: Some(colors.text_accent),
554 underline: Some(UnderlineStyle {
555 color: Some(colors.text_accent.opacity(0.5)),
556 thickness: px(1.),
557 ..Default::default()
558 }),
559 ..Default::default()
560 },
561 syntax: cx.theme().syntax().clone(),
562 ..Default::default()
563 }
564}
565
566/// An extension-based language model.
567pub struct ExtensionLanguageModel {
568 extension: WasmExtension,
569 model_info: LlmModelInfo,
570 provider_id: LanguageModelProviderId,
571 provider_name: LanguageModelProviderName,
572 provider_info: LlmProviderInfo,
573}
574
575impl LanguageModel for ExtensionLanguageModel {
576 fn id(&self) -> LanguageModelId {
577 LanguageModelId::from(self.model_info.id.clone())
578 }
579
580 fn name(&self) -> LanguageModelName {
581 LanguageModelName::from(self.model_info.name.clone())
582 }
583
584 fn provider_id(&self) -> LanguageModelProviderId {
585 self.provider_id.clone()
586 }
587
588 fn provider_name(&self) -> LanguageModelProviderName {
589 self.provider_name.clone()
590 }
591
592 fn telemetry_id(&self) -> String {
593 format!("extension-{}", self.model_info.id)
594 }
595
596 fn supports_images(&self) -> bool {
597 self.model_info.capabilities.supports_images
598 }
599
600 fn supports_tools(&self) -> bool {
601 self.model_info.capabilities.supports_tools
602 }
603
604 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
605 match choice {
606 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
607 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
608 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
609 }
610 }
611
612 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
613 match self.model_info.capabilities.tool_input_format {
614 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
615 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
616 }
617 }
618
619 fn max_token_count(&self) -> u64 {
620 self.model_info.max_token_count
621 }
622
623 fn max_output_tokens(&self) -> Option<u64> {
624 self.model_info.max_output_tokens
625 }
626
627 fn count_tokens(
628 &self,
629 request: LanguageModelRequest,
630 cx: &App,
631 ) -> BoxFuture<'static, Result<u64>> {
632 let extension = self.extension.clone();
633 let provider_id = self.provider_info.id.clone();
634 let model_id = self.model_info.id.clone();
635
636 let wit_request = convert_request_to_wit(request);
637
638 cx.background_spawn(async move {
639 extension
640 .call({
641 let provider_id = provider_id.clone();
642 let model_id = model_id.clone();
643 let wit_request = wit_request.clone();
644 |ext, store| {
645 async move {
646 let count = ext
647 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
648 .await?
649 .map_err(|e| anyhow!("{}", e))?;
650 Ok(count)
651 }
652 .boxed()
653 }
654 })
655 .await?
656 })
657 .boxed()
658 }
659
660 fn stream_completion(
661 &self,
662 request: LanguageModelRequest,
663 _cx: &AsyncApp,
664 ) -> BoxFuture<
665 'static,
666 Result<
667 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
668 LanguageModelCompletionError,
669 >,
670 > {
671 let extension = self.extension.clone();
672 let provider_id = self.provider_info.id.clone();
673 let model_id = self.model_info.id.clone();
674
675 let wit_request = convert_request_to_wit(request);
676
677 async move {
678 // Start the stream
679 let stream_id_result = extension
680 .call({
681 let provider_id = provider_id.clone();
682 let model_id = model_id.clone();
683 let wit_request = wit_request.clone();
684 |ext, store| {
685 async move {
686 let id = ext
687 .call_llm_stream_completion_start(
688 store,
689 &provider_id,
690 &model_id,
691 &wit_request,
692 )
693 .await?
694 .map_err(|e| anyhow!("{}", e))?;
695 Ok(id)
696 }
697 .boxed()
698 }
699 })
700 .await;
701
702 let stream_id = stream_id_result
703 .map_err(LanguageModelCompletionError::Other)?
704 .map_err(LanguageModelCompletionError::Other)?;
705
706 // Create a stream that polls for events
707 let stream = futures::stream::unfold(
708 (extension.clone(), stream_id, false),
709 move |(extension, stream_id, done)| async move {
710 if done {
711 return None;
712 }
713
714 let result = extension
715 .call({
716 let stream_id = stream_id.clone();
717 |ext, store| {
718 async move {
719 let event = ext
720 .call_llm_stream_completion_next(store, &stream_id)
721 .await?
722 .map_err(|e| anyhow!("{}", e))?;
723 Ok(event)
724 }
725 .boxed()
726 }
727 })
728 .await
729 .and_then(|inner| inner);
730
731 match result {
732 Ok(Some(event)) => {
733 let converted = convert_completion_event(event);
734 let is_done =
735 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
736 Some((converted, (extension, stream_id, is_done)))
737 }
738 Ok(None) => {
739 // Stream complete, close it
740 let _ = extension
741 .call({
742 let stream_id = stream_id.clone();
743 |ext, store| {
744 async move {
745 ext.call_llm_stream_completion_close(store, &stream_id)
746 .await?;
747 Ok::<(), anyhow::Error>(())
748 }
749 .boxed()
750 }
751 })
752 .await;
753 None
754 }
755 Err(e) => Some((
756 Err(LanguageModelCompletionError::Other(e)),
757 (extension, stream_id, true),
758 )),
759 }
760 },
761 );
762
763 Ok(stream.boxed())
764 }
765 .boxed()
766 }
767
768 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
769 // Extensions can implement this via llm_cache_configuration
770 None
771 }
772}
773
774fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
775 use language_model::{MessageContent, Role};
776
777 let messages: Vec<LlmRequestMessage> = request
778 .messages
779 .into_iter()
780 .map(|msg| {
781 let role = match msg.role {
782 Role::User => LlmMessageRole::User,
783 Role::Assistant => LlmMessageRole::Assistant,
784 Role::System => LlmMessageRole::System,
785 };
786
787 let content: Vec<LlmMessageContent> = msg
788 .content
789 .into_iter()
790 .map(|c| match c {
791 MessageContent::Text(text) => LlmMessageContent::Text(text),
792 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
793 source: image.source.to_string(),
794 width: Some(image.size.width.0 as u32),
795 height: Some(image.size.height.0 as u32),
796 }),
797 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
798 id: tool_use.id.to_string(),
799 name: tool_use.name.to_string(),
800 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
801 thought_signature: tool_use.thought_signature,
802 }),
803 MessageContent::ToolResult(tool_result) => {
804 let content = match tool_result.content {
805 language_model::LanguageModelToolResultContent::Text(text) => {
806 LlmToolResultContent::Text(text.to_string())
807 }
808 language_model::LanguageModelToolResultContent::Image(image) => {
809 LlmToolResultContent::Image(LlmImageData {
810 source: image.source.to_string(),
811 width: Some(image.size.width.0 as u32),
812 height: Some(image.size.height.0 as u32),
813 })
814 }
815 };
816 LlmMessageContent::ToolResult(LlmToolResult {
817 tool_use_id: tool_result.tool_use_id.to_string(),
818 tool_name: tool_result.tool_name.to_string(),
819 is_error: tool_result.is_error,
820 content,
821 })
822 }
823 MessageContent::Thinking { text, signature } => {
824 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
825 }
826 MessageContent::RedactedThinking(data) => {
827 LlmMessageContent::RedactedThinking(data)
828 }
829 })
830 .collect();
831
832 LlmRequestMessage {
833 role,
834 content,
835 cache: msg.cache,
836 }
837 })
838 .collect();
839
840 let tools: Vec<LlmToolDefinition> = request
841 .tools
842 .into_iter()
843 .map(|tool| LlmToolDefinition {
844 name: tool.name,
845 description: tool.description,
846 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
847 })
848 .collect();
849
850 let tool_choice = request.tool_choice.map(|tc| match tc {
851 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
852 LanguageModelToolChoice::Any => LlmToolChoice::Any,
853 LanguageModelToolChoice::None => LlmToolChoice::None,
854 });
855
856 LlmCompletionRequest {
857 messages,
858 tools,
859 tool_choice,
860 stop_sequences: request.stop,
861 temperature: request.temperature,
862 thinking_allowed: false,
863 max_tokens: None,
864 }
865}
866
867fn convert_completion_event(
868 event: LlmCompletionEvent,
869) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
870 match event {
871 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
872 message_id: String::new(),
873 }),
874 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
875 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
876 text: thinking.text,
877 signature: thinking.signature,
878 }),
879 LlmCompletionEvent::RedactedThinking(data) => {
880 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
881 }
882 LlmCompletionEvent::ToolUse(tool_use) => {
883 let raw_input = tool_use.input.clone();
884 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
885 Ok(LanguageModelCompletionEvent::ToolUse(
886 LanguageModelToolUse {
887 id: LanguageModelToolUseId::from(tool_use.id),
888 name: tool_use.name.into(),
889 raw_input,
890 input,
891 is_input_complete: true,
892 thought_signature: tool_use.thought_signature,
893 },
894 ))
895 }
896 LlmCompletionEvent::ToolUseJsonParseError(error) => {
897 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
898 id: LanguageModelToolUseId::from(error.id),
899 tool_name: error.tool_name.into(),
900 raw_input: error.raw_input.into(),
901 json_parse_error: error.error,
902 })
903 }
904 LlmCompletionEvent::Stop(reason) => {
905 let stop_reason = match reason {
906 LlmStopReason::EndTurn => StopReason::EndTurn,
907 LlmStopReason::MaxTokens => StopReason::MaxTokens,
908 LlmStopReason::ToolUse => StopReason::ToolUse,
909 LlmStopReason::Refusal => StopReason::Refusal,
910 };
911 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
912 }
913 LlmCompletionEvent::Usage(usage) => {
914 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
915 input_tokens: usage.input_tokens,
916 output_tokens: usage.output_tokens,
917 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
918 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
919 }))
920 }
921 LlmCompletionEvent::ReasoningDetails(json) => {
922 Ok(LanguageModelCompletionEvent::ReasoningDetails(
923 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
924 ))
925 }
926 }
927}