llm_provider.rs

  1use crate::wasm_host::WasmExtension;
  2
  3use crate::wasm_host::wit::{
  4    LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
  5    LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
  6    LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
  7    LlmToolUse,
  8};
  9use anyhow::{Result, anyhow};
 10use futures::future::BoxFuture;
 11use futures::stream::BoxStream;
 12use futures::{FutureExt, StreamExt};
 13use gpui::{AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, Window};
 14use language_model::tool_schema::LanguageModelToolSchemaFormat;
 15use language_model::{
 16    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
 17    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 18    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 19    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 20    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
 21};
 22use std::sync::Arc;
 23
 24/// An extension-based language model provider.
 25pub struct ExtensionLanguageModelProvider {
 26    pub extension: WasmExtension,
 27    pub provider_info: LlmProviderInfo,
 28    state: Entity<ExtensionLlmProviderState>,
 29}
 30
 31pub struct ExtensionLlmProviderState {
 32    is_authenticated: bool,
 33    available_models: Vec<LlmModelInfo>,
 34}
 35
 36impl EventEmitter<()> for ExtensionLlmProviderState {}
 37
 38impl ExtensionLanguageModelProvider {
 39    pub fn new(
 40        extension: WasmExtension,
 41        provider_info: LlmProviderInfo,
 42        models: Vec<LlmModelInfo>,
 43        is_authenticated: bool,
 44        cx: &mut App,
 45    ) -> Self {
 46        let state = cx.new(|_| ExtensionLlmProviderState {
 47            is_authenticated,
 48            available_models: models,
 49        });
 50
 51        Self {
 52            extension,
 53            provider_info,
 54            state,
 55        }
 56    }
 57
 58    fn provider_id_string(&self) -> String {
 59        format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
 60    }
 61}
 62
 63impl LanguageModelProvider for ExtensionLanguageModelProvider {
 64    fn id(&self) -> LanguageModelProviderId {
 65        let id = LanguageModelProviderId::from(self.provider_id_string());
 66        eprintln!("ExtensionLanguageModelProvider::id() -> {:?}", id);
 67        id
 68    }
 69
 70    fn name(&self) -> LanguageModelProviderName {
 71        LanguageModelProviderName::from(self.provider_info.name.clone())
 72    }
 73
 74    fn icon(&self) -> ui::IconName {
 75        ui::IconName::ZedAssistant
 76    }
 77
 78    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 79        let state = self.state.read(cx);
 80        state
 81            .available_models
 82            .iter()
 83            .find(|m| m.is_default)
 84            .or_else(|| state.available_models.first())
 85            .map(|model_info| {
 86                Arc::new(ExtensionLanguageModel {
 87                    extension: self.extension.clone(),
 88                    model_info: model_info.clone(),
 89                    provider_id: self.id(),
 90                    provider_name: self.name(),
 91                    provider_info: self.provider_info.clone(),
 92                }) as Arc<dyn LanguageModel>
 93            })
 94    }
 95
 96    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 97        let state = self.state.read(cx);
 98        state
 99            .available_models
100            .iter()
101            .find(|m| m.is_default_fast)
102            .or_else(|| state.available_models.iter().find(|m| m.is_default))
103            .or_else(|| state.available_models.first())
104            .map(|model_info| {
105                Arc::new(ExtensionLanguageModel {
106                    extension: self.extension.clone(),
107                    model_info: model_info.clone(),
108                    provider_id: self.id(),
109                    provider_name: self.name(),
110                    provider_info: self.provider_info.clone(),
111                }) as Arc<dyn LanguageModel>
112            })
113    }
114
115    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
116        let state = self.state.read(cx);
117        eprintln!(
118            "ExtensionLanguageModelProvider::provided_models called for {}, returning {} models",
119            self.provider_info.name,
120            state.available_models.len()
121        );
122        state
123            .available_models
124            .iter()
125            .map(|model_info| {
126                eprintln!("  - model: {}", model_info.name);
127                Arc::new(ExtensionLanguageModel {
128                    extension: self.extension.clone(),
129                    model_info: model_info.clone(),
130                    provider_id: self.id(),
131                    provider_name: self.name(),
132                    provider_info: self.provider_info.clone(),
133                }) as Arc<dyn LanguageModel>
134            })
135            .collect()
136    }
137
138    fn is_authenticated(&self, cx: &App) -> bool {
139        self.state.read(cx).is_authenticated
140    }
141
142    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
143        let extension = self.extension.clone();
144        let provider_id = self.provider_info.id.clone();
145        let state = self.state.clone();
146
147        cx.spawn(async move |cx| {
148            let result = extension
149                .call(|extension, store| {
150                    async move {
151                        extension
152                            .call_llm_provider_authenticate(store, &provider_id)
153                            .await
154                    }
155                    .boxed()
156                })
157                .await;
158
159            match result {
160                Ok(Ok(Ok(()))) => {
161                    cx.update(|cx| {
162                        state.update(cx, |state, _| {
163                            state.is_authenticated = true;
164                        });
165                    })?;
166                    Ok(())
167                }
168                Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
169                Ok(Err(e)) => Err(AuthenticateError::Other(e)),
170                Err(e) => Err(AuthenticateError::Other(e)),
171            }
172        })
173    }
174
175    fn configuration_view(
176        &self,
177        _target_agent: ConfigurationViewTargetAgent,
178        _window: &mut Window,
179        cx: &mut App,
180    ) -> AnyView {
181        cx.new(|_| EmptyConfigView).into()
182    }
183
184    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
185        let extension = self.extension.clone();
186        let provider_id = self.provider_info.id.clone();
187        let state = self.state.clone();
188
189        cx.spawn(async move |cx| {
190            let result = extension
191                .call(|extension, store| {
192                    async move {
193                        extension
194                            .call_llm_provider_reset_credentials(store, &provider_id)
195                            .await
196                    }
197                    .boxed()
198                })
199                .await;
200
201            match result {
202                Ok(Ok(Ok(()))) => {
203                    cx.update(|cx| {
204                        state.update(cx, |state, _| {
205                            state.is_authenticated = false;
206                        });
207                    })?;
208                    Ok(())
209                }
210                Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
211                Ok(Err(e)) => Err(e),
212                Err(e) => Err(e),
213            }
214        })
215    }
216}
217
218impl LanguageModelProviderState for ExtensionLanguageModelProvider {
219    type ObservableEntity = ExtensionLlmProviderState;
220
221    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
222        Some(self.state.clone())
223    }
224
225    fn subscribe<T: 'static>(
226        &self,
227        cx: &mut Context<T>,
228        callback: impl Fn(&mut T, &mut Context<T>) + 'static,
229    ) -> Option<gpui::Subscription> {
230        Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
231    }
232}
233
234struct EmptyConfigView;
235
236impl gpui::Render for EmptyConfigView {
237    fn render(
238        &mut self,
239        _window: &mut Window,
240        _cx: &mut gpui::Context<Self>,
241    ) -> impl gpui::IntoElement {
242        gpui::Empty
243    }
244}
245
246/// An extension-based language model.
247pub struct ExtensionLanguageModel {
248    extension: WasmExtension,
249    model_info: LlmModelInfo,
250    provider_id: LanguageModelProviderId,
251    provider_name: LanguageModelProviderName,
252    provider_info: LlmProviderInfo,
253}
254
255impl LanguageModel for ExtensionLanguageModel {
256    fn id(&self) -> LanguageModelId {
257        LanguageModelId::from(format!("{}:{}", self.provider_id.0, self.model_info.id))
258    }
259
260    fn name(&self) -> LanguageModelName {
261        LanguageModelName::from(self.model_info.name.clone())
262    }
263
264    fn provider_id(&self) -> LanguageModelProviderId {
265        self.provider_id.clone()
266    }
267
268    fn provider_name(&self) -> LanguageModelProviderName {
269        self.provider_name.clone()
270    }
271
272    fn telemetry_id(&self) -> String {
273        format!("extension:{}", self.model_info.id)
274    }
275
276    fn supports_images(&self) -> bool {
277        self.model_info.capabilities.supports_images
278    }
279
280    fn supports_tools(&self) -> bool {
281        self.model_info.capabilities.supports_tools
282    }
283
284    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
285        match choice {
286            LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
287            LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
288            LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
289        }
290    }
291
292    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
293        match self.model_info.capabilities.tool_input_format {
294            LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
295            LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
296        }
297    }
298
299    fn max_token_count(&self) -> u64 {
300        self.model_info.max_token_count
301    }
302
303    fn max_output_tokens(&self) -> Option<u64> {
304        self.model_info.max_output_tokens
305    }
306
307    fn count_tokens(
308        &self,
309        request: LanguageModelRequest,
310        _cx: &App,
311    ) -> BoxFuture<'static, Result<u64>> {
312        let extension = self.extension.clone();
313        let provider_id = self.provider_info.id.clone();
314        let model_id = self.model_info.id.clone();
315
316        async move {
317            let wit_request = convert_request_to_wit(&request);
318
319            let result = extension
320                .call(|ext, store| {
321                    async move {
322                        ext.call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
323                            .await
324                    }
325                    .boxed()
326                })
327                .await?;
328
329            match result {
330                Ok(Ok(count)) => Ok(count),
331                Ok(Err(e)) => Err(anyhow!("{}", e)),
332                Err(e) => Err(e),
333            }
334        }
335        .boxed()
336    }
337
338    fn stream_completion(
339        &self,
340        request: LanguageModelRequest,
341        _cx: &AsyncApp,
342    ) -> BoxFuture<
343        'static,
344        Result<
345            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
346            LanguageModelCompletionError,
347        >,
348    > {
349        let extension = self.extension.clone();
350        let provider_id = self.provider_info.id.clone();
351        let model_id = self.model_info.id.clone();
352
353        async move {
354            let wit_request = convert_request_to_wit(&request);
355
356            // Start the stream and get a stream ID
357            let outer_result = extension
358                .call(|ext, store| {
359                    async move {
360                        ext.call_llm_stream_completion_start(
361                            store,
362                            &provider_id,
363                            &model_id,
364                            &wit_request,
365                        )
366                        .await
367                    }
368                    .boxed()
369                })
370                .await
371                .map_err(|e| LanguageModelCompletionError::Other(e))?;
372
373            // Unwrap the inner Result<Result<String, String>>
374            let inner_result =
375                outer_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
376
377            // Get the stream ID
378            let stream_id =
379                inner_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
380
381            // Create a stream that polls for events
382            let stream = futures::stream::unfold(
383                (extension, stream_id, false),
384                |(ext, stream_id, done)| async move {
385                    if done {
386                        return None;
387                    }
388
389                    let result = ext
390                        .call({
391                            let stream_id = stream_id.clone();
392                            move |ext, store| {
393                                async move {
394                                    ext.call_llm_stream_completion_next(store, &stream_id).await
395                                }
396                                .boxed()
397                            }
398                        })
399                        .await;
400
401                    match result {
402                        Ok(Ok(Ok(Some(event)))) => {
403                            let converted = convert_completion_event(event);
404                            Some((Ok(converted), (ext, stream_id, false)))
405                        }
406                        Ok(Ok(Ok(None))) => {
407                            // Stream complete - close it
408                            let _ = ext
409                                .call({
410                                    let stream_id = stream_id.clone();
411                                    move |ext, store| {
412                                        async move {
413                                            ext.call_llm_stream_completion_close(store, &stream_id)
414                                                .await
415                                        }
416                                        .boxed()
417                                    }
418                                })
419                                .await;
420                            None
421                        }
422                        Ok(Ok(Err(e))) => {
423                            // Extension returned an error - close stream and return error
424                            let _ = ext
425                                .call({
426                                    let stream_id = stream_id.clone();
427                                    move |ext, store| {
428                                        async move {
429                                            ext.call_llm_stream_completion_close(store, &stream_id)
430                                                .await
431                                        }
432                                        .boxed()
433                                    }
434                                })
435                                .await;
436                            Some((
437                                Err(LanguageModelCompletionError::Other(anyhow!("{}", e))),
438                                (ext, stream_id, true),
439                            ))
440                        }
441                        Ok(Err(e)) => {
442                            // WASM call error - close stream and return error
443                            let _ = ext
444                                .call({
445                                    let stream_id = stream_id.clone();
446                                    move |ext, store| {
447                                        async move {
448                                            ext.call_llm_stream_completion_close(store, &stream_id)
449                                                .await
450                                        }
451                                        .boxed()
452                                    }
453                                })
454                                .await;
455                            Some((
456                                Err(LanguageModelCompletionError::Other(e)),
457                                (ext, stream_id, true),
458                            ))
459                        }
460                        Err(e) => {
461                            // Channel error - close stream and return error
462                            let _ = ext
463                                .call({
464                                    let stream_id = stream_id.clone();
465                                    move |ext, store| {
466                                        async move {
467                                            ext.call_llm_stream_completion_close(store, &stream_id)
468                                                .await
469                                        }
470                                        .boxed()
471                                    }
472                                })
473                                .await;
474                            Some((
475                                Err(LanguageModelCompletionError::Other(e)),
476                                (ext, stream_id, true),
477                            ))
478                        }
479                    }
480                },
481            );
482
483            Ok(stream.boxed())
484        }
485        .boxed()
486    }
487
488    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
489        None
490    }
491}
492
493fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionRequest {
494    let messages = request
495        .messages
496        .iter()
497        .map(|msg| LlmRequestMessage {
498            role: match msg.role {
499                language_model::Role::User => LlmMessageRole::User,
500                language_model::Role::Assistant => LlmMessageRole::Assistant,
501                language_model::Role::System => LlmMessageRole::System,
502            },
503            content: msg
504                .content
505                .iter()
506                .map(|content| match content {
507                    language_model::MessageContent::Text(text) => {
508                        LlmMessageContent::Text(text.clone())
509                    }
510                    language_model::MessageContent::Image(image) => {
511                        LlmMessageContent::Image(LlmImageData {
512                            source: image.source.to_string(),
513                            width: Some(image.size.width.0 as u32),
514                            height: Some(image.size.height.0 as u32),
515                        })
516                    }
517                    language_model::MessageContent::ToolUse(tool_use) => {
518                        LlmMessageContent::ToolUse(LlmToolUse {
519                            id: tool_use.id.to_string(),
520                            name: tool_use.name.to_string(),
521                            input: tool_use.raw_input.clone(),
522                            thought_signature: tool_use.thought_signature.clone(),
523                        })
524                    }
525                    language_model::MessageContent::ToolResult(result) => {
526                        LlmMessageContent::ToolResult(LlmToolResult {
527                            tool_use_id: result.tool_use_id.to_string(),
528                            tool_name: result.tool_name.to_string(),
529                            is_error: result.is_error,
530                            content: match &result.content {
531                                language_model::LanguageModelToolResultContent::Text(t) => {
532                                    LlmToolResultContent::Text(t.to_string())
533                                }
534                                language_model::LanguageModelToolResultContent::Image(img) => {
535                                    LlmToolResultContent::Image(LlmImageData {
536                                        source: img.source.to_string(),
537                                        width: Some(img.size.width.0 as u32),
538                                        height: Some(img.size.height.0 as u32),
539                                    })
540                                }
541                            },
542                        })
543                    }
544                    language_model::MessageContent::Thinking { text, signature } => {
545                        LlmMessageContent::Thinking(LlmThinkingContent {
546                            text: text.clone(),
547                            signature: signature.clone(),
548                        })
549                    }
550                    language_model::MessageContent::RedactedThinking(data) => {
551                        LlmMessageContent::RedactedThinking(data.clone())
552                    }
553                })
554                .collect(),
555            cache: msg.cache,
556        })
557        .collect();
558
559    let tools = request
560        .tools
561        .iter()
562        .map(|tool| LlmToolDefinition {
563            name: tool.name.clone(),
564            description: tool.description.clone(),
565            input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
566        })
567        .collect();
568
569    let tool_choice = request.tool_choice.as_ref().map(|choice| match choice {
570        LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
571        LanguageModelToolChoice::Any => LlmToolChoice::Any,
572        LanguageModelToolChoice::None => LlmToolChoice::None,
573    });
574
575    LlmCompletionRequest {
576        messages,
577        tools,
578        tool_choice,
579        stop_sequences: request.stop.clone(),
580        temperature: request.temperature,
581        thinking_allowed: request.thinking_allowed,
582        max_tokens: None,
583    }
584}
585
586fn convert_completion_event(event: LlmCompletionEvent) -> LanguageModelCompletionEvent {
587    match event {
588        LlmCompletionEvent::Started => LanguageModelCompletionEvent::Started,
589        LlmCompletionEvent::Text(text) => LanguageModelCompletionEvent::Text(text),
590        LlmCompletionEvent::Thinking(thinking) => LanguageModelCompletionEvent::Thinking {
591            text: thinking.text,
592            signature: thinking.signature,
593        },
594        LlmCompletionEvent::RedactedThinking(data) => {
595            LanguageModelCompletionEvent::RedactedThinking { data }
596        }
597        LlmCompletionEvent::ToolUse(tool_use) => {
598            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
599                id: LanguageModelToolUseId::from(tool_use.id),
600                name: tool_use.name.into(),
601                raw_input: tool_use.input.clone(),
602                input: serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null),
603                is_input_complete: true,
604                thought_signature: tool_use.thought_signature,
605            })
606        }
607        LlmCompletionEvent::ToolUseJsonParseError(error) => {
608            LanguageModelCompletionEvent::ToolUseJsonParseError {
609                id: LanguageModelToolUseId::from(error.id),
610                tool_name: error.tool_name.into(),
611                raw_input: error.raw_input.into(),
612                json_parse_error: error.error,
613            }
614        }
615        LlmCompletionEvent::Stop(reason) => LanguageModelCompletionEvent::Stop(match reason {
616            LlmStopReason::EndTurn => StopReason::EndTurn,
617            LlmStopReason::MaxTokens => StopReason::MaxTokens,
618            LlmStopReason::ToolUse => StopReason::ToolUse,
619            LlmStopReason::Refusal => StopReason::Refusal,
620        }),
621        LlmCompletionEvent::Usage(usage) => LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
622            input_tokens: usage.input_tokens,
623            output_tokens: usage.output_tokens,
624            cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
625            cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
626        }),
627        LlmCompletionEvent::ReasoningDetails(json) => {
628            LanguageModelCompletionEvent::ReasoningDetails(
629                serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
630            )
631        }
632    }
633}