llm_provider.rs

  1use crate::wasm_host::WasmExtension;
  2
  3use crate::wasm_host::wit::{
  4    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
  5    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
  6    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  7    LlmToolUse,
  8};
  9use anyhow::{Result, anyhow};
 10use futures::future::BoxFuture;
 11use futures::stream::BoxStream;
 12use futures::{FutureExt, StreamExt};
 13use gpui::{AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, Window};
 14use language_model::tool_schema::LanguageModelToolSchemaFormat;
 15use language_model::{
 16    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
 17    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 18    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 19    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 20    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
 21};
 22use std::sync::Arc;
 23
 24/// An extension-based language model provider.
 25pub struct ExtensionLanguageModelProvider {
 26    pub extension: WasmExtension,
 27    pub provider_info: LlmProviderInfo,
 28    state: Entity<ExtensionLlmProviderState>,
 29}
 30
 31pub struct ExtensionLlmProviderState {
 32    is_authenticated: bool,
 33    available_models: Vec<LlmModelInfo>,
 34}
 35
 36impl EventEmitter<()> for ExtensionLlmProviderState {}
 37
 38impl ExtensionLanguageModelProvider {
 39    pub fn new(
 40        extension: WasmExtension,
 41        provider_info: LlmProviderInfo,
 42        models: Vec<LlmModelInfo>,
 43        cx: &mut App,
 44    ) -> Self {
 45        let state = cx.new(|_| ExtensionLlmProviderState {
 46            is_authenticated: false,
 47            available_models: models,
 48        });
 49
 50        Self {
 51            extension,
 52            provider_info,
 53            state,
 54        }
 55    }
 56
 57    fn provider_id_string(&self) -> String {
 58        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 59    }
 60}
 61
 62impl LanguageModelProvider for ExtensionLanguageModelProvider {
 63    fn id(&self) -> LanguageModelProviderId {
 64        LanguageModelProviderId::from(self.provider_id_string())
 65    }
 66
 67    fn name(&self) -> LanguageModelProviderName {
 68        LanguageModelProviderName::from(self.provider_info.name.clone())
 69    }
 70
 71    fn icon(&self) -> ui::IconName {
 72        ui::IconName::ZedAssistant
 73    }
 74
 75    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 76        let state = self.state.read(cx);
 77        state
 78            .available_models
 79            .iter()
 80            .find(|m| m.is_default)
 81            .or_else(|| state.available_models.first())
 82            .map(|model_info| {
 83                Arc::new(ExtensionLanguageModel {
 84                    extension: self.extension.clone(),
 85                    model_info: model_info.clone(),
 86                    provider_id: self.id(),
 87                    provider_name: self.name(),
 88                    provider_info: self.provider_info.clone(),
 89                }) as Arc<dyn LanguageModel>
 90            })
 91    }
 92
 93    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 94        let state = self.state.read(cx);
 95        state
 96            .available_models
 97            .iter()
 98            .find(|m| m.is_default_fast)
 99            .or_else(|| state.available_models.iter().find(|m| m.is_default))
100            .or_else(|| state.available_models.first())
101            .map(|model_info| {
102                Arc::new(ExtensionLanguageModel {
103                    extension: self.extension.clone(),
104                    model_info: model_info.clone(),
105                    provider_id: self.id(),
106                    provider_name: self.name(),
107                    provider_info: self.provider_info.clone(),
108                }) as Arc<dyn LanguageModel>
109            })
110    }
111
112    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
113        let state = self.state.read(cx);
114        state
115            .available_models
116            .iter()
117            .map(|model_info| {
118                Arc::new(ExtensionLanguageModel {
119                    extension: self.extension.clone(),
120                    model_info: model_info.clone(),
121                    provider_id: self.id(),
122                    provider_name: self.name(),
123                    provider_info: self.provider_info.clone(),
124                }) as Arc<dyn LanguageModel>
125            })
126            .collect()
127    }
128
129    fn is_authenticated(&self, cx: &App) -> bool {
130        self.state.read(cx).is_authenticated
131    }
132
133    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
134        let extension = self.extension.clone();
135        let provider_id = self.provider_info.id.clone();
136        let state = self.state.clone();
137
138        cx.spawn(async move |cx| {
139            let result = extension
140                .call(|extension, store| {
141                    async move {
142                        extension
143                            .call_llm_provider_authenticate(store, &provider_id)
144                            .await
145                    }
146                    .boxed()
147                })
148                .await;
149
150            match result {
151                Ok(Ok(Ok(()))) => {
152                    cx.update(|cx| {
153                        state.update(cx, |state, _| {
154                            state.is_authenticated = true;
155                        });
156                    })?;
157                    Ok(())
158                }
159                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
160                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
161                Err(e) => Err(AuthenticateError::Other(e)),
162            }
163        })
164    }
165
166    fn configuration_view(
167        &self,
168        _target_agent: ConfigurationViewTargetAgent,
169        _window: &mut Window,
170        cx: &mut App,
171    ) -> AnyView {
172        cx.new(|_| EmptyConfigView).into()
173    }
174
175    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
176        let extension = self.extension.clone();
177        let provider_id = self.provider_info.id.clone();
178        let state = self.state.clone();
179
180        cx.spawn(async move |cx| {
181            let result = extension
182                .call(|extension, store| {
183                    async move {
184                        extension
185                            .call_llm_provider_reset_credentials(store, &provider_id)
186                            .await
187                    }
188                    .boxed()
189                })
190                .await;
191
192            match result {
193                Ok(Ok(Ok(()))) => {
194                    cx.update(|cx| {
195                        state.update(cx, |state, _| {
196                            state.is_authenticated = false;
197                        });
198                    })?;
199                    Ok(())
200                }
201                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
202                Ok(Err(e)) => Err(e),
203                Err(e) => Err(e),
204            }
205        })
206    }
207}
208
209impl LanguageModelProviderState for ExtensionLanguageModelProvider {
210    type ObservableEntity = ExtensionLlmProviderState;
211
212    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
213        Some(self.state.clone())
214    }
215
216    fn subscribe<T: 'static>(
217        &self,
218        cx: &mut Context<T>,
219        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
220    ) -> Option<gpui::Subscription> {
221        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
222    }
223}
224
225struct EmptyConfigView;
226
227impl gpui::Render for EmptyConfigView {
228    fn render(
229        &mut self,
230        _window: &mut Window,
231        _cx: &mut gpui::Context<Self>,
232    ) -> impl gpui::IntoElement {
233        gpui::Empty
234    }
235}
236
237/// An extension-based language model.
238pub struct ExtensionLanguageModel {
239    extension: WasmExtension,
240    model_info: LlmModelInfo,
241    provider_id: LanguageModelProviderId,
242    provider_name: LanguageModelProviderName,
243    provider_info: LlmProviderInfo,
244}
245
246impl LanguageModel for ExtensionLanguageModel {
247    fn id(&self) -> LanguageModelId {
248        LanguageModelId::from(format!("{}:{}", self.provider_id.0, self.model_info.id))
249    }
250
251    fn name(&self) -> LanguageModelName {
252        LanguageModelName::from(self.model_info.name.clone())
253    }
254
255    fn provider_id(&self) -> LanguageModelProviderId {
256        self.provider_id.clone()
257    }
258
259    fn provider_name(&self) -> LanguageModelProviderName {
260        self.provider_name.clone()
261    }
262
263    fn telemetry_id(&self) -> String {
264        format!("extension:{}", self.model_info.id)
265    }
266
267    fn supports_images(&self) -> bool {
268        self.model_info.capabilities.supports_images
269    }
270
271    fn supports_tools(&self) -> bool {
272        self.model_info.capabilities.supports_tools
273    }
274
275    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
276        match choice {
277            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
278            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
279            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
280        }
281    }
282
283    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
284        match self.model_info.capabilities.tool_input_format {
285            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
286            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
287        }
288    }
289
290    fn max_token_count(&self) -> u64 {
291        self.model_info.max_token_count
292    }
293
294    fn max_output_tokens(&self) -> Option<u64> {
295        self.model_info.max_output_tokens
296    }
297
298    fn count_tokens(
299        &self,
300        request: LanguageModelRequest,
301        _cx: &App,
302    ) -> BoxFuture<'static, Result<u64>> {
303        let extension = self.extension.clone();
304        let provider_id = self.provider_info.id.clone();
305        let model_id = self.model_info.id.clone();
306
307        async move {
308            let wit_request = convert_request_to_wit(&request);
309
310            let result = extension
311                .call(|ext, store| {
312                    async move {
313                        ext.call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
314                            .await
315                    }
316                    .boxed()
317                })
318                .await?;
319
320            match result {
321                Ok(Ok(count)) => Ok(count),
322                Ok(Err(e)) => Err(anyhow!("{}", e)),
323                Err(e) => Err(e),
324            }
325        }
326        .boxed()
327    }
328
329    fn stream_completion(
330        &self,
331        request: LanguageModelRequest,
332        _cx: &AsyncApp,
333    ) -> BoxFuture<
334        'static,
335        Result<
336            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
337            LanguageModelCompletionError,
338        >,
339    > {
340        let extension = self.extension.clone();
341        let provider_id = self.provider_info.id.clone();
342        let model_id = self.model_info.id.clone();
343
344        async move {
345            let wit_request = convert_request_to_wit(&request);
346
347            // Start the stream and get a stream ID
348            let outer_result = extension
349                .call(|ext, store| {
350                    async move {
351                        ext.call_llm_stream_completion_start(
352                            store,
353                            &provider_id,
354                            &model_id,
355                            &wit_request,
356                        )
357                        .await
358                    }
359                    .boxed()
360                })
361                .await
362                .map_err(|e| LanguageModelCompletionError::Other(e))?;
363
364            // Unwrap the inner Result<Result<String, String>>
365            let inner_result =
366                outer_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
367
368            // Get the stream ID
369            let stream_id =
370                inner_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
371
372            // Create a stream that polls for events
373            let stream = futures::stream::unfold(
374                (extension, stream_id, false),
375                |(ext, stream_id, done)| async move {
376                    if done {
377                        return None;
378                    }
379
380                    let result = ext
381                        .call({
382                            let stream_id = stream_id.clone();
383                            move |ext, store| {
384                                async move {
385                                    ext.call_llm_stream_completion_next(store, &stream_id).await
386                                }
387                                .boxed()
388                            }
389                        })
390                        .await;
391
392                    match result {
393                        Ok(Ok(Ok(Some(event)))) => {
394                            let converted = convert_completion_event(event);
395                            Some((Ok(converted), (ext, stream_id, false)))
396                        }
397                        Ok(Ok(Ok(None))) => {
398                            // Stream complete - close it
399                            let _ = ext
400                                .call({
401                                    let stream_id = stream_id.clone();
402                                    move |ext, store| {
403                                        async move {
404                                            ext.call_llm_stream_completion_close(store, &stream_id)
405                                                .await
406                                        }
407                                        .boxed()
408                                    }
409                                })
410                                .await;
411                            None
412                        }
413                        Ok(Ok(Err(e))) => {
414                            // Extension returned an error - close stream and return error
415                            let _ = ext
416                                .call({
417                                    let stream_id = stream_id.clone();
418                                    move |ext, store| {
419                                        async move {
420                                            ext.call_llm_stream_completion_close(store, &stream_id)
421                                                .await
422                                        }
423                                        .boxed()
424                                    }
425                                })
426                                .await;
427                            Some((
428                                Err(LanguageModelCompletionError::Other(anyhow!("{}", e))),
429                                (ext, stream_id, true),
430                            ))
431                        }
432                        Ok(Err(e)) => {
433                            // WASM call error - close stream and return error
434                            let _ = ext
435                                .call({
436                                    let stream_id = stream_id.clone();
437                                    move |ext, store| {
438                                        async move {
439                                            ext.call_llm_stream_completion_close(store, &stream_id)
440                                                .await
441                                        }
442                                        .boxed()
443                                    }
444                                })
445                                .await;
446                            Some((
447                                Err(LanguageModelCompletionError::Other(e)),
448                                (ext, stream_id, true),
449                            ))
450                        }
451                        Err(e) => {
452                            // Channel error - close stream and return error
453                            let _ = ext
454                                .call({
455                                    let stream_id = stream_id.clone();
456                                    move |ext, store| {
457                                        async move {
458                                            ext.call_llm_stream_completion_close(store, &stream_id)
459                                                .await
460                                        }
461                                        .boxed()
462                                    }
463                                })
464                                .await;
465                            Some((
466                                Err(LanguageModelCompletionError::Other(e)),
467                                (ext, stream_id, true),
468                            ))
469                        }
470                    }
471                },
472            );
473
474            Ok(stream.boxed())
475        }
476        .boxed()
477    }
478
479    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
480        None
481    }
482}
483
484fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionRequest {
485    let messages = request
486        .messages
487        .iter()
488        .map(|msg| LlmRequestMessage {
489            role: match msg.role {
490                language_model::Role::User => LlmMessageRole::User,
491                language_model::Role::Assistant => LlmMessageRole::Assistant,
492                language_model::Role::System => LlmMessageRole::System,
493            },
494            content: msg
495                .content
496                .iter()
497                .map(|content| match content {
498                    language_model::MessageContent::Text(text) => {
499                        LlmMessageContent::Text(text.clone())
500                    }
501                    language_model::MessageContent::Image(image) => {
502                        LlmMessageContent::Image(LlmImageData {
503                            source: image.source.to_string(),
504                            width: Some(image.size.width.0 as u32),
505                            height: Some(image.size.height.0 as u32),
506                        })
507                    }
508                    language_model::MessageContent::ToolUse(tool_use) => {
509                        LlmMessageContent::ToolUse(LlmToolUse {
510                            id: tool_use.id.to_string(),
511                            name: tool_use.name.to_string(),
512                            input: tool_use.raw_input.clone(),
513                            thought_signature: tool_use.thought_signature.clone(),
514                        })
515                    }
516                    language_model::MessageContent::ToolResult(result) => {
517                        LlmMessageContent::ToolResult(LlmToolResult {
518                            tool_use_id: result.tool_use_id.to_string(),
519                            tool_name: result.tool_name.to_string(),
520                            is_error: result.is_error,
521                            content: match &result.content {
522                                language_model::LanguageModelToolResultContent::Text(t) => {
523                                    LlmToolResultContent::Text(t.to_string())
524                                }
525                                language_model::LanguageModelToolResultContent::Image(img) => {
526                                    LlmToolResultContent::Image(LlmImageData {
527                                        source: img.source.to_string(),
528                                        width: Some(img.size.width.0 as u32),
529                                        height: Some(img.size.height.0 as u32),
530                                    })
531                                }
532                            },
533                        })
534                    }
535                    language_model::MessageContent::Thinking { text, signature } => {
536                        LlmMessageContent::Thinking(LlmThinkingContent {
537                            text: text.clone(),
538                            signature: signature.clone(),
539                        })
540                    }
541                    language_model::MessageContent::RedactedThinking(data) => {
542                        LlmMessageContent::RedactedThinking(data.clone())
543                    }
544                })
545                .collect(),
546            cache: msg.cache,
547        })
548        .collect();
549
550    let tools = request
551        .tools
552        .iter()
553        .map(|tool| LlmToolDefinition {
554            name: tool.name.clone(),
555            description: tool.description.clone(),
556            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
557        })
558        .collect();
559
560    let tool_choice = request.tool_choice.as_ref().map(|choice| match choice {
561        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
562        LanguageModelToolChoice::Any => LlmToolChoice::Any,
563        LanguageModelToolChoice::None => LlmToolChoice::None,
564    });
565
566    LlmCompletionRequest {
567        messages,
568        tools,
569        tool_choice,
570        stop_sequences: request.stop.clone(),
571        temperature: request.temperature,
572        thinking_allowed: request.thinking_allowed,
573        max_tokens: None,
574    }
575}
576
577fn convert_completion_event(event: LlmCompletionEvent) -> LanguageModelCompletionEvent {
578    match event {
579        LlmCompletionEvent::Started => LanguageModelCompletionEvent::Started,
580        LlmCompletionEvent::Text(text) => LanguageModelCompletionEvent::Text(text),
581        LlmCompletionEvent::Thinking(thinking) => LanguageModelCompletionEvent::Thinking {
582            text: thinking.text,
583            signature: thinking.signature,
584        },
585        LlmCompletionEvent::RedactedThinking(data) => {
586            LanguageModelCompletionEvent::RedactedThinking { data }
587        }
588        LlmCompletionEvent::ToolUse(tool_use) => {
589            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
590                id: LanguageModelToolUseId::from(tool_use.id),
591                name: tool_use.name.into(),
592                raw_input: tool_use.input.clone(),
593                input: serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null),
594                is_input_complete: true,
595                thought_signature: tool_use.thought_signature,
596            })
597        }
598        LlmCompletionEvent::ToolUseJsonParseError(error) => {
599            LanguageModelCompletionEvent::ToolUseJsonParseError {
600                id: LanguageModelToolUseId::from(error.id),
601                tool_name: error.tool_name.into(),
602                raw_input: error.raw_input.into(),
603                json_parse_error: error.error,
604            }
605        }
606        LlmCompletionEvent::Stop(reason) => LanguageModelCompletionEvent::Stop(match reason {
607            LlmStopReason::EndTurn => StopReason::EndTurn,
608            LlmStopReason::MaxTokens => StopReason::MaxTokens,
609            LlmStopReason::ToolUse => StopReason::ToolUse,
610            LlmStopReason::Refusal => StopReason::Refusal,
611        }),
612        LlmCompletionEvent::Usage(usage) => LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
613            input_tokens: usage.input_tokens,
614            output_tokens: usage.output_tokens,
615            cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
616            cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
617        }),
618        LlmCompletionEvent::ReasoningDetails(json) => {
619            LanguageModelCompletionEvent::ReasoningDetails(
620                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
621            )
622        }
623    }
624}