1use crate::wasm_host::WasmExtension;
2
3use crate::wasm_host::wit::{
4 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
5 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
6 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
7 LlmToolUse,
8};
9use anyhow::{Result, anyhow};
10use futures::future::BoxFuture;
11use futures::stream::BoxStream;
12use futures::{FutureExt, StreamExt};
13use gpui::{AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, Window};
14use language_model::tool_schema::LanguageModelToolSchemaFormat;
15use language_model::{
16 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
17 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
18 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
19 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
20 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
21};
22use std::sync::Arc;
23
24/// An extension-based language model provider.
25pub struct ExtensionLanguageModelProvider {
26 pub extension: WasmExtension,
27 pub provider_info: LlmProviderInfo,
28 state: Entity<ExtensionLlmProviderState>,
29}
30
31pub struct ExtensionLlmProviderState {
32 is_authenticated: bool,
33 available_models: Vec<LlmModelInfo>,
34}
35
36impl EventEmitter<()> for ExtensionLlmProviderState {}
37
38impl ExtensionLanguageModelProvider {
39 pub fn new(
40 extension: WasmExtension,
41 provider_info: LlmProviderInfo,
42 models: Vec<LlmModelInfo>,
43 cx: &mut App,
44 ) -> Self {
45 let state = cx.new(|_| ExtensionLlmProviderState {
46 is_authenticated: false,
47 available_models: models,
48 });
49
50 Self {
51 extension,
52 provider_info,
53 state,
54 }
55 }
56
57 fn provider_id_string(&self) -> String {
58 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
59 }
60}
61
62impl LanguageModelProvider for ExtensionLanguageModelProvider {
63 fn id(&self) -> LanguageModelProviderId {
64 LanguageModelProviderId::from(self.provider_id_string())
65 }
66
67 fn name(&self) -> LanguageModelProviderName {
68 LanguageModelProviderName::from(self.provider_info.name.clone())
69 }
70
71 fn icon(&self) -> ui::IconName {
72 ui::IconName::ZedAssistant
73 }
74
75 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
76 let state = self.state.read(cx);
77 state
78 .available_models
79 .iter()
80 .find(|m| m.is_default)
81 .or_else(|| state.available_models.first())
82 .map(|model_info| {
83 Arc::new(ExtensionLanguageModel {
84 extension: self.extension.clone(),
85 model_info: model_info.clone(),
86 provider_id: self.id(),
87 provider_name: self.name(),
88 provider_info: self.provider_info.clone(),
89 }) as Arc<dyn LanguageModel>
90 })
91 }
92
93 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
94 let state = self.state.read(cx);
95 state
96 .available_models
97 .iter()
98 .find(|m| m.is_default_fast)
99 .or_else(|| state.available_models.iter().find(|m| m.is_default))
100 .or_else(|| state.available_models.first())
101 .map(|model_info| {
102 Arc::new(ExtensionLanguageModel {
103 extension: self.extension.clone(),
104 model_info: model_info.clone(),
105 provider_id: self.id(),
106 provider_name: self.name(),
107 provider_info: self.provider_info.clone(),
108 }) as Arc<dyn LanguageModel>
109 })
110 }
111
112 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
113 let state = self.state.read(cx);
114 state
115 .available_models
116 .iter()
117 .map(|model_info| {
118 Arc::new(ExtensionLanguageModel {
119 extension: self.extension.clone(),
120 model_info: model_info.clone(),
121 provider_id: self.id(),
122 provider_name: self.name(),
123 provider_info: self.provider_info.clone(),
124 }) as Arc<dyn LanguageModel>
125 })
126 .collect()
127 }
128
129 fn is_authenticated(&self, cx: &App) -> bool {
130 self.state.read(cx).is_authenticated
131 }
132
133 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
134 let extension = self.extension.clone();
135 let provider_id = self.provider_info.id.clone();
136 let state = self.state.clone();
137
138 cx.spawn(async move |cx| {
139 let result = extension
140 .call(|extension, store| {
141 async move {
142 extension
143 .call_llm_provider_authenticate(store, &provider_id)
144 .await
145 }
146 .boxed()
147 })
148 .await;
149
150 match result {
151 Ok(Ok(Ok(()))) => {
152 cx.update(|cx| {
153 state.update(cx, |state, _| {
154 state.is_authenticated = true;
155 });
156 })?;
157 Ok(())
158 }
159 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
160 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
161 Err(e) => Err(AuthenticateError::Other(e)),
162 }
163 })
164 }
165
166 fn configuration_view(
167 &self,
168 _target_agent: ConfigurationViewTargetAgent,
169 _window: &mut Window,
170 cx: &mut App,
171 ) -> AnyView {
172 cx.new(|_| EmptyConfigView).into()
173 }
174
175 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
176 let extension = self.extension.clone();
177 let provider_id = self.provider_info.id.clone();
178 let state = self.state.clone();
179
180 cx.spawn(async move |cx| {
181 let result = extension
182 .call(|extension, store| {
183 async move {
184 extension
185 .call_llm_provider_reset_credentials(store, &provider_id)
186 .await
187 }
188 .boxed()
189 })
190 .await;
191
192 match result {
193 Ok(Ok(Ok(()))) => {
194 cx.update(|cx| {
195 state.update(cx, |state, _| {
196 state.is_authenticated = false;
197 });
198 })?;
199 Ok(())
200 }
201 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
202 Ok(Err(e)) => Err(e),
203 Err(e) => Err(e),
204 }
205 })
206 }
207}
208
209impl LanguageModelProviderState for ExtensionLanguageModelProvider {
210 type ObservableEntity = ExtensionLlmProviderState;
211
212 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
213 Some(self.state.clone())
214 }
215
216 fn subscribe<T: 'static>(
217 &self,
218 cx: &mut Context<T>,
219 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
220 ) -> Option<gpui::Subscription> {
221 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
222 }
223}
224
225struct EmptyConfigView;
226
227impl gpui::Render for EmptyConfigView {
228 fn render(
229 &mut self,
230 _window: &mut Window,
231 _cx: &mut gpui::Context<Self>,
232 ) -> impl gpui::IntoElement {
233 gpui::Empty
234 }
235}
236
237/// An extension-based language model.
238pub struct ExtensionLanguageModel {
239 extension: WasmExtension,
240 model_info: LlmModelInfo,
241 provider_id: LanguageModelProviderId,
242 provider_name: LanguageModelProviderName,
243 provider_info: LlmProviderInfo,
244}
245
246impl LanguageModel for ExtensionLanguageModel {
247 fn id(&self) -> LanguageModelId {
248 LanguageModelId::from(format!("{}:{}", self.provider_id.0, self.model_info.id))
249 }
250
251 fn name(&self) -> LanguageModelName {
252 LanguageModelName::from(self.model_info.name.clone())
253 }
254
255 fn provider_id(&self) -> LanguageModelProviderId {
256 self.provider_id.clone()
257 }
258
259 fn provider_name(&self) -> LanguageModelProviderName {
260 self.provider_name.clone()
261 }
262
263 fn telemetry_id(&self) -> String {
264 format!("extension:{}", self.model_info.id)
265 }
266
267 fn supports_images(&self) -> bool {
268 self.model_info.capabilities.supports_images
269 }
270
271 fn supports_tools(&self) -> bool {
272 self.model_info.capabilities.supports_tools
273 }
274
275 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
276 match choice {
277 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
278 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
279 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
280 }
281 }
282
283 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
284 match self.model_info.capabilities.tool_input_format {
285 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
286 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
287 }
288 }
289
290 fn max_token_count(&self) -> u64 {
291 self.model_info.max_token_count
292 }
293
294 fn max_output_tokens(&self) -> Option<u64> {
295 self.model_info.max_output_tokens
296 }
297
298 fn count_tokens(
299 &self,
300 request: LanguageModelRequest,
301 _cx: &App,
302 ) -> BoxFuture<'static, Result<u64>> {
303 let extension = self.extension.clone();
304 let provider_id = self.provider_info.id.clone();
305 let model_id = self.model_info.id.clone();
306
307 async move {
308 let wit_request = convert_request_to_wit(&request);
309
310 let result = extension
311 .call(|ext, store| {
312 async move {
313 ext.call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
314 .await
315 }
316 .boxed()
317 })
318 .await?;
319
320 match result {
321 Ok(Ok(count)) => Ok(count),
322 Ok(Err(e)) => Err(anyhow!("{}", e)),
323 Err(e) => Err(e),
324 }
325 }
326 .boxed()
327 }
328
329 fn stream_completion(
330 &self,
331 request: LanguageModelRequest,
332 _cx: &AsyncApp,
333 ) -> BoxFuture<
334 'static,
335 Result<
336 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
337 LanguageModelCompletionError,
338 >,
339 > {
340 let extension = self.extension.clone();
341 let provider_id = self.provider_info.id.clone();
342 let model_id = self.model_info.id.clone();
343
344 async move {
345 let wit_request = convert_request_to_wit(&request);
346
347 // Start the stream and get a stream ID
348 let outer_result = extension
349 .call(|ext, store| {
350 async move {
351 ext.call_llm_stream_completion_start(
352 store,
353 &provider_id,
354 &model_id,
355 &wit_request,
356 )
357 .await
358 }
359 .boxed()
360 })
361 .await
362 .map_err(|e| LanguageModelCompletionError::Other(e))?;
363
364 // Unwrap the inner Result<Result<String, String>>
365 let inner_result =
366 outer_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
367
368 // Get the stream ID
369 let stream_id =
370 inner_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
371
372 // Create a stream that polls for events
373 let stream = futures::stream::unfold(
374 (extension, stream_id, false),
375 |(ext, stream_id, done)| async move {
376 if done {
377 return None;
378 }
379
380 let result = ext
381 .call({
382 let stream_id = stream_id.clone();
383 move |ext, store| {
384 async move {
385 ext.call_llm_stream_completion_next(store, &stream_id).await
386 }
387 .boxed()
388 }
389 })
390 .await;
391
392 match result {
393 Ok(Ok(Ok(Some(event)))) => {
394 let converted = convert_completion_event(event);
395 Some((Ok(converted), (ext, stream_id, false)))
396 }
397 Ok(Ok(Ok(None))) => {
398 // Stream complete - close it
399 let _ = ext
400 .call({
401 let stream_id = stream_id.clone();
402 move |ext, store| {
403 async move {
404 ext.call_llm_stream_completion_close(store, &stream_id)
405 .await
406 }
407 .boxed()
408 }
409 })
410 .await;
411 None
412 }
413 Ok(Ok(Err(e))) => {
414 // Extension returned an error - close stream and return error
415 let _ = ext
416 .call({
417 let stream_id = stream_id.clone();
418 move |ext, store| {
419 async move {
420 ext.call_llm_stream_completion_close(store, &stream_id)
421 .await
422 }
423 .boxed()
424 }
425 })
426 .await;
427 Some((
428 Err(LanguageModelCompletionError::Other(anyhow!("{}", e))),
429 (ext, stream_id, true),
430 ))
431 }
432 Ok(Err(e)) => {
433 // WASM call error - close stream and return error
434 let _ = ext
435 .call({
436 let stream_id = stream_id.clone();
437 move |ext, store| {
438 async move {
439 ext.call_llm_stream_completion_close(store, &stream_id)
440 .await
441 }
442 .boxed()
443 }
444 })
445 .await;
446 Some((
447 Err(LanguageModelCompletionError::Other(e)),
448 (ext, stream_id, true),
449 ))
450 }
451 Err(e) => {
452 // Channel error - close stream and return error
453 let _ = ext
454 .call({
455 let stream_id = stream_id.clone();
456 move |ext, store| {
457 async move {
458 ext.call_llm_stream_completion_close(store, &stream_id)
459 .await
460 }
461 .boxed()
462 }
463 })
464 .await;
465 Some((
466 Err(LanguageModelCompletionError::Other(e)),
467 (ext, stream_id, true),
468 ))
469 }
470 }
471 },
472 );
473
474 Ok(stream.boxed())
475 }
476 .boxed()
477 }
478
479 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
480 None
481 }
482}
483
484fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionRequest {
485 let messages = request
486 .messages
487 .iter()
488 .map(|msg| LlmRequestMessage {
489 role: match msg.role {
490 language_model::Role::User => LlmMessageRole::User,
491 language_model::Role::Assistant => LlmMessageRole::Assistant,
492 language_model::Role::System => LlmMessageRole::System,
493 },
494 content: msg
495 .content
496 .iter()
497 .map(|content| match content {
498 language_model::MessageContent::Text(text) => {
499 LlmMessageContent::Text(text.clone())
500 }
501 language_model::MessageContent::Image(image) => {
502 LlmMessageContent::Image(LlmImageData {
503 source: image.source.to_string(),
504 width: Some(image.size.width.0 as u32),
505 height: Some(image.size.height.0 as u32),
506 })
507 }
508 language_model::MessageContent::ToolUse(tool_use) => {
509 LlmMessageContent::ToolUse(LlmToolUse {
510 id: tool_use.id.to_string(),
511 name: tool_use.name.to_string(),
512 input: tool_use.raw_input.clone(),
513 thought_signature: tool_use.thought_signature.clone(),
514 })
515 }
516 language_model::MessageContent::ToolResult(result) => {
517 LlmMessageContent::ToolResult(LlmToolResult {
518 tool_use_id: result.tool_use_id.to_string(),
519 tool_name: result.tool_name.to_string(),
520 is_error: result.is_error,
521 content: match &result.content {
522 language_model::LanguageModelToolResultContent::Text(t) => {
523 LlmToolResultContent::Text(t.to_string())
524 }
525 language_model::LanguageModelToolResultContent::Image(img) => {
526 LlmToolResultContent::Image(LlmImageData {
527 source: img.source.to_string(),
528 width: Some(img.size.width.0 as u32),
529 height: Some(img.size.height.0 as u32),
530 })
531 }
532 },
533 })
534 }
535 language_model::MessageContent::Thinking { text, signature } => {
536 LlmMessageContent::Thinking(LlmThinkingContent {
537 text: text.clone(),
538 signature: signature.clone(),
539 })
540 }
541 language_model::MessageContent::RedactedThinking(data) => {
542 LlmMessageContent::RedactedThinking(data.clone())
543 }
544 })
545 .collect(),
546 cache: msg.cache,
547 })
548 .collect();
549
550 let tools = request
551 .tools
552 .iter()
553 .map(|tool| LlmToolDefinition {
554 name: tool.name.clone(),
555 description: tool.description.clone(),
556 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
557 })
558 .collect();
559
560 let tool_choice = request.tool_choice.as_ref().map(|choice| match choice {
561 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
562 LanguageModelToolChoice::Any => LlmToolChoice::Any,
563 LanguageModelToolChoice::None => LlmToolChoice::None,
564 });
565
566 LlmCompletionRequest {
567 messages,
568 tools,
569 tool_choice,
570 stop_sequences: request.stop.clone(),
571 temperature: request.temperature,
572 thinking_allowed: request.thinking_allowed,
573 max_tokens: None,
574 }
575}
576
577fn convert_completion_event(event: LlmCompletionEvent) -> LanguageModelCompletionEvent {
578 match event {
579 LlmCompletionEvent::Started => LanguageModelCompletionEvent::Started,
580 LlmCompletionEvent::Text(text) => LanguageModelCompletionEvent::Text(text),
581 LlmCompletionEvent::Thinking(thinking) => LanguageModelCompletionEvent::Thinking {
582 text: thinking.text,
583 signature: thinking.signature,
584 },
585 LlmCompletionEvent::RedactedThinking(data) => {
586 LanguageModelCompletionEvent::RedactedThinking { data }
587 }
588 LlmCompletionEvent::ToolUse(tool_use) => {
589 LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
590 id: LanguageModelToolUseId::from(tool_use.id),
591 name: tool_use.name.into(),
592 raw_input: tool_use.input.clone(),
593 input: serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null),
594 is_input_complete: true,
595 thought_signature: tool_use.thought_signature,
596 })
597 }
598 LlmCompletionEvent::ToolUseJsonParseError(error) => {
599 LanguageModelCompletionEvent::ToolUseJsonParseError {
600 id: LanguageModelToolUseId::from(error.id),
601 tool_name: error.tool_name.into(),
602 raw_input: error.raw_input.into(),
603 json_parse_error: error.error,
604 }
605 }
606 LlmCompletionEvent::Stop(reason) => LanguageModelCompletionEvent::Stop(match reason {
607 LlmStopReason::EndTurn => StopReason::EndTurn,
608 LlmStopReason::MaxTokens => StopReason::MaxTokens,
609 LlmStopReason::ToolUse => StopReason::ToolUse,
610 LlmStopReason::Refusal => StopReason::Refusal,
611 }),
612 LlmCompletionEvent::Usage(usage) => LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
613 input_tokens: usage.input_tokens,
614 output_tokens: usage.output_tokens,
615 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
616 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
617 }),
618 LlmCompletionEvent::ReasoningDetails(json) => {
619 LanguageModelCompletionEvent::ReasoningDetails(
620 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
621 )
622 }
623 }
624}