1use crate::wasm_host::WasmExtension;
2
3use crate::wasm_host::wit::{
4 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
5 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
6 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
7 LlmToolUse,
8};
9use anyhow::{Result, anyhow};
10use credentials_provider::CredentialsProvider;
11use editor::Editor;
12use futures::future::BoxFuture;
13use futures::stream::BoxStream;
14use futures::{FutureExt, StreamExt};
15use gpui::Focusable;
16use gpui::{
17 AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
18 TextStyleRefinement, UnderlineStyle, Window, px,
19};
20use language_model::tool_schema::LanguageModelToolSchemaFormat;
21use language_model::{
22 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
23 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
27};
28use markdown::{Markdown, MarkdownElement, MarkdownStyle};
29use settings::Settings;
30use std::sync::Arc;
31use theme::ThemeSettings;
32use ui::{Label, LabelSize, prelude::*};
33use util::ResultExt as _;
34
35/// An extension-based language model provider.
36pub struct ExtensionLanguageModelProvider {
37 pub extension: WasmExtension,
38 pub provider_info: LlmProviderInfo,
39 state: Entity<ExtensionLlmProviderState>,
40}
41
42pub struct ExtensionLlmProviderState {
43 is_authenticated: bool,
44 available_models: Vec<LlmModelInfo>,
45}
46
47impl EventEmitter<()> for ExtensionLlmProviderState {}
48
49impl ExtensionLanguageModelProvider {
50 pub fn new(
51 extension: WasmExtension,
52 provider_info: LlmProviderInfo,
53 models: Vec<LlmModelInfo>,
54 is_authenticated: bool,
55 cx: &mut App,
56 ) -> Self {
57 let state = cx.new(|_| ExtensionLlmProviderState {
58 is_authenticated,
59 available_models: models,
60 });
61
62 Self {
63 extension,
64 provider_info,
65 state,
66 }
67 }
68
69 fn provider_id_string(&self) -> String {
70 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
71 }
72
73 /// The credential key used for storing the API key in the system keychain.
74 fn credential_key(&self) -> String {
75 format!("extension-llm-{}", self.provider_id_string())
76 }
77}
78
79impl LanguageModelProvider for ExtensionLanguageModelProvider {
80 fn id(&self) -> LanguageModelProviderId {
81 LanguageModelProviderId::from(self.provider_id_string())
82 }
83
84 fn name(&self) -> LanguageModelProviderName {
85 LanguageModelProviderName::from(self.provider_info.name.clone())
86 }
87
88 fn icon(&self) -> ui::IconName {
89 ui::IconName::ZedAssistant
90 }
91
92 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
93 let state = self.state.read(cx);
94 state
95 .available_models
96 .iter()
97 .find(|m| m.is_default)
98 .or_else(|| state.available_models.first())
99 .map(|model_info| {
100 Arc::new(ExtensionLanguageModel {
101 extension: self.extension.clone(),
102 model_info: model_info.clone(),
103 provider_id: self.id(),
104 provider_name: self.name(),
105 provider_info: self.provider_info.clone(),
106 }) as Arc<dyn LanguageModel>
107 })
108 }
109
110 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
111 let state = self.state.read(cx);
112 state
113 .available_models
114 .iter()
115 .find(|m| m.is_default_fast)
116 .map(|model_info| {
117 Arc::new(ExtensionLanguageModel {
118 extension: self.extension.clone(),
119 model_info: model_info.clone(),
120 provider_id: self.id(),
121 provider_name: self.name(),
122 provider_info: self.provider_info.clone(),
123 }) as Arc<dyn LanguageModel>
124 })
125 }
126
127 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
128 let state = self.state.read(cx);
129 state
130 .available_models
131 .iter()
132 .map(|model_info| {
133 Arc::new(ExtensionLanguageModel {
134 extension: self.extension.clone(),
135 model_info: model_info.clone(),
136 provider_id: self.id(),
137 provider_name: self.name(),
138 provider_info: self.provider_info.clone(),
139 }) as Arc<dyn LanguageModel>
140 })
141 .collect()
142 }
143
144 fn is_authenticated(&self, cx: &App) -> bool {
145 self.state.read(cx).is_authenticated
146 }
147
148 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
149 let extension = self.extension.clone();
150 let provider_id = self.provider_info.id.clone();
151 let state = self.state.clone();
152
153 cx.spawn(async move |cx| {
154 let result = extension
155 .call(|extension, store| {
156 async move {
157 extension
158 .call_llm_provider_authenticate(store, &provider_id)
159 .await
160 }
161 .boxed()
162 })
163 .await;
164
165 match result {
166 Ok(Ok(Ok(()))) => {
167 cx.update(|cx| {
168 state.update(cx, |state, _| {
169 state.is_authenticated = true;
170 });
171 })?;
172 Ok(())
173 }
174 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
175 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
176 Err(e) => Err(AuthenticateError::Other(e)),
177 }
178 })
179 }
180
181 fn configuration_view(
182 &self,
183 _target_agent: ConfigurationViewTargetAgent,
184 window: &mut Window,
185 cx: &mut App,
186 ) -> AnyView {
187 let credential_key = self.credential_key();
188 let extension = self.extension.clone();
189 let extension_provider_id = self.provider_info.id.clone();
190 let state = self.state.clone();
191
192 cx.new(|cx| {
193 ExtensionProviderConfigurationView::new(
194 credential_key,
195 extension,
196 extension_provider_id,
197 state,
198 window,
199 cx,
200 )
201 })
202 .into()
203 }
204
205 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
206 let extension = self.extension.clone();
207 let provider_id = self.provider_info.id.clone();
208 let state = self.state.clone();
209 let credential_key = self.credential_key();
210
211 let credentials_provider = <dyn CredentialsProvider>::global(cx);
212
213 cx.spawn(async move |cx| {
214 // Delete from system keychain
215 credentials_provider
216 .delete_credentials(&credential_key, cx)
217 .await
218 .log_err();
219
220 // Call extension's reset_credentials
221 let result = extension
222 .call(|extension, store| {
223 async move {
224 extension
225 .call_llm_provider_reset_credentials(store, &provider_id)
226 .await
227 }
228 .boxed()
229 })
230 .await;
231
232 // Update state
233 cx.update(|cx| {
234 state.update(cx, |state, _| {
235 state.is_authenticated = false;
236 });
237 })?;
238
239 match result {
240 Ok(Ok(Ok(()))) => Ok(()),
241 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
242 Ok(Err(e)) => Err(e),
243 Err(e) => Err(e),
244 }
245 })
246 }
247}
248
249impl LanguageModelProviderState for ExtensionLanguageModelProvider {
250 type ObservableEntity = ExtensionLlmProviderState;
251
252 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
253 Some(self.state.clone())
254 }
255
256 fn subscribe<T: 'static>(
257 &self,
258 cx: &mut Context<T>,
259 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
260 ) -> Option<Subscription> {
261 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
262 }
263}
264
265/// Configuration view for extension-based LLM providers.
266struct ExtensionProviderConfigurationView {
267 credential_key: String,
268 extension: WasmExtension,
269 extension_provider_id: String,
270 state: Entity<ExtensionLlmProviderState>,
271 settings_markdown: Option<Entity<Markdown>>,
272 api_key_editor: Entity<Editor>,
273 loading_settings: bool,
274 loading_credentials: bool,
275 _subscriptions: Vec<Subscription>,
276}
277
278impl ExtensionProviderConfigurationView {
279 fn new(
280 credential_key: String,
281 extension: WasmExtension,
282 extension_provider_id: String,
283 state: Entity<ExtensionLlmProviderState>,
284 window: &mut Window,
285 cx: &mut Context<Self>,
286 ) -> Self {
287 // Subscribe to state changes
288 let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
289 cx.notify();
290 });
291
292 // Create API key editor
293 let api_key_editor = cx.new(|cx| {
294 let mut editor = Editor::single_line(window, cx);
295 editor.set_placeholder_text("Enter API key...", window, cx);
296 editor
297 });
298
299 let mut this = Self {
300 credential_key,
301 extension,
302 extension_provider_id,
303 state,
304 settings_markdown: None,
305 api_key_editor,
306 loading_settings: true,
307 loading_credentials: true,
308 _subscriptions: vec![state_subscription],
309 };
310
311 // Load settings text from extension
312 this.load_settings_text(cx);
313
314 // Load existing credentials
315 this.load_credentials(cx);
316
317 this
318 }
319
320 fn load_settings_text(&mut self, cx: &mut Context<Self>) {
321 let extension = self.extension.clone();
322 let provider_id = self.extension_provider_id.clone();
323
324 cx.spawn(async move |this, cx| {
325 let result = extension
326 .call({
327 let provider_id = provider_id.clone();
328 |ext, store| {
329 async move {
330 ext.call_llm_provider_settings_markdown(store, &provider_id)
331 .await
332 }
333 .boxed()
334 }
335 })
336 .await;
337
338 let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
339
340 this.update(cx, |this, cx| {
341 this.loading_settings = false;
342 if let Some(text) = settings_text {
343 let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
344 this.settings_markdown = Some(markdown);
345 }
346 cx.notify();
347 })
348 .log_err();
349 })
350 .detach();
351 }
352
353 fn load_credentials(&mut self, cx: &mut Context<Self>) {
354 let credential_key = self.credential_key.clone();
355 let credentials_provider = <dyn CredentialsProvider>::global(cx);
356 let state = self.state.clone();
357
358 cx.spawn(async move |this, cx| {
359 let credentials = credentials_provider
360 .read_credentials(&credential_key, cx)
361 .await
362 .log_err()
363 .flatten();
364
365 let has_credentials = credentials.is_some();
366
367 // Update authentication state based on stored credentials
368 let _ = cx.update(|cx| {
369 state.update(cx, |state, cx| {
370 state.is_authenticated = has_credentials;
371 cx.notify();
372 });
373 });
374
375 this.update(cx, |this, cx| {
376 this.loading_credentials = false;
377 cx.notify();
378 })
379 .log_err();
380 })
381 .detach();
382 }
383
384 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
385 let api_key = self.api_key_editor.read(cx).text(cx);
386 if api_key.is_empty() {
387 return;
388 }
389
390 // Clear the editor
391 self.api_key_editor
392 .update(cx, |editor, cx| editor.set_text("", window, cx));
393
394 let credential_key = self.credential_key.clone();
395 let credentials_provider = <dyn CredentialsProvider>::global(cx);
396 let state = self.state.clone();
397
398 cx.spawn(async move |_this, cx| {
399 // Store in system keychain
400 credentials_provider
401 .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
402 .await
403 .log_err();
404
405 // Update state to authenticated
406 let _ = cx.update(|cx| {
407 state.update(cx, |state, cx| {
408 state.is_authenticated = true;
409 cx.notify();
410 });
411 });
412 })
413 .detach();
414 }
415
416 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
417 // Clear the editor
418 self.api_key_editor
419 .update(cx, |editor, cx| editor.set_text("", window, cx));
420
421 let credential_key = self.credential_key.clone();
422 let credentials_provider = <dyn CredentialsProvider>::global(cx);
423 let state = self.state.clone();
424
425 cx.spawn(async move |_this, cx| {
426 // Delete from system keychain
427 credentials_provider
428 .delete_credentials(&credential_key, cx)
429 .await
430 .log_err();
431
432 // Update state to unauthenticated
433 let _ = cx.update(|cx| {
434 state.update(cx, |state, cx| {
435 state.is_authenticated = false;
436 cx.notify();
437 });
438 });
439 })
440 .detach();
441 }
442
443 fn is_authenticated(&self, cx: &Context<Self>) -> bool {
444 self.state.read(cx).is_authenticated
445 }
446}
447
448impl gpui::Render for ExtensionProviderConfigurationView {
449 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
450 let is_loading = self.loading_settings || self.loading_credentials;
451 let is_authenticated = self.is_authenticated(cx);
452
453 if is_loading {
454 return v_flex()
455 .gap_2()
456 .child(Label::new("Loading...").color(Color::Muted))
457 .into_any_element();
458 }
459
460 let mut content = v_flex().gap_4().size_full();
461
462 // Render settings markdown if available
463 if let Some(markdown) = &self.settings_markdown {
464 let style = settings_markdown_style(_window, cx);
465 content = content.child(
466 div()
467 .p_2()
468 .rounded_md()
469 .bg(cx.theme().colors().surface_background)
470 .child(MarkdownElement::new(markdown.clone(), style)),
471 );
472 }
473
474 // Render API key section
475 if is_authenticated {
476 content = content.child(
477 v_flex()
478 .gap_2()
479 .child(
480 h_flex()
481 .gap_2()
482 .child(
483 ui::Icon::new(ui::IconName::Check)
484 .color(Color::Success)
485 .size(ui::IconSize::Small),
486 )
487 .child(Label::new("API key configured").color(Color::Success)),
488 )
489 .child(
490 ui::Button::new("reset-api-key", "Reset API Key")
491 .style(ui::ButtonStyle::Subtle)
492 .on_click(cx.listener(|this, _, window, cx| {
493 this.reset_api_key(window, cx);
494 })),
495 ),
496 );
497 } else {
498 content = content.child(
499 v_flex()
500 .gap_2()
501 .on_action(cx.listener(Self::save_api_key))
502 .child(
503 Label::new("API Key")
504 .size(LabelSize::Small)
505 .color(Color::Muted),
506 )
507 .child(self.api_key_editor.clone())
508 .child(
509 Label::new("Enter your API key and press Enter to save")
510 .size(LabelSize::Small)
511 .color(Color::Muted),
512 ),
513 );
514 }
515
516 content.into_any_element()
517 }
518}
519
520impl Focusable for ExtensionProviderConfigurationView {
521 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
522 self.api_key_editor.focus_handle(cx)
523 }
524}
525
526fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
527 let theme_settings = ThemeSettings::get_global(cx);
528 let colors = cx.theme().colors();
529 let mut text_style = window.text_style();
530 text_style.refine(&TextStyleRefinement {
531 font_family: Some(theme_settings.ui_font.family.clone()),
532 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
533 font_features: Some(theme_settings.ui_font.features.clone()),
534 color: Some(colors.text),
535 ..Default::default()
536 });
537
538 MarkdownStyle {
539 base_text_style: text_style,
540 selection_background_color: colors.element_selection_background,
541 inline_code: TextStyleRefinement {
542 background_color: Some(colors.editor_background),
543 ..Default::default()
544 },
545 link: TextStyleRefinement {
546 color: Some(colors.text_accent),
547 underline: Some(UnderlineStyle {
548 color: Some(colors.text_accent.opacity(0.5)),
549 thickness: px(1.),
550 ..Default::default()
551 }),
552 ..Default::default()
553 },
554 syntax: cx.theme().syntax().clone(),
555 ..Default::default()
556 }
557}
558
559/// An extension-based language model.
560pub struct ExtensionLanguageModel {
561 extension: WasmExtension,
562 model_info: LlmModelInfo,
563 provider_id: LanguageModelProviderId,
564 provider_name: LanguageModelProviderName,
565 provider_info: LlmProviderInfo,
566}
567
568impl LanguageModel for ExtensionLanguageModel {
569 fn id(&self) -> LanguageModelId {
570 LanguageModelId::from(self.model_info.id.clone())
571 }
572
573 fn name(&self) -> LanguageModelName {
574 LanguageModelName::from(self.model_info.name.clone())
575 }
576
577 fn provider_id(&self) -> LanguageModelProviderId {
578 self.provider_id.clone()
579 }
580
581 fn provider_name(&self) -> LanguageModelProviderName {
582 self.provider_name.clone()
583 }
584
585 fn telemetry_id(&self) -> String {
586 format!("extension-{}", self.model_info.id)
587 }
588
589 fn supports_images(&self) -> bool {
590 self.model_info.capabilities.supports_images
591 }
592
593 fn supports_tools(&self) -> bool {
594 self.model_info.capabilities.supports_tools
595 }
596
597 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
598 match choice {
599 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
600 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
601 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
602 }
603 }
604
605 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
606 match self.model_info.capabilities.tool_input_format {
607 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
608 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
609 }
610 }
611
612 fn max_token_count(&self) -> u64 {
613 self.model_info.max_token_count
614 }
615
616 fn max_output_tokens(&self) -> Option<u64> {
617 self.model_info.max_output_tokens
618 }
619
620 fn count_tokens(
621 &self,
622 request: LanguageModelRequest,
623 cx: &App,
624 ) -> BoxFuture<'static, Result<u64>> {
625 let extension = self.extension.clone();
626 let provider_id = self.provider_info.id.clone();
627 let model_id = self.model_info.id.clone();
628
629 let wit_request = convert_request_to_wit(request);
630
631 cx.background_spawn(async move {
632 extension
633 .call({
634 let provider_id = provider_id.clone();
635 let model_id = model_id.clone();
636 let wit_request = wit_request.clone();
637 |ext, store| {
638 async move {
639 let count = ext
640 .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
641 .await?
642 .map_err(|e| anyhow!("{}", e))?;
643 Ok(count)
644 }
645 .boxed()
646 }
647 })
648 .await?
649 })
650 .boxed()
651 }
652
653 fn stream_completion(
654 &self,
655 request: LanguageModelRequest,
656 _cx: &AsyncApp,
657 ) -> BoxFuture<
658 'static,
659 Result<
660 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
661 LanguageModelCompletionError,
662 >,
663 > {
664 let extension = self.extension.clone();
665 let provider_id = self.provider_info.id.clone();
666 let model_id = self.model_info.id.clone();
667
668 let wit_request = convert_request_to_wit(request);
669
670 async move {
671 // Start the stream
672 let stream_id_result = extension
673 .call({
674 let provider_id = provider_id.clone();
675 let model_id = model_id.clone();
676 let wit_request = wit_request.clone();
677 |ext, store| {
678 async move {
679 let id = ext
680 .call_llm_stream_completion_start(
681 store,
682 &provider_id,
683 &model_id,
684 &wit_request,
685 )
686 .await?
687 .map_err(|e| anyhow!("{}", e))?;
688 Ok(id)
689 }
690 .boxed()
691 }
692 })
693 .await;
694
695 let stream_id = stream_id_result
696 .map_err(LanguageModelCompletionError::Other)?
697 .map_err(LanguageModelCompletionError::Other)?;
698
699 // Create a stream that polls for events
700 let stream = futures::stream::unfold(
701 (extension.clone(), stream_id, false),
702 move |(extension, stream_id, done)| async move {
703 if done {
704 return None;
705 }
706
707 let result = extension
708 .call({
709 let stream_id = stream_id.clone();
710 |ext, store| {
711 async move {
712 let event = ext
713 .call_llm_stream_completion_next(store, &stream_id)
714 .await?
715 .map_err(|e| anyhow!("{}", e))?;
716 Ok(event)
717 }
718 .boxed()
719 }
720 })
721 .await
722 .and_then(|inner| inner);
723
724 match result {
725 Ok(Some(event)) => {
726 let converted = convert_completion_event(event);
727 let is_done =
728 matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
729 Some((converted, (extension, stream_id, is_done)))
730 }
731 Ok(None) => {
732 // Stream complete, close it
733 let _ = extension
734 .call({
735 let stream_id = stream_id.clone();
736 |ext, store| {
737 async move {
738 ext.call_llm_stream_completion_close(store, &stream_id)
739 .await?;
740 Ok::<(), anyhow::Error>(())
741 }
742 .boxed()
743 }
744 })
745 .await;
746 None
747 }
748 Err(e) => Some((
749 Err(LanguageModelCompletionError::Other(e)),
750 (extension, stream_id, true),
751 )),
752 }
753 },
754 );
755
756 Ok(stream.boxed())
757 }
758 .boxed()
759 }
760
761 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
762 // Extensions can implement this via llm_cache_configuration
763 None
764 }
765}
766
767fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
768 use language_model::{MessageContent, Role};
769
770 let messages: Vec<LlmRequestMessage> = request
771 .messages
772 .into_iter()
773 .map(|msg| {
774 let role = match msg.role {
775 Role::User => LlmMessageRole::User,
776 Role::Assistant => LlmMessageRole::Assistant,
777 Role::System => LlmMessageRole::System,
778 };
779
780 let content: Vec<LlmMessageContent> = msg
781 .content
782 .into_iter()
783 .map(|c| match c {
784 MessageContent::Text(text) => LlmMessageContent::Text(text),
785 MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
786 source: image.source.to_string(),
787 width: Some(image.size.width.0 as u32),
788 height: Some(image.size.height.0 as u32),
789 }),
790 MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
791 id: tool_use.id.to_string(),
792 name: tool_use.name.to_string(),
793 input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
794 thought_signature: tool_use.thought_signature,
795 }),
796 MessageContent::ToolResult(tool_result) => {
797 let content = match tool_result.content {
798 language_model::LanguageModelToolResultContent::Text(text) => {
799 LlmToolResultContent::Text(text.to_string())
800 }
801 language_model::LanguageModelToolResultContent::Image(image) => {
802 LlmToolResultContent::Image(LlmImageData {
803 source: image.source.to_string(),
804 width: Some(image.size.width.0 as u32),
805 height: Some(image.size.height.0 as u32),
806 })
807 }
808 };
809 LlmMessageContent::ToolResult(LlmToolResult {
810 tool_use_id: tool_result.tool_use_id.to_string(),
811 tool_name: tool_result.tool_name.to_string(),
812 is_error: tool_result.is_error,
813 content,
814 })
815 }
816 MessageContent::Thinking { text, signature } => {
817 LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
818 }
819 MessageContent::RedactedThinking(data) => {
820 LlmMessageContent::RedactedThinking(data)
821 }
822 })
823 .collect();
824
825 LlmRequestMessage {
826 role,
827 content,
828 cache: msg.cache,
829 }
830 })
831 .collect();
832
833 let tools: Vec<LlmToolDefinition> = request
834 .tools
835 .into_iter()
836 .map(|tool| LlmToolDefinition {
837 name: tool.name,
838 description: tool.description,
839 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
840 })
841 .collect();
842
843 let tool_choice = request.tool_choice.map(|tc| match tc {
844 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
845 LanguageModelToolChoice::Any => LlmToolChoice::Any,
846 LanguageModelToolChoice::None => LlmToolChoice::None,
847 });
848
849 LlmCompletionRequest {
850 messages,
851 tools,
852 tool_choice,
853 stop_sequences: request.stop,
854 temperature: request.temperature,
855 thinking_allowed: false,
856 max_tokens: None,
857 }
858}
859
860fn convert_completion_event(
861 event: LlmCompletionEvent,
862) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
863 match event {
864 LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
865 message_id: String::new(),
866 }),
867 LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
868 LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
869 text: thinking.text,
870 signature: thinking.signature,
871 }),
872 LlmCompletionEvent::RedactedThinking(data) => {
873 Ok(LanguageModelCompletionEvent::RedactedThinking { data })
874 }
875 LlmCompletionEvent::ToolUse(tool_use) => {
876 let raw_input = tool_use.input.clone();
877 let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
878 Ok(LanguageModelCompletionEvent::ToolUse(
879 LanguageModelToolUse {
880 id: LanguageModelToolUseId::from(tool_use.id),
881 name: tool_use.name.into(),
882 raw_input,
883 input,
884 is_input_complete: true,
885 thought_signature: tool_use.thought_signature,
886 },
887 ))
888 }
889 LlmCompletionEvent::ToolUseJsonParseError(error) => {
890 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
891 id: LanguageModelToolUseId::from(error.id),
892 tool_name: error.tool_name.into(),
893 raw_input: error.raw_input.into(),
894 json_parse_error: error.error,
895 })
896 }
897 LlmCompletionEvent::Stop(reason) => {
898 let stop_reason = match reason {
899 LlmStopReason::EndTurn => StopReason::EndTurn,
900 LlmStopReason::MaxTokens => StopReason::MaxTokens,
901 LlmStopReason::ToolUse => StopReason::ToolUse,
902 LlmStopReason::Refusal => StopReason::Refusal,
903 };
904 Ok(LanguageModelCompletionEvent::Stop(stop_reason))
905 }
906 LlmCompletionEvent::Usage(usage) => {
907 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
908 input_tokens: usage.input_tokens,
909 output_tokens: usage.output_tokens,
910 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
911 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
912 }))
913 }
914 LlmCompletionEvent::ReasoningDetails(json) => {
915 Ok(LanguageModelCompletionEvent::ReasoningDetails(
916 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
917 ))
918 }
919 }
920}