1use crate::wasm_host::WasmExtension;
2
3use crate::wasm_host::wit::{
4 LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
5 LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
6 LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
7 LlmToolUse,
8};
9use anyhow::{Result, anyhow};
10use futures::future::BoxFuture;
11use futures::stream::BoxStream;
12use futures::{FutureExt, StreamExt};
13use gpui::{AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, Window};
14use language_model::tool_schema::LanguageModelToolSchemaFormat;
15use language_model::{
16 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
17 LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
18 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
19 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
20 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage,
21};
22use std::sync::Arc;
23
24/// An extension-based language model provider.
25pub struct ExtensionLanguageModelProvider {
26 pub extension: WasmExtension,
27 pub provider_info: LlmProviderInfo,
28 state: Entity<ExtensionLlmProviderState>,
29}
30
31pub struct ExtensionLlmProviderState {
32 is_authenticated: bool,
33 available_models: Vec<LlmModelInfo>,
34}
35
36impl EventEmitter<()> for ExtensionLlmProviderState {}
37
38impl ExtensionLanguageModelProvider {
39 pub fn new(
40 extension: WasmExtension,
41 provider_info: LlmProviderInfo,
42 models: Vec<LlmModelInfo>,
43 is_authenticated: bool,
44 cx: &mut App,
45 ) -> Self {
46 let state = cx.new(|_| ExtensionLlmProviderState {
47 is_authenticated,
48 available_models: models,
49 });
50
51 Self {
52 extension,
53 provider_info,
54 state,
55 }
56 }
57
58 fn provider_id_string(&self) -> String {
59 format!("{}:{}", self.extension.manifest.id, self.provider_info.id)
60 }
61}
62
63impl LanguageModelProvider for ExtensionLanguageModelProvider {
64 fn id(&self) -> LanguageModelProviderId {
65 let id = LanguageModelProviderId::from(self.provider_id_string());
66 eprintln!("ExtensionLanguageModelProvider::id() -> {:?}", id);
67 id
68 }
69
70 fn name(&self) -> LanguageModelProviderName {
71 LanguageModelProviderName::from(self.provider_info.name.clone())
72 }
73
74 fn icon(&self) -> ui::IconName {
75 ui::IconName::ZedAssistant
76 }
77
78 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
79 let state = self.state.read(cx);
80 state
81 .available_models
82 .iter()
83 .find(|m| m.is_default)
84 .or_else(|| state.available_models.first())
85 .map(|model_info| {
86 Arc::new(ExtensionLanguageModel {
87 extension: self.extension.clone(),
88 model_info: model_info.clone(),
89 provider_id: self.id(),
90 provider_name: self.name(),
91 provider_info: self.provider_info.clone(),
92 }) as Arc<dyn LanguageModel>
93 })
94 }
95
96 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
97 let state = self.state.read(cx);
98 state
99 .available_models
100 .iter()
101 .find(|m| m.is_default_fast)
102 .or_else(|| state.available_models.iter().find(|m| m.is_default))
103 .or_else(|| state.available_models.first())
104 .map(|model_info| {
105 Arc::new(ExtensionLanguageModel {
106 extension: self.extension.clone(),
107 model_info: model_info.clone(),
108 provider_id: self.id(),
109 provider_name: self.name(),
110 provider_info: self.provider_info.clone(),
111 }) as Arc<dyn LanguageModel>
112 })
113 }
114
115 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
116 let state = self.state.read(cx);
117 eprintln!(
118 "ExtensionLanguageModelProvider::provided_models called for {}, returning {} models",
119 self.provider_info.name,
120 state.available_models.len()
121 );
122 state
123 .available_models
124 .iter()
125 .map(|model_info| {
126 eprintln!(" - model: {}", model_info.name);
127 Arc::new(ExtensionLanguageModel {
128 extension: self.extension.clone(),
129 model_info: model_info.clone(),
130 provider_id: self.id(),
131 provider_name: self.name(),
132 provider_info: self.provider_info.clone(),
133 }) as Arc<dyn LanguageModel>
134 })
135 .collect()
136 }
137
138 fn is_authenticated(&self, cx: &App) -> bool {
139 self.state.read(cx).is_authenticated
140 }
141
142 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
143 let extension = self.extension.clone();
144 let provider_id = self.provider_info.id.clone();
145 let state = self.state.clone();
146
147 cx.spawn(async move |cx| {
148 let result = extension
149 .call(|extension, store| {
150 async move {
151 extension
152 .call_llm_provider_authenticate(store, &provider_id)
153 .await
154 }
155 .boxed()
156 })
157 .await;
158
159 match result {
160 Ok(Ok(Ok(()))) => {
161 cx.update(|cx| {
162 state.update(cx, |state, _| {
163 state.is_authenticated = true;
164 });
165 })?;
166 Ok(())
167 }
168 Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
169 Ok(Err(e)) => Err(AuthenticateError::Other(e)),
170 Err(e) => Err(AuthenticateError::Other(e)),
171 }
172 })
173 }
174
175 fn configuration_view(
176 &self,
177 _target_agent: ConfigurationViewTargetAgent,
178 _window: &mut Window,
179 cx: &mut App,
180 ) -> AnyView {
181 cx.new(|_| EmptyConfigView).into()
182 }
183
184 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
185 let extension = self.extension.clone();
186 let provider_id = self.provider_info.id.clone();
187 let state = self.state.clone();
188
189 cx.spawn(async move |cx| {
190 let result = extension
191 .call(|extension, store| {
192 async move {
193 extension
194 .call_llm_provider_reset_credentials(store, &provider_id)
195 .await
196 }
197 .boxed()
198 })
199 .await;
200
201 match result {
202 Ok(Ok(Ok(()))) => {
203 cx.update(|cx| {
204 state.update(cx, |state, _| {
205 state.is_authenticated = false;
206 });
207 })?;
208 Ok(())
209 }
210 Ok(Ok(Err(e))) => Err(anyhow!("{}", e)),
211 Ok(Err(e)) => Err(e),
212 Err(e) => Err(e),
213 }
214 })
215 }
216}
217
218impl LanguageModelProviderState for ExtensionLanguageModelProvider {
219 type ObservableEntity = ExtensionLlmProviderState;
220
221 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
222 Some(self.state.clone())
223 }
224
225 fn subscribe<T: 'static>(
226 &self,
227 cx: &mut Context<T>,
228 callback: impl Fn(&mut T, &mut Context<T>) + 'static,
229 ) -> Option<gpui::Subscription> {
230 Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx)))
231 }
232}
233
234struct EmptyConfigView;
235
236impl gpui::Render for EmptyConfigView {
237 fn render(
238 &mut self,
239 _window: &mut Window,
240 _cx: &mut gpui::Context<Self>,
241 ) -> impl gpui::IntoElement {
242 gpui::Empty
243 }
244}
245
246/// An extension-based language model.
247pub struct ExtensionLanguageModel {
248 extension: WasmExtension,
249 model_info: LlmModelInfo,
250 provider_id: LanguageModelProviderId,
251 provider_name: LanguageModelProviderName,
252 provider_info: LlmProviderInfo,
253}
254
255impl LanguageModel for ExtensionLanguageModel {
256 fn id(&self) -> LanguageModelId {
257 LanguageModelId::from(format!("{}:{}", self.provider_id.0, self.model_info.id))
258 }
259
260 fn name(&self) -> LanguageModelName {
261 LanguageModelName::from(self.model_info.name.clone())
262 }
263
264 fn provider_id(&self) -> LanguageModelProviderId {
265 self.provider_id.clone()
266 }
267
268 fn provider_name(&self) -> LanguageModelProviderName {
269 self.provider_name.clone()
270 }
271
272 fn telemetry_id(&self) -> String {
273 format!("extension:{}", self.model_info.id)
274 }
275
276 fn supports_images(&self) -> bool {
277 self.model_info.capabilities.supports_images
278 }
279
280 fn supports_tools(&self) -> bool {
281 self.model_info.capabilities.supports_tools
282 }
283
284 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
285 match choice {
286 LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto,
287 LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any,
288 LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none,
289 }
290 }
291
292 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
293 match self.model_info.capabilities.tool_input_format {
294 LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema,
295 LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema,
296 }
297 }
298
299 fn max_token_count(&self) -> u64 {
300 self.model_info.max_token_count
301 }
302
303 fn max_output_tokens(&self) -> Option<u64> {
304 self.model_info.max_output_tokens
305 }
306
307 fn count_tokens(
308 &self,
309 request: LanguageModelRequest,
310 _cx: &App,
311 ) -> BoxFuture<'static, Result<u64>> {
312 let extension = self.extension.clone();
313 let provider_id = self.provider_info.id.clone();
314 let model_id = self.model_info.id.clone();
315
316 async move {
317 let wit_request = convert_request_to_wit(&request);
318
319 let result = extension
320 .call(|ext, store| {
321 async move {
322 ext.call_llm_count_tokens(store, &provider_id, &model_id, &wit_request)
323 .await
324 }
325 .boxed()
326 })
327 .await?;
328
329 match result {
330 Ok(Ok(count)) => Ok(count),
331 Ok(Err(e)) => Err(anyhow!("{}", e)),
332 Err(e) => Err(e),
333 }
334 }
335 .boxed()
336 }
337
338 fn stream_completion(
339 &self,
340 request: LanguageModelRequest,
341 _cx: &AsyncApp,
342 ) -> BoxFuture<
343 'static,
344 Result<
345 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
346 LanguageModelCompletionError,
347 >,
348 > {
349 let extension = self.extension.clone();
350 let provider_id = self.provider_info.id.clone();
351 let model_id = self.model_info.id.clone();
352
353 async move {
354 let wit_request = convert_request_to_wit(&request);
355
356 // Start the stream and get a stream ID
357 let outer_result = extension
358 .call(|ext, store| {
359 async move {
360 ext.call_llm_stream_completion_start(
361 store,
362 &provider_id,
363 &model_id,
364 &wit_request,
365 )
366 .await
367 }
368 .boxed()
369 })
370 .await
371 .map_err(|e| LanguageModelCompletionError::Other(e))?;
372
373 // Unwrap the inner Result<Result<String, String>>
374 let inner_result =
375 outer_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
376
377 // Get the stream ID
378 let stream_id =
379 inner_result.map_err(|e| LanguageModelCompletionError::Other(anyhow!("{}", e)))?;
380
381 // Create a stream that polls for events
382 let stream = futures::stream::unfold(
383 (extension, stream_id, false),
384 |(ext, stream_id, done)| async move {
385 if done {
386 return None;
387 }
388
389 let result = ext
390 .call({
391 let stream_id = stream_id.clone();
392 move |ext, store| {
393 async move {
394 ext.call_llm_stream_completion_next(store, &stream_id).await
395 }
396 .boxed()
397 }
398 })
399 .await;
400
401 match result {
402 Ok(Ok(Ok(Some(event)))) => {
403 let converted = convert_completion_event(event);
404 Some((Ok(converted), (ext, stream_id, false)))
405 }
406 Ok(Ok(Ok(None))) => {
407 // Stream complete - close it
408 let _ = ext
409 .call({
410 let stream_id = stream_id.clone();
411 move |ext, store| {
412 async move {
413 ext.call_llm_stream_completion_close(store, &stream_id)
414 .await
415 }
416 .boxed()
417 }
418 })
419 .await;
420 None
421 }
422 Ok(Ok(Err(e))) => {
423 // Extension returned an error - close stream and return error
424 let _ = ext
425 .call({
426 let stream_id = stream_id.clone();
427 move |ext, store| {
428 async move {
429 ext.call_llm_stream_completion_close(store, &stream_id)
430 .await
431 }
432 .boxed()
433 }
434 })
435 .await;
436 Some((
437 Err(LanguageModelCompletionError::Other(anyhow!("{}", e))),
438 (ext, stream_id, true),
439 ))
440 }
441 Ok(Err(e)) => {
442 // WASM call error - close stream and return error
443 let _ = ext
444 .call({
445 let stream_id = stream_id.clone();
446 move |ext, store| {
447 async move {
448 ext.call_llm_stream_completion_close(store, &stream_id)
449 .await
450 }
451 .boxed()
452 }
453 })
454 .await;
455 Some((
456 Err(LanguageModelCompletionError::Other(e)),
457 (ext, stream_id, true),
458 ))
459 }
460 Err(e) => {
461 // Channel error - close stream and return error
462 let _ = ext
463 .call({
464 let stream_id = stream_id.clone();
465 move |ext, store| {
466 async move {
467 ext.call_llm_stream_completion_close(store, &stream_id)
468 .await
469 }
470 .boxed()
471 }
472 })
473 .await;
474 Some((
475 Err(LanguageModelCompletionError::Other(e)),
476 (ext, stream_id, true),
477 ))
478 }
479 }
480 },
481 );
482
483 Ok(stream.boxed())
484 }
485 .boxed()
486 }
487
488 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
489 None
490 }
491}
492
493fn convert_request_to_wit(request: &LanguageModelRequest) -> LlmCompletionRequest {
494 let messages = request
495 .messages
496 .iter()
497 .map(|msg| LlmRequestMessage {
498 role: match msg.role {
499 language_model::Role::User => LlmMessageRole::User,
500 language_model::Role::Assistant => LlmMessageRole::Assistant,
501 language_model::Role::System => LlmMessageRole::System,
502 },
503 content: msg
504 .content
505 .iter()
506 .map(|content| match content {
507 language_model::MessageContent::Text(text) => {
508 LlmMessageContent::Text(text.clone())
509 }
510 language_model::MessageContent::Image(image) => {
511 LlmMessageContent::Image(LlmImageData {
512 source: image.source.to_string(),
513 width: Some(image.size.width.0 as u32),
514 height: Some(image.size.height.0 as u32),
515 })
516 }
517 language_model::MessageContent::ToolUse(tool_use) => {
518 LlmMessageContent::ToolUse(LlmToolUse {
519 id: tool_use.id.to_string(),
520 name: tool_use.name.to_string(),
521 input: tool_use.raw_input.clone(),
522 thought_signature: tool_use.thought_signature.clone(),
523 })
524 }
525 language_model::MessageContent::ToolResult(result) => {
526 LlmMessageContent::ToolResult(LlmToolResult {
527 tool_use_id: result.tool_use_id.to_string(),
528 tool_name: result.tool_name.to_string(),
529 is_error: result.is_error,
530 content: match &result.content {
531 language_model::LanguageModelToolResultContent::Text(t) => {
532 LlmToolResultContent::Text(t.to_string())
533 }
534 language_model::LanguageModelToolResultContent::Image(img) => {
535 LlmToolResultContent::Image(LlmImageData {
536 source: img.source.to_string(),
537 width: Some(img.size.width.0 as u32),
538 height: Some(img.size.height.0 as u32),
539 })
540 }
541 },
542 })
543 }
544 language_model::MessageContent::Thinking { text, signature } => {
545 LlmMessageContent::Thinking(LlmThinkingContent {
546 text: text.clone(),
547 signature: signature.clone(),
548 })
549 }
550 language_model::MessageContent::RedactedThinking(data) => {
551 LlmMessageContent::RedactedThinking(data.clone())
552 }
553 })
554 .collect(),
555 cache: msg.cache,
556 })
557 .collect();
558
559 let tools = request
560 .tools
561 .iter()
562 .map(|tool| LlmToolDefinition {
563 name: tool.name.clone(),
564 description: tool.description.clone(),
565 input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(),
566 })
567 .collect();
568
569 let tool_choice = request.tool_choice.as_ref().map(|choice| match choice {
570 LanguageModelToolChoice::Auto => LlmToolChoice::Auto,
571 LanguageModelToolChoice::Any => LlmToolChoice::Any,
572 LanguageModelToolChoice::None => LlmToolChoice::None,
573 });
574
575 LlmCompletionRequest {
576 messages,
577 tools,
578 tool_choice,
579 stop_sequences: request.stop.clone(),
580 temperature: request.temperature,
581 thinking_allowed: request.thinking_allowed,
582 max_tokens: None,
583 }
584}
585
586fn convert_completion_event(event: LlmCompletionEvent) -> LanguageModelCompletionEvent {
587 match event {
588 LlmCompletionEvent::Started => LanguageModelCompletionEvent::Started,
589 LlmCompletionEvent::Text(text) => LanguageModelCompletionEvent::Text(text),
590 LlmCompletionEvent::Thinking(thinking) => LanguageModelCompletionEvent::Thinking {
591 text: thinking.text,
592 signature: thinking.signature,
593 },
594 LlmCompletionEvent::RedactedThinking(data) => {
595 LanguageModelCompletionEvent::RedactedThinking { data }
596 }
597 LlmCompletionEvent::ToolUse(tool_use) => {
598 LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
599 id: LanguageModelToolUseId::from(tool_use.id),
600 name: tool_use.name.into(),
601 raw_input: tool_use.input.clone(),
602 input: serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null),
603 is_input_complete: true,
604 thought_signature: tool_use.thought_signature,
605 })
606 }
607 LlmCompletionEvent::ToolUseJsonParseError(error) => {
608 LanguageModelCompletionEvent::ToolUseJsonParseError {
609 id: LanguageModelToolUseId::from(error.id),
610 tool_name: error.tool_name.into(),
611 raw_input: error.raw_input.into(),
612 json_parse_error: error.error,
613 }
614 }
615 LlmCompletionEvent::Stop(reason) => LanguageModelCompletionEvent::Stop(match reason {
616 LlmStopReason::EndTurn => StopReason::EndTurn,
617 LlmStopReason::MaxTokens => StopReason::MaxTokens,
618 LlmStopReason::ToolUse => StopReason::ToolUse,
619 LlmStopReason::Refusal => StopReason::Refusal,
620 }),
621 LlmCompletionEvent::Usage(usage) => LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
622 input_tokens: usage.input_tokens,
623 output_tokens: usage.output_tokens,
624 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
625 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
626 }),
627 LlmCompletionEvent::ReasoningDetails(json) => {
628 LanguageModelCompletionEvent::ReasoningDetails(
629 serde_json::from_str(&json).unwrap_or(serde_json::Value::Null),
630 )
631 }
632 }
633}