llm_provider.rs

  1use crate::wasm_host::WasmExtension;
  2
  3use crate::wasm_host::wit::{
  4    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
  5    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
  6    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  7    LlmToolUse,
  8};
  9use anyhow::{Result, anyhow};
 10use credentials_provider::CredentialsProvider;
 11use editor::Editor;
 12use futures::future::BoxFuture;
 13use futures::stream::BoxStream;
 14use futures::{FutureExt, StreamExt};
 15use gpui::Focusable;
 16use gpui::{
 17    AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
 18    TextStyleRefinement, UnderlineStyle, Window, px,
 19};
 20use language_model::tool_schema::LanguageModelToolSchemaFormat;
 21use language_model::{
 22    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
 23    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 24    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 25    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 26    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
 27};
 28use markdown::{Markdown, MarkdownElement, MarkdownStyle};
 29use settings::Settings;
 30use std::sync::Arc;
 31use theme::ThemeSettings;
 32use ui::{Label, LabelSize, prelude::*};
 33use util::ResultExt as _;
 34
 35/// An extension-based language model provider.
 36pub struct ExtensionLanguageModelProvider {
 37    pub extension: WasmExtension,
 38    pub provider_info: LlmProviderInfo,
 39    icon_path: Option<SharedString>,
 40    state: Entity<ExtensionLlmProviderState>,
 41}
 42
 43pub struct ExtensionLlmProviderState {
 44    is_authenticated: bool,
 45    available_models: Vec<LlmModelInfo>,
 46}
 47
 48impl EventEmitter<()> for ExtensionLlmProviderState {}
 49
 50impl ExtensionLanguageModelProvider {
 51    pub fn new(
 52        extension: WasmExtension,
 53        provider_info: LlmProviderInfo,
 54        models: Vec<LlmModelInfo>,
 55        is_authenticated: bool,
 56        icon_path: Option<SharedString>,
 57        cx: &mut App,
 58    ) -> Self {
 59        let state = cx.new(|_| ExtensionLlmProviderState {
 60            is_authenticated,
 61            available_models: models,
 62        });
 63
 64        Self {
 65            extension,
 66            provider_info,
 67            icon_path,
 68            state,
 69        }
 70    }
 71
 72    fn provider_id_string(&self) -> String {
 73        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 74    }
 75
 76    /// The credential key used for storing the API key in the system keychain.
 77    fn credential_key(&self) -> String {
 78        format!("extension-llm-{}", self.provider_id_string())
 79    }
 80}
 81
 82impl LanguageModelProvider for ExtensionLanguageModelProvider {
 83    fn id(&self) -> LanguageModelProviderId {
 84        LanguageModelProviderId::from(self.provider_id_string())
 85    }
 86
 87    fn name(&self) -> LanguageModelProviderName {
 88        LanguageModelProviderName::from(self.provider_info.name.clone())
 89    }
 90
 91    fn icon(&self) -> ui::IconName {
 92        ui::IconName::ZedAssistant
 93    }
 94
 95    fn icon_path(&self) -> Option<SharedString> {
 96        self.icon_path.clone()
 97    }
 98
 99    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
100        let state = self.state.read(cx);
101        state
102            .available_models
103            .iter()
104            .find(|m| m.is_default)
105            .or_else(|| state.available_models.first())
106            .map(|model_info| {
107                Arc::new(ExtensionLanguageModel {
108                    extension: self.extension.clone(),
109                    model_info: model_info.clone(),
110                    provider_id: self.id(),
111                    provider_name: self.name(),
112                    provider_info: self.provider_info.clone(),
113                }) as Arc<dyn LanguageModel>
114            })
115    }
116
117    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
118        let state = self.state.read(cx);
119        state
120            .available_models
121            .iter()
122            .find(|m| m.is_default_fast)
123            .map(|model_info| {
124                Arc::new(ExtensionLanguageModel {
125                    extension: self.extension.clone(),
126                    model_info: model_info.clone(),
127                    provider_id: self.id(),
128                    provider_name: self.name(),
129                    provider_info: self.provider_info.clone(),
130                }) as Arc<dyn LanguageModel>
131            })
132    }
133
134    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
135        let state = self.state.read(cx);
136        state
137            .available_models
138            .iter()
139            .map(|model_info| {
140                Arc::new(ExtensionLanguageModel {
141                    extension: self.extension.clone(),
142                    model_info: model_info.clone(),
143                    provider_id: self.id(),
144                    provider_name: self.name(),
145                    provider_info: self.provider_info.clone(),
146                }) as Arc<dyn LanguageModel>
147            })
148            .collect()
149    }
150
151    fn is_authenticated(&self, cx: &App) -> bool {
152        self.state.read(cx).is_authenticated
153    }
154
155    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
156        let extension = self.extension.clone();
157        let provider_id = self.provider_info.id.clone();
158        let state = self.state.clone();
159
160        cx.spawn(async move |cx| {
161            let result = extension
162                .call(|extension, store| {
163                    async move {
164                        extension
165                            .call_llm_provider_authenticate(store, &provider_id)
166                            .await
167                    }
168                    .boxed()
169                })
170                .await;
171
172            match result {
173                Ok(Ok(Ok(()))) => {
174                    cx.update(|cx| {
175                        state.update(cx, |state, _| {
176                            state.is_authenticated = true;
177                        });
178                    })?;
179                    Ok(())
180                }
181                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
182                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
183                Err(e) => Err(AuthenticateError::Other(e)),
184            }
185        })
186    }
187
188    fn configuration_view(
189        &self,
190        _target_agent: ConfigurationViewTargetAgent,
191        window: &mut Window,
192        cx: &mut App,
193    ) -> AnyView {
194        let credential_key = self.credential_key();
195        let extension = self.extension.clone();
196        let extension_provider_id = self.provider_info.id.clone();
197        let state = self.state.clone();
198
199        cx.new(|cx| {
200            ExtensionProviderConfigurationView::new(
201                credential_key,
202                extension,
203                extension_provider_id,
204                state,
205                window,
206                cx,
207            )
208        })
209        .into()
210    }
211
212    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
213        let extension = self.extension.clone();
214        let provider_id = self.provider_info.id.clone();
215        let state = self.state.clone();
216        let credential_key = self.credential_key();
217
218        let credentials_provider = <dyn CredentialsProvider>::global(cx);
219
220        cx.spawn(async move |cx| {
221            // Delete from system keychain
222            credentials_provider
223                .delete_credentials(&credential_key, cx)
224                .await
225                .log_err();
226
227            // Call extension's reset_credentials
228            let result = extension
229                .call(|extension, store| {
230                    async move {
231                        extension
232                            .call_llm_provider_reset_credentials(store, &provider_id)
233                            .await
234                    }
235                    .boxed()
236                })
237                .await;
238
239            // Update state
240            cx.update(|cx| {
241                state.update(cx, |state, _| {
242                    state.is_authenticated = false;
243                });
244            })?;
245
246            match result {
247                Ok(Ok(Ok(()))) => Ok(()),
248                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
249                Ok(Err(e)) => Err(e),
250                Err(e) => Err(e),
251            }
252        })
253    }
254}
255
256impl LanguageModelProviderState for ExtensionLanguageModelProvider {
257    type ObservableEntity = ExtensionLlmProviderState;
258
259    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
260        Some(self.state.clone())
261    }
262
263    fn subscribe<T: 'static>(
264        &self,
265        cx: &mut Context<T>,
266        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
267    ) -> Option<Subscription> {
268        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
269    }
270}
271
272/// Configuration view for extension-based LLM providers.
273struct ExtensionProviderConfigurationView {
274    credential_key: String,
275    extension: WasmExtension,
276    extension_provider_id: String,
277    state: Entity<ExtensionLlmProviderState>,
278    settings_markdown: Option<Entity<Markdown>>,
279    api_key_editor: Entity<Editor>,
280    loading_settings: bool,
281    loading_credentials: bool,
282    _subscriptions: Vec<Subscription>,
283}
284
285impl ExtensionProviderConfigurationView {
286    fn new(
287        credential_key: String,
288        extension: WasmExtension,
289        extension_provider_id: String,
290        state: Entity<ExtensionLlmProviderState>,
291        window: &mut Window,
292        cx: &mut Context<Self>,
293    ) -> Self {
294        // Subscribe to state changes
295        let state_subscription = cx.subscribe(&state, |_, _, _, cx| {
296            cx.notify();
297        });
298
299        // Create API key editor
300        let api_key_editor = cx.new(|cx| {
301            let mut editor = Editor::single_line(window, cx);
302            editor.set_placeholder_text("Enter API key...", window, cx);
303            editor
304        });
305
306        let mut this = Self {
307            credential_key,
308            extension,
309            extension_provider_id,
310            state,
311            settings_markdown: None,
312            api_key_editor,
313            loading_settings: true,
314            loading_credentials: true,
315            _subscriptions: vec![state_subscription],
316        };
317
318        // Load settings text from extension
319        this.load_settings_text(cx);
320
321        // Load existing credentials
322        this.load_credentials(cx);
323
324        this
325    }
326
327    fn load_settings_text(&mut self, cx: &mut Context<Self>) {
328        let extension = self.extension.clone();
329        let provider_id = self.extension_provider_id.clone();
330
331        cx.spawn(async move |this, cx| {
332            let result = extension
333                .call({
334                    let provider_id = provider_id.clone();
335                    |ext, store| {
336                        async move {
337                            ext.call_llm_provider_settings_markdown(store, &provider_id)
338                                .await
339                        }
340                        .boxed()
341                    }
342                })
343                .await;
344
345            let settings_text = result.ok().and_then(|inner| inner.ok()).flatten();
346
347            this.update(cx, |this, cx| {
348                this.loading_settings = false;
349                if let Some(text) = settings_text {
350                    let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx));
351                    this.settings_markdown = Some(markdown);
352                }
353                cx.notify();
354            })
355            .log_err();
356        })
357        .detach();
358    }
359
360    fn load_credentials(&mut self, cx: &mut Context<Self>) {
361        let credential_key = self.credential_key.clone();
362        let credentials_provider = <dyn CredentialsProvider>::global(cx);
363        let state = self.state.clone();
364
365        cx.spawn(async move |this, cx| {
366            let credentials = credentials_provider
367                .read_credentials(&credential_key, cx)
368                .await
369                .log_err()
370                .flatten();
371
372            let has_credentials = credentials.is_some();
373
374            // Update authentication state based on stored credentials
375            let _ = cx.update(|cx| {
376                state.update(cx, |state, cx| {
377                    state.is_authenticated = has_credentials;
378                    cx.notify();
379                });
380            });
381
382            this.update(cx, |this, cx| {
383                this.loading_credentials = false;
384                cx.notify();
385            })
386            .log_err();
387        })
388        .detach();
389    }
390
391    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
392        let api_key = self.api_key_editor.read(cx).text(cx);
393        if api_key.is_empty() {
394            return;
395        }
396
397        // Clear the editor
398        self.api_key_editor
399            .update(cx, |editor, cx| editor.set_text("", window, cx));
400
401        let credential_key = self.credential_key.clone();
402        let credentials_provider = <dyn CredentialsProvider>::global(cx);
403        let state = self.state.clone();
404
405        cx.spawn(async move |_this, cx| {
406            // Store in system keychain
407            credentials_provider
408                .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx)
409                .await
410                .log_err();
411
412            // Update state to authenticated
413            let _ = cx.update(|cx| {
414                state.update(cx, |state, cx| {
415                    state.is_authenticated = true;
416                    cx.notify();
417                });
418            });
419        })
420        .detach();
421    }
422
423    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
424        // Clear the editor
425        self.api_key_editor
426            .update(cx, |editor, cx| editor.set_text("", window, cx));
427
428        let credential_key = self.credential_key.clone();
429        let credentials_provider = <dyn CredentialsProvider>::global(cx);
430        let state = self.state.clone();
431
432        cx.spawn(async move |_this, cx| {
433            // Delete from system keychain
434            credentials_provider
435                .delete_credentials(&credential_key, cx)
436                .await
437                .log_err();
438
439            // Update state to unauthenticated
440            let _ = cx.update(|cx| {
441                state.update(cx, |state, cx| {
442                    state.is_authenticated = false;
443                    cx.notify();
444                });
445            });
446        })
447        .detach();
448    }
449
450    fn is_authenticated(&self, cx: &Context<Self>) -> bool {
451        self.state.read(cx).is_authenticated
452    }
453}
454
455impl gpui::Render for ExtensionProviderConfigurationView {
456    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
457        let is_loading = self.loading_settings || self.loading_credentials;
458        let is_authenticated = self.is_authenticated(cx);
459
460        if is_loading {
461            return v_flex()
462                .gap_2()
463                .child(Label::new("Loading...").color(Color::Muted))
464                .into_any_element();
465        }
466
467        let mut content = v_flex().gap_4().size_full();
468
469        // Render settings markdown if available
470        if let Some(markdown) = &self.settings_markdown {
471            let style = settings_markdown_style(_window, cx);
472            content = content.child(
473                div()
474                    .p_2()
475                    .rounded_md()
476                    .bg(cx.theme().colors().surface_background)
477                    .child(MarkdownElement::new(markdown.clone(), style)),
478            );
479        }
480
481        // Render API key section
482        if is_authenticated {
483            content = content.child(
484                v_flex()
485                    .gap_2()
486                    .child(
487                        h_flex()
488                            .gap_2()
489                            .child(
490                                ui::Icon::new(ui::IconName::Check)
491                                    .color(Color::Success)
492                                    .size(ui::IconSize::Small),
493                            )
494                            .child(Label::new("API key configured").color(Color::Success)),
495                    )
496                    .child(
497                        ui::Button::new("reset-api-key", "Reset API Key")
498                            .style(ui::ButtonStyle::Subtle)
499                            .on_click(cx.listener(|this, _, window, cx| {
500                                this.reset_api_key(window, cx);
501                            })),
502                    ),
503            );
504        } else {
505            content = content.child(
506                v_flex()
507                    .gap_2()
508                    .on_action(cx.listener(Self::save_api_key))
509                    .child(
510                        Label::new("API Key")
511                            .size(LabelSize::Small)
512                            .color(Color::Muted),
513                    )
514                    .child(self.api_key_editor.clone())
515                    .child(
516                        Label::new("Enter your API key and press Enter to save")
517                            .size(LabelSize::Small)
518                            .color(Color::Muted),
519                    ),
520            );
521        }
522
523        content.into_any_element()
524    }
525}
526
527impl Focusable for ExtensionProviderConfigurationView {
528    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
529        self.api_key_editor.focus_handle(cx)
530    }
531}
532
533fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
534    let theme_settings = ThemeSettings::get_global(cx);
535    let colors = cx.theme().colors();
536    let mut text_style = window.text_style();
537    text_style.refine(&TextStyleRefinement {
538        font_family: Some(theme_settings.ui_font.family.clone()),
539        font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
540        font_features: Some(theme_settings.ui_font.features.clone()),
541        color: Some(colors.text),
542        ..Default::default()
543    });
544
545    MarkdownStyle {
546        base_text_style: text_style,
547        selection_background_color: colors.element_selection_background,
548        inline_code: TextStyleRefinement {
549            background_color: Some(colors.editor_background),
550            ..Default::default()
551        },
552        link: TextStyleRefinement {
553            color: Some(colors.text_accent),
554            underline: Some(UnderlineStyle {
555                color: Some(colors.text_accent.opacity(0.5)),
556                thickness: px(1.),
557                ..Default::default()
558            }),
559            ..Default::default()
560        },
561        syntax: cx.theme().syntax().clone(),
562        ..Default::default()
563    }
564}
565
566/// An extension-based language model.
567pub struct ExtensionLanguageModel {
568    extension: WasmExtension,
569    model_info: LlmModelInfo,
570    provider_id: LanguageModelProviderId,
571    provider_name: LanguageModelProviderName,
572    provider_info: LlmProviderInfo,
573}
574
575impl LanguageModel for ExtensionLanguageModel {
576    fn id(&self) -> LanguageModelId {
577        LanguageModelId::from(self.model_info.id.clone())
578    }
579
580    fn name(&self) -> LanguageModelName {
581        LanguageModelName::from(self.model_info.name.clone())
582    }
583
584    fn provider_id(&self) -> LanguageModelProviderId {
585        self.provider_id.clone()
586    }
587
588    fn provider_name(&self) -> LanguageModelProviderName {
589        self.provider_name.clone()
590    }
591
592    fn telemetry_id(&self) -> String {
593        format!("extension-{}", self.model_info.id)
594    }
595
596    fn supports_images(&self) -> bool {
597        self.model_info.capabilities.supports_images
598    }
599
600    fn supports_tools(&self) -> bool {
601        self.model_info.capabilities.supports_tools
602    }
603
604    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
605        match choice {
606            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
607            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
608            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
609        }
610    }
611
612    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
613        match self.model_info.capabilities.tool_input_format {
614            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
615            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
616        }
617    }
618
619    fn max_token_count(&self) -> u64 {
620        self.model_info.max_token_count
621    }
622
623    fn max_output_tokens(&self) -> Option<u64> {
624        self.model_info.max_output_tokens
625    }
626
627    fn count_tokens(
628        &self,
629        request: LanguageModelRequest,
630        cx: &App,
631    ) -> BoxFuture<'static, Result<u64>> {
632        let extension = self.extension.clone();
633        let provider_id = self.provider_info.id.clone();
634        let model_id = self.model_info.id.clone();
635
636        let wit_request = convert_request_to_wit(request);
637
638        cx.background_spawn(async move {
639            extension
640                .call({
641                    let provider_id = provider_id.clone();
642                    let model_id = model_id.clone();
643                    let wit_request = wit_request.clone();
644                    |ext, store| {
645                        async move {
646                            let count = ext
647                                .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
648                                .await?
649                                .map_err(|e| anyhow!("{}", e))?;
650                            Ok(count)
651                        }
652                        .boxed()
653                    }
654                })
655                .await?
656        })
657        .boxed()
658    }
659
660    fn stream_completion(
661        &self,
662        request: LanguageModelRequest,
663        _cx: &AsyncApp,
664    ) -> BoxFuture<
665        'static,
666        Result<
667            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
668            LanguageModelCompletionError,
669        >,
670    > {
671        let extension = self.extension.clone();
672        let provider_id = self.provider_info.id.clone();
673        let model_id = self.model_info.id.clone();
674
675        let wit_request = convert_request_to_wit(request);
676
677        async move {
678            // Start the stream
679            let stream_id_result = extension
680                .call({
681                    let provider_id = provider_id.clone();
682                    let model_id = model_id.clone();
683                    let wit_request = wit_request.clone();
684                    |ext, store| {
685                        async move {
686                            let id = ext
687                                .call_llm_stream_completion_start(
688                                    store,
689                                    &provider_id,
690                                    &model_id,
691                                    &wit_request,
692                                )
693                                .await?
694                                .map_err(|e| anyhow!("{}", e))?;
695                            Ok(id)
696                        }
697                        .boxed()
698                    }
699                })
700                .await;
701
702            let stream_id = stream_id_result
703                .map_err(LanguageModelCompletionError::Other)?
704                .map_err(LanguageModelCompletionError::Other)?;
705
706            // Create a stream that polls for events
707            let stream = futures::stream::unfold(
708                (extension.clone(), stream_id, false),
709                move |(extension, stream_id, done)| async move {
710                    if done {
711                        return None;
712                    }
713
714                    let result = extension
715                        .call({
716                            let stream_id = stream_id.clone();
717                            |ext, store| {
718                                async move {
719                                    let event = ext
720                                        .call_llm_stream_completion_next(store, &stream_id)
721                                        .await?
722                                        .map_err(|e| anyhow!("{}", e))?;
723                                    Ok(event)
724                                }
725                                .boxed()
726                            }
727                        })
728                        .await
729                        .and_then(|inner| inner);
730
731                    match result {
732                        Ok(Some(event)) => {
733                            let converted = convert_completion_event(event);
734                            let is_done =
735                                matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_)));
736                            Some((converted, (extension, stream_id, is_done)))
737                        }
738                        Ok(None) => {
739                            // Stream complete, close it
740                            let _ = extension
741                                .call({
742                                    let stream_id = stream_id.clone();
743                                    |ext, store| {
744                                        async move {
745                                            ext.call_llm_stream_completion_close(store, &stream_id)
746                                                .await?;
747                                            Ok::<(), anyhow::Error>(())
748                                        }
749                                        .boxed()
750                                    }
751                                })
752                                .await;
753                            None
754                        }
755                        Err(e) => Some((
756                            Err(LanguageModelCompletionError::Other(e)),
757                            (extension, stream_id, true),
758                        )),
759                    }
760                },
761            );
762
763            Ok(stream.boxed())
764        }
765        .boxed()
766    }
767
768    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
769        // Extensions can implement this via llm_cache_configuration
770        None
771    }
772}
773
774fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest {
775    use language_model::{MessageContent, Role};
776
777    let messages: Vec<LlmRequestMessage> = request
778        .messages
779        .into_iter()
780        .map(|msg| {
781            let role = match msg.role {
782                Role::User => LlmMessageRole::User,
783                Role::Assistant => LlmMessageRole::Assistant,
784                Role::System => LlmMessageRole::System,
785            };
786
787            let content: Vec<LlmMessageContent> = msg
788                .content
789                .into_iter()
790                .map(|c| match c {
791                    MessageContent::Text(text) => LlmMessageContent::Text(text),
792                    MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData {
793                        source: image.source.to_string(),
794                        width: Some(image.size.width.0 as u32),
795                        height: Some(image.size.height.0 as u32),
796                    }),
797                    MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse {
798                        id: tool_use.id.to_string(),
799                        name: tool_use.name.to_string(),
800                        input: serde_json::to_string(&tool_use.input).unwrap_or_default(),
801                        thought_signature: tool_use.thought_signature,
802                    }),
803                    MessageContent::ToolResult(tool_result) => {
804                        let content = match tool_result.content {
805                            language_model::LanguageModelToolResultContent::Text(text) => {
806                                LlmToolResultContent::Text(text.to_string())
807                            }
808                            language_model::LanguageModelToolResultContent::Image(image) => {
809                                LlmToolResultContent::Image(LlmImageData {
810                                    source: image.source.to_string(),
811                                    width: Some(image.size.width.0 as u32),
812                                    height: Some(image.size.height.0 as u32),
813                                })
814                            }
815                        };
816                        LlmMessageContent::ToolResult(LlmToolResult {
817                            tool_use_id: tool_result.tool_use_id.to_string(),
818                            tool_name: tool_result.tool_name.to_string(),
819                            is_error: tool_result.is_error,
820                            content,
821                        })
822                    }
823                    MessageContent::Thinking { text, signature } => {
824                        LlmMessageContent::Thinking(LlmThinkingContent { text, signature })
825                    }
826                    MessageContent::RedactedThinking(data) => {
827                        LlmMessageContent::RedactedThinking(data)
828                    }
829                })
830                .collect();
831
832            LlmRequestMessage {
833                role,
834                content,
835                cache: msg.cache,
836            }
837        })
838        .collect();
839
840    let tools: Vec<LlmToolDefinition> = request
841        .tools
842        .into_iter()
843        .map(|tool| LlmToolDefinition {
844            name: tool.name,
845            description: tool.description,
846            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
847        })
848        .collect();
849
850    let tool_choice = request.tool_choice.map(|tc| match tc {
851        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
852        LanguageModelToolChoice::Any => LlmToolChoice::Any,
853        LanguageModelToolChoice::None => LlmToolChoice::None,
854    });
855
856    LlmCompletionRequest {
857        messages,
858        tools,
859        tool_choice,
860        stop_sequences: request.stop,
861        temperature: request.temperature,
862        thinking_allowed: false,
863        max_tokens: None,
864    }
865}
866
867fn convert_completion_event(
868    event: LlmCompletionEvent,
869) -> Result<LanguageModelCompletionEvent, LanguageModelCompletionError> {
870    match event {
871        LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage {
872            message_id: String::new(),
873        }),
874        LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)),
875        LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking {
876            text: thinking.text,
877            signature: thinking.signature,
878        }),
879        LlmCompletionEvent::RedactedThinking(data) => {
880            Ok(LanguageModelCompletionEvent::RedactedThinking { data })
881        }
882        LlmCompletionEvent::ToolUse(tool_use) => {
883            let raw_input = tool_use.input.clone();
884            let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null);
885            Ok(LanguageModelCompletionEvent::ToolUse(
886                LanguageModelToolUse {
887                    id: LanguageModelToolUseId::from(tool_use.id),
888                    name: tool_use.name.into(),
889                    raw_input,
890                    input,
891                    is_input_complete: true,
892                    thought_signature: tool_use.thought_signature,
893                },
894            ))
895        }
896        LlmCompletionEvent::ToolUseJsonParseError(error) => {
897            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
898                id: LanguageModelToolUseId::from(error.id),
899                tool_name: error.tool_name.into(),
900                raw_input: error.raw_input.into(),
901                json_parse_error: error.error,
902            })
903        }
904        LlmCompletionEvent::Stop(reason) => {
905            let stop_reason = match reason {
906                LlmStopReason::EndTurn => StopReason::EndTurn,
907                LlmStopReason::MaxTokens => StopReason::MaxTokens,
908                LlmStopReason::ToolUse => StopReason::ToolUse,
909                LlmStopReason::Refusal => StopReason::Refusal,
910            };
911            Ok(LanguageModelCompletionEvent::Stop(stop_reason))
912        }
913        LlmCompletionEvent::Usage(usage) => {
914            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
915                input_tokens: usage.input_tokens,
916                output_tokens: usage.output_tokens,
917                cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
918                cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
919            }))
920        }
921        LlmCompletionEvent::ReasoningDetails(json) => {
922            Ok(LanguageModelCompletionEvent::ReasoningDetails(
923                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
924            ))
925        }
926    }
927}