llm_provider.rs

  1use crate::wasm_host::WasmExtension;
  2
  3use crate::wasm_host::wit::{
  4    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
  5    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
  6    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  7    LlmToolUse,
  8};
  9use anyhow::{Result, anyhow};
 10use credentials_provider::CredentialsProvider;
 11use editor::Editor;
 12use futures::future::BoxFuture;
 13use futures::stream::BoxStream;
 14use futures::{FutureExt, StreamExt};
 15use gpui::Focusable;
 16use gpui::{
 17    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
 18    TextStyleRefinement, UnderlineStyle, Window, px,
 19};
 20use language_model::tool_schema::LanguageModelToolSchemaFormat;
 21use language_model::{
 22    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
 23    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 24    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 25    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 26    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
 27};
 28use markdown::{Markdown, MarkdownElement, MarkdownStyle};
 29use settings::Settings;
 30use std::sync::Arc;
 31use theme::ThemeSettings;
 32use ui::{Label, LabelSize, prelude::*};
 33use util::ResultExt as _;
 34
 35/// An extension-based language model provider.
 36pub struct ExtensionLanguageModelProvider {
 37    pub extension: WasmExtension,
 38    pub provider_info: LlmProviderInfo,
 39    state: Entity<ExtensionLlmProviderState>,
 40}
 41
 42pub struct ExtensionLlmProviderState {
 43    is_authenticated: bool,
 44    available_models: Vec<LlmModelInfo>,
 45}
 46
 47impl EventEmitter<()> for ExtensionLlmProviderState {}
 48
 49impl ExtensionLanguageModelProvider {
 50    pub fn new(
 51        extension: WasmExtension,
 52        provider_info: LlmProviderInfo,
 53        models: Vec<LlmModelInfo>,
 54        is_authenticated: bool,
 55        cx: &mut App,
 56    ) -> Self {
 57        let state = cx.new(|_| ExtensionLlmProviderState {
 58            is_authenticated,
 59            available_models: models,
 60        });
 61
 62        Self {
 63            extension,
 64            provider_info,
 65            state,
 66        }
 67    }
 68
 69    fn provider_id_string(&self) -> String {
 70        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 71    }
 72
 73    /// The credential key used for storing the API key in the system keychain.
 74    fn credential_key(&self) -> String {
 75        format!("extension-llm-{}", self.provider_id_string())
 76    }
 77}
 78
 79impl LanguageModelProvider for ExtensionLanguageModelProvider {
 80    fn id(&self) -> LanguageModelProviderId {
 81        LanguageModelProviderId::from(self.provider_id_string())
 82    }
 83
 84    fn name(&self) -> LanguageModelProviderName {
 85        LanguageModelProviderName::from(self.provider_info.name.clone())
 86    }
 87
 88    fn icon(&self) -> ui::IconName {
 89        ui::IconName::ZedAssistant
 90    }
 91
 92    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 93        let state = self.state.read(cx);
 94        state
 95            .available_models
 96            .iter()
 97            .find(|m| m.is_default)
 98            .or_else(|| state.available_models.first())
 99            .map(|model_info| {
100                Arc::new(ExtensionLanguageModel {
101                    extension: self.extension.clone(),
102                    model_info: model_info.clone(),
103                    provider_id: self.id(),
104                    provider_name: self.name(),
105                    provider_info: self.provider_info.clone(),
106                }) as Arc<dyn LanguageModel>
107            })
108    }
109
110    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
111        let state = self.state.read(cx);
112        state
113            .available_models
114            .iter()
115            .find(|m| m.is_default_fast)
116            .map(|model_info| {
117                Arc::new(ExtensionLanguageModel {
118                    extension: self.extension.clone(),
119                    model_info: model_info.clone(),
120                    provider_id: self.id(),
121                    provider_name: self.name(),
122                    provider_info: self.provider_info.clone(),
123                }) as Arc<dyn LanguageModel>
124            })
125    }
126
127    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
128        let state = self.state.read(cx);
129        state
130            .available_models
131            .iter()
132            .map(|model_info| {
133                Arc::new(ExtensionLanguageModel {
134                    extension: self.extension.clone(),
135                    model_info: model_info.clone(),
136                    provider_id: self.id(),
137                    provider_name: self.name(),
138                    provider_info: self.provider_info.clone(),
139                }) as Arc<dyn LanguageModel>
140            })
141            .collect()
142    }
143
144    fn is_authenticated(&self, cx: &App) -> bool {
145        self.state.read(cx).is_authenticated
146    }
147
148    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
149        let extension = self.extension.clone();
150        let provider_id = self.provider_info.id.clone();
151        let state = self.state.clone();
152
153        cx.spawn(async move |cx| {
154            let result = extension
155                .call(|extension, store| {
156                    async move {
157                        extension
158                            .call_llm_provider_authenticate(store, &provider_id)
159                            .await
160                    }
161                    .boxed()
162                })
163                .await;
164
165            match result {
166                Ok(Ok(Ok(()))) => {
167                    cx.update(|cx| {
168                        state.update(cx, |state, _| {
169                            state.is_authenticated = true;
170                        });
171                    })?;
172                    Ok(())
173                }
174                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
175                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
176                Err(e) => Err(AuthenticateError::Other(e)),
177            }
178        })
179    }
180
181    fn configuration_view(
182        &self,
183        _target_agent: ConfigurationViewTargetAgent,
184        window: &mut Window,
185        cx: &mut App,
186    ) -> AnyView {
187        let credential_key = self.credential_key();
188        let extension = self.extension.clone();
189        let extension_provider_id = self.provider_info.id.clone();
190        let state = self.state.clone();
191
192        cx.new(|cx| {
193            ExtensionProviderConfigurationView::new(
194                credential_key,
195                extension,
196                extension_provider_id,
197                state,
198                window,
199                cx,
200            )
201        })
202        .into()
203    }
204
205    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
206        let extension = self.extension.clone();
207        let provider_id = self.provider_info.id.clone();
208        let state = self.state.clone();
209        let credential_key = self.credential_key();
210
211        let credentials_provider = <dyn CredentialsProvider>::global(cx);
212
213        cx.spawn(async move |cx| {
214            // Delete from system keychain
215            credentials_provider
216                .delete_credentials(&credential_key, cx)
217                .await
218                .log_err();
219
220            // Call extension's reset_credentials
221            let result = extension
222                .call(|extension, store| {
223                    async move {
224                        extension
225                            .call_llm_provider_reset_credentials(store, &provider_id)
226                            .await
227                    }
228                    .boxed()
229                })
230                .await;
231
232            // Update state
233            cx.update(|cx| {
234                state.update(cx, |state, _| {
235                    state.is_authenticated = false;
236                });
237            })?;
238
239            match result {
240                Ok(Ok(Ok(()))) => Ok(()),
241                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
242                Ok(Err(e)) => Err(e),
243                Err(e) => Err(e),
244            }
245        })
246    }
247}
248
249impl LanguageModelProviderState for ExtensionLanguageModelProvider {
250    type ObservableEntity = ExtensionLlmProviderState;
251
252    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
253        Some(self.state.clone())
254    }
255
256    fn subscribe<T: 'static>(
257        &self,
258        cx: &mut Context<T>,
259        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
260    ) -> Option<Subscription> {
261        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
262    }
263}
264
265/// Configuration view for extension-based LLM providers.
266struct ExtensionProviderConfigurationView {
267    credential_key: String,
268    extension: WasmExtension,
269    extension_provider_id: String,
270    state: Entity<ExtensionLlmProviderState>,
271    settings_markdown: Option<Entity<Markdown>>,
272    api_key_editor: Entity<Editor>,
273    loading_settings: bool,
274    loading_credentials: bool,
275    _subscriptions: Vec<Subscription>,
276}
277
278impl ExtensionProviderConfigurationView {
279    fn new(
280        credential_key: String,
281        extension: WasmExtension,
282        extension_provider_id: String,
283        state: Entity<ExtensionLlmProviderState>,
284        window: &mut Window,
285        cx: &mut Context<Self>,
286    ) -> Self {
287        // Subscribe to state changes
288        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
289            cx.notify();
290        });
291
292        // Create API key editor
293        let api_key_editor = cx.new(|cx| {
294            let mut editor = Editor::single_line(window, cx);
295            editor.set_placeholder_text("Enter API key...", window, cx);
296            editor
297        });
298
299        let mut this = Self {
300            credential_key,
301            extension,
302            extension_provider_id,
303            state,
304            settings_markdown: None,
305            api_key_editor,
306            loading_settings: true,
307            loading_credentials: true,
308            _subscriptions: vec![state_subscription],
309        };
310
311        // Load settings text from extension
312        this.load_settings_text(cx);
313
314        // Load existing credentials
315        this.load_credentials(cx);
316
317        this
318    }
319
320    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
321        let extension = self.extension.clone();
322        let provider_id = self.extension_provider_id.clone();
323
324        cx.spawn(async move |this, cx| {
325            let result = extension
326                .call({
327                    let provider_id = provider_id.clone();
328                    |ext, store| {
329                        async move {
330                            ext.call_llm_provider_settings_markdown(store, &provider_id)
331                                .await
332                        }
333                        .boxed()
334                    }
335                })
336                .await;
337
338            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
339
340            this.update(cx, |this, cx| {
341                this.loading_settings = false;
342                if let Some(text) = settings_text {
343                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
344                    this.settings_markdown = Some(markdown);
345                }
346                cx.notify();
347            })
348            .log_err();
349        })
350        .detach();
351    }
352
353    fn load_credentials(&mut self, cx: &mut Context<Self>) {
354        let credential_key = self.credential_key.clone();
355        let credentials_provider = <dyn CredentialsProvider>::global(cx);
356        let state = self.state.clone();
357
358        cx.spawn(async move |this, cx| {
359            let credentials = credentials_provider
360                .read_credentials(&credential_key, cx)
361                .await
362                .log_err()
363                .flatten();
364
365            let has_credentials = credentials.is_some();
366
367            // Update authentication state based on stored credentials
368            let _ = cx.update(|cx| {
369                state.update(cx, |state, cx| {
370                    state.is_authenticated = has_credentials;
371                    cx.notify();
372                });
373            });
374
375            this.update(cx, |this, cx| {
376                this.loading_credentials = false;
377                cx.notify();
378            })
379            .log_err();
380        })
381        .detach();
382    }
383
384    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
385        let api_key = self.api_key_editor.read(cx).text(cx);
386        if api_key.is_empty() {
387            return;
388        }
389
390        // Clear the editor
391        self.api_key_editor
392            .update(cx, |editor, cx| editor.set_text("", window, cx));
393
394        let credential_key = self.credential_key.clone();
395        let credentials_provider = <dyn CredentialsProvider>::global(cx);
396        let state = self.state.clone();
397
398        cx.spawn(async move |_this, cx| {
399            // Store in system keychain
400            credentials_provider
401                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
402                .await
403                .log_err();
404
405            // Update state to authenticated
406            let _ = cx.update(|cx| {
407                state.update(cx, |state, cx| {
408                    state.is_authenticated = true;
409                    cx.notify();
410                });
411            });
412        })
413        .detach();
414    }
415
416    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
417        // Clear the editor
418        self.api_key_editor
419            .update(cx, |editor, cx| editor.set_text("", window, cx));
420
421        let credential_key = self.credential_key.clone();
422        let credentials_provider = <dyn CredentialsProvider>::global(cx);
423        let state = self.state.clone();
424
425        cx.spawn(async move |_this, cx| {
426            // Delete from system keychain
427            credentials_provider
428                .delete_credentials(&credential_key, cx)
429                .await
430                .log_err();
431
432            // Update state to unauthenticated
433            let _ = cx.update(|cx| {
434                state.update(cx, |state, cx| {
435                    state.is_authenticated = false;
436                    cx.notify();
437                });
438            });
439        })
440        .detach();
441    }
442
443    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
444        self.state.read(cx).is_authenticated
445    }
446}
447
448impl gpui::Render for ExtensionProviderConfigurationView {
449    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
450        let is_loading = self.loading_settings || self.loading_credentials;
451        let is_authenticated = self.is_authenticated(cx);
452
453        if is_loading {
454            return v_flex()
455                .gap_2()
456                .child(Label::new("Loading...").color(Color::Muted))
457                .into_any_element();
458        }
459
460        let mut content = v_flex().gap_4().size_full();
461
462        // Render settings markdown if available
463        if let Some(markdown) = &self.settings_markdown {
464            let style = settings_markdown_style(_window, cx);
465            content = content.child(
466                div()
467                    .p_2()
468                    .rounded_md()
469                    .bg(cx.theme().colors().surface_background)
470                    .child(MarkdownElement::new(markdown.clone(), style)),
471            );
472        }
473
474        // Render API key section
475        if is_authenticated {
476            content = content.child(
477                v_flex()
478                    .gap_2()
479                    .child(
480                        h_flex()
481                            .gap_2()
482                            .child(
483                                ui::Icon::new(ui::IconName::Check)
484                                    .color(Color::Success)
485                                    .size(ui::IconSize::Small),
486                            )
487                            .child(Label::new("API key configured").color(Color::Success)),
488                    )
489                    .child(
490                        ui::Button::new("reset-api-key", "Reset API Key")
491                            .style(ui::ButtonStyle::Subtle)
492                            .on_click(cx.listener(|this, _, window, cx| {
493                                this.reset_api_key(window, cx);
494                            })),
495                    ),
496            );
497        } else {
498            content = content.child(
499                v_flex()
500                    .gap_2()
501                    .on_action(cx.listener(Self::save_api_key))
502                    .child(
503                        Label::new("API Key")
504                            .size(LabelSize::Small)
505                            .color(Color::Muted),
506                    )
507                    .child(self.api_key_editor.clone())
508                    .child(
509                        Label::new("Enter your API key and press Enter to save")
510                            .size(LabelSize::Small)
511                            .color(Color::Muted),
512                    ),
513            );
514        }
515
516        content.into_any_element()
517    }
518}
519
520impl Focusable for ExtensionProviderConfigurationView {
521    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
522        self.api_key_editor.focus_handle(cx)
523    }
524}
525
526fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
527    let theme_settings = ThemeSettings::get_global(cx);
528    let colors = cx.theme().colors();
529    let mut text_style = window.text_style();
530    text_style.refine(&TextStyleRefinement {
531        font_family: Some(theme_settings.ui_font.family.clone()),
532        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
533        font_features: Some(theme_settings.ui_font.features.clone()),
534        color: Some(colors.text),
535        ..Default::default()
536    });
537
538    MarkdownStyle {
539        base_text_style: text_style,
540        selection_background_color: colors.element_selection_background,
541        inline_code: TextStyleRefinement {
542            background_color: Some(colors.editor_background),
543            ..Default::default()
544        },
545        link: TextStyleRefinement {
546            color: Some(colors.text_accent),
547            underline: Some(UnderlineStyle {
548                color: Some(colors.text_accent.opacity(0.5)),
549                thickness: px(1.),
550                ..Default::default()
551            }),
552            ..Default::default()
553        },
554        syntax: cx.theme().syntax().clone(),
555        ..Default::default()
556    }
557}
558
559/// An extension-based language model.
560pub struct ExtensionLanguageModel {
561    extension: WasmExtension,
562    model_info: LlmModelInfo,
563    provider_id: LanguageModelProviderId,
564    provider_name: LanguageModelProviderName,
565    provider_info: LlmProviderInfo,
566}
567
568impl LanguageModel for ExtensionLanguageModel {
569    fn id(&self) -> LanguageModelId {
570        LanguageModelId::from(self.model_info.id.clone())
571    }
572
573    fn name(&self) -> LanguageModelName {
574        LanguageModelName::from(self.model_info.name.clone())
575    }
576
577    fn provider_id(&self) -> LanguageModelProviderId {
578        self.provider_id.clone()
579    }
580
581    fn provider_name(&self) -> LanguageModelProviderName {
582        self.provider_name.clone()
583    }
584
585    fn telemetry_id(&self) -> String {
586        format!("extension-{}", self.model_info.id)
587    }
588
589    fn supports_images(&self) -> bool {
590        self.model_info.capabilities.supports_images
591    }
592
593    fn supports_tools(&self) -> bool {
594        self.model_info.capabilities.supports_tools
595    }
596
597    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
598        match choice {
599            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
600            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
601            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
602        }
603    }
604
605    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
606        match self.model_info.capabilities.tool_input_format {
607            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
608            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
609        }
610    }
611
612    fn max_token_count(&self) -> u64 {
613        self.model_info.max_token_count
614    }
615
616    fn max_output_tokens(&self) -> Option<u64> {
617        self.model_info.max_output_tokens
618    }
619
620    fn count_tokens(
621        &self,
622        request: LanguageModelRequest,
623        cx: &App,
624    ) -> BoxFuture<'static, Result<u64>> {
625        let extension = self.extension.clone();
626        let provider_id = self.provider_info.id.clone();
627        let model_id = self.model_info.id.clone();
628
629        let wit_request = convert_request_to_wit(request);
630
631        cx.background_spawn(async move {
632            extension
633                .call({
634                    let provider_id = provider_id.clone();
635                    let model_id = model_id.clone();
636                    let wit_request = wit_request.clone();
637                    |ext, store| {
638                        async move {
639                            let count = ext
640                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
641                                .await?
642                                .map_err(|e| anyhow!("{}", e))?;
643                            Ok(count)
644                        }
645                        .boxed()
646                    }
647                })
648                .await?
649        })
650        .boxed()
651    }
652
653    fn stream_completion(
654        &self,
655        request: LanguageModelRequest,
656        _cx: &AsyncApp,
657    ) -> BoxFuture<
658        'static,
659        Result<
660            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
661            LanguageModelCompletionError,
662        >,
663    > {
664        let extension = self.extension.clone();
665        let provider_id = self.provider_info.id.clone();
666        let model_id = self.model_info.id.clone();
667
668        let wit_request = convert_request_to_wit(request);
669
670        async move {
671            // Start the stream
672            let stream_id_result = extension
673                .call({
674                    let provider_id = provider_id.clone();
675                    let model_id = model_id.clone();
676                    let wit_request = wit_request.clone();
677                    |ext, store| {
678                        async move {
679                            let id = ext
680                                .call_llm_stream_completion_start(
681                                    store,
682                                    &provider_id,
683                                    &model_id,
684                                    &wit_request,
685                                )
686                                .await?
687                                .map_err(|e| anyhow!("{}", e))?;
688                            Ok(id)
689                        }
690                        .boxed()
691                    }
692                })
693                .await;
694
695            let stream_id = stream_id_result
696                .map_err(LanguageModelCompletionError::Other)?
697                .map_err(LanguageModelCompletionError::Other)?;
698
699            // Create a stream that polls for events
700            let stream = futures::stream::unfold(
701                (extension.clone(), stream_id, false),
702                move |(extension, stream_id, done)| async move {
703                    if done {
704                        return None;
705                    }
706
707                    let result = extension
708                        .call({
709                            let stream_id = stream_id.clone();
710                            |ext, store| {
711                                async move {
712                                    let event = ext
713                                        .call_llm_stream_completion_next(store, &stream_id)
714                                        .await?
715                                        .map_err(|e| anyhow!("{}", e))?;
716                                    Ok(event)
717                                }
718                                .boxed()
719                            }
720                        })
721                        .await
722                        .and_then(|inner| inner);
723
724                    match result {
725                        Ok(Some(event)) => {
726                            let converted = convert_completion_event(event);
727                            let is_done =
728                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
729                            Some((converted, (extension, stream_id, is_done)))
730                        }
731                        Ok(None) => {
732                            // Stream complete, close it
733                            let _ = extension
734                                .call({
735                                    let stream_id = stream_id.clone();
736                                    |ext, store| {
737                                        async move {
738                                            ext.call_llm_stream_completion_close(store, &stream_id)
739                                                .await?;
740                                            Ok::<(), anyhow::Error>(())
741                                        }
742                                        .boxed()
743                                    }
744                                })
745                                .await;
746                            None
747                        }
748                        Err(e) => Some((
749                            Err(LanguageModelCompletionError::Other(e)),
750                            (extension, stream_id, true),
751                        )),
752                    }
753                },
754            );
755
756            Ok(stream.boxed())
757        }
758        .boxed()
759    }
760
761    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
762        // Extensions can implement this via llm_cache_configuration
763        None
764    }
765}
766
767fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
768    use language_model::{MessageContent, Role};
769
770    let messages: Vec<LlmRequestMessage> = request
771        .messages
772        .into_iter()
773        .map(|msg| {
774            let role = match msg.role {
775                Role::User => LlmMessageRole::User,
776                Role::Assistant => LlmMessageRole::Assistant,
777                Role::System => LlmMessageRole::System,
778            };
779
780            let content: Vec<LlmMessageContent> = msg
781                .content
782                .into_iter()
783                .map(|c| match c {
784                    MessageContent::Text(text) => LlmMessageContent::Text(text),
785                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
786                        source: image.source.to_string(),
787                        width: Some(image.size.width.0 as u32),
788                        height: Some(image.size.height.0 as u32),
789                    }),
790                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
791                        id: tool_use.id.to_string(),
792                        name: tool_use.name.to_string(),
793                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
794                        thought_signature: tool_use.thought_signature,
795                    }),
796                    MessageContent::ToolResult(tool_result) => {
797                        let content = match tool_result.content {
798                            language_model::LanguageModelToolResultContent::Text(text) => {
799                                LlmToolResultContent::Text(text.to_string())
800                            }
801                            language_model::LanguageModelToolResultContent::Image(image) => {
802                                LlmToolResultContent::Image(LlmImageData {
803                                    source: image.source.to_string(),
804                                    width: Some(image.size.width.0 as u32),
805                                    height: Some(image.size.height.0 as u32),
806                                })
807                            }
808                        };
809                        LlmMessageContent::ToolResult(LlmToolResult {
810                            tool_use_id: tool_result.tool_use_id.to_string(),
811                            tool_name: tool_result.tool_name.to_string(),
812                            is_error: tool_result.is_error,
813                            content,
814                        })
815                    }
816                    MessageContent::Thinking { text, signature } => {
817                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
818                    }
819                    MessageContent::RedactedThinking(data) => {
820                        LlmMessageContent::RedactedThinking(data)
821                    }
822                })
823                .collect();
824
825            LlmRequestMessage {
826                role,
827                content,
828                cache: msg.cache,
829            }
830        })
831        .collect();
832
833    let tools: Vec<LlmToolDefinition> = request
834        .tools
835        .into_iter()
836        .map(|tool| LlmToolDefinition {
837            name: tool.name,
838            description: tool.description,
839            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
840        })
841        .collect();
842
843    let tool_choice = request.tool_choice.map(|tc| match tc {
844        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
845        LanguageModelToolChoice::Any => LlmToolChoice::Any,
846        LanguageModelToolChoice::None => LlmToolChoice::None,
847    });
848
849    LlmCompletionRequest {
850        messages,
851        tools,
852        tool_choice,
853        stop_sequences: request.stop,
854        temperature: request.temperature,
855        thinking_allowed: false,
856        max_tokens: None,
857    }
858}
859
860fn convert_completion_event(
861    event: LlmCompletionEvent,
862) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
863    match event {
864        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
865            message_id: String::new(),
866        }),
867        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
868        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
869            text: thinking.text,
870            signature: thinking.signature,
871        }),
872        LlmCompletionEvent::RedactedThinking(data) => {
873            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
874        }
875        LlmCompletionEvent::ToolUse(tool_use) => {
876            let raw_input = tool_use.input.clone();
877            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
878            Ok(LanguageModelCompletionEvent::ToolUse(
879                LanguageModelToolUse {
880                    id: LanguageModelToolUseId::from(tool_use.id),
881                    name: tool_use.name.into(),
882                    raw_input,
883                    input,
884                    is_input_complete: true,
885                    thought_signature: tool_use.thought_signature,
886                },
887            ))
888        }
889        LlmCompletionEvent::ToolUseJsonParseError(error) => {
890            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
891                id: LanguageModelToolUseId::from(error.id),
892                tool_name: error.tool_name.into(),
893                raw_input: error.raw_input.into(),
894                json_parse_error: error.error,
895            })
896        }
897        LlmCompletionEvent::Stop(reason) => {
898            let stop_reason = match reason {
899                LlmStopReason::EndTurn => StopReason::EndTurn,
900                LlmStopReason::MaxTokens => StopReason::MaxTokens,
901                LlmStopReason::ToolUse => StopReason::ToolUse,
902                LlmStopReason::Refusal => StopReason::Refusal,
903            };
904            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
905        }
906        LlmCompletionEvent::Usage(usage) => {
907            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
908                input_tokens: usage.input_tokens,
909                output_tokens: usage.output_tokens,
910                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
911                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
912            }))
913        }
914        LlmCompletionEvent::ReasoningDetails(json) => {
915            Ok(LanguageModelCompletionEvent::ReasoningDetails(
916                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
917            ))
918        }
919    }
920}