1mod api_key;
2mod model;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7mod telemetry;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anthropic::{AnthropicError, parse_prompt_too_long};
14use anyhow::{Result, anyhow};
15use client::Client;
16use client::UserStore;
17use cloud_llm_client::CompletionRequestStatus;
18use futures::FutureExt;
19use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
20use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
21use http_client::{StatusCode, http};
22use icons::IconName;
23use open_router::OpenRouterError;
24use parking_lot::Mutex;
25use serde::{Deserialize, Serialize};
26pub use settings::LanguageModelCacheConfiguration;
27use std::ops::{Add, Sub};
28use std::str::FromStr;
29use std::sync::Arc;
30use std::time::Duration;
31use std::{fmt, io};
32use thiserror::Error;
33use util::serde::is_default;
34
35pub use crate::api_key::{ApiKey, ApiKeyState};
36pub use crate::model::*;
37pub use crate::rate_limiter::*;
38pub use crate::registry::*;
39pub use crate::request::*;
40pub use crate::role::*;
41pub use crate::telemetry::*;
42pub use crate::tool_schema::LanguageModelToolSchemaFormat;
43pub use zed_env_vars::{EnvVar, env_var};
44
45pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
46 LanguageModelProviderId::new("anthropic");
47pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Anthropic");
49
50pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
51pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("Google AI");
53
54pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
55pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
56 LanguageModelProviderName::new("OpenAI");
57
58pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
59pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
60
61pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
62pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
63 LanguageModelProviderName::new("Zed");
64
65pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
66 init_settings(cx);
67 RefreshLlmTokenListener::register(client, user_store, cx);
68}
69
70pub fn init_settings(cx: &mut App) {
71 registry::init(cx);
72}
73
74/// A completion event from a language model.
75#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
76pub enum LanguageModelCompletionEvent {
77 Queued {
78 position: usize,
79 },
80 Started,
81 Stop(StopReason),
82 Text(String),
83 Thinking {
84 text: String,
85 signature: Option<String>,
86 },
87 RedactedThinking {
88 data: String,
89 },
90 ToolUse(LanguageModelToolUse),
91 ToolUseJsonParseError {
92 id: LanguageModelToolUseId,
93 tool_name: Arc<str>,
94 raw_input: Arc<str>,
95 json_parse_error: String,
96 },
97 StartMessage {
98 message_id: String,
99 },
100 ReasoningDetails(serde_json::Value),
101 UsageUpdate(TokenUsage),
102}
103
104impl LanguageModelCompletionEvent {
105 pub fn from_completion_request_status(
106 status: CompletionRequestStatus,
107 upstream_provider: LanguageModelProviderName,
108 ) -> Result<Option<Self>, LanguageModelCompletionError> {
109 match status {
110 CompletionRequestStatus::Queued { position } => {
111 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
112 }
113 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
114 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
115 CompletionRequestStatus::Failed {
116 code,
117 message,
118 request_id: _,
119 retry_after,
120 } => Err(LanguageModelCompletionError::from_cloud_failure(
121 upstream_provider,
122 code,
123 message,
124 retry_after.map(Duration::from_secs_f64),
125 )),
126 }
127 }
128}
129
130#[derive(Error, Debug)]
131pub enum LanguageModelCompletionError {
132 #[error("prompt too large for context window")]
133 PromptTooLarge { tokens: Option<u64> },
134 #[error("missing {provider} API key")]
135 NoApiKey { provider: LanguageModelProviderName },
136 #[error("{provider}'s API rate limit exceeded")]
137 RateLimitExceeded {
138 provider: LanguageModelProviderName,
139 retry_after: Option<Duration>,
140 },
141 #[error("{provider}'s API servers are overloaded right now")]
142 ServerOverloaded {
143 provider: LanguageModelProviderName,
144 retry_after: Option<Duration>,
145 },
146 #[error("{provider}'s API server reported an internal server error: {message}")]
147 ApiInternalServerError {
148 provider: LanguageModelProviderName,
149 message: String,
150 },
151 #[error("{message}")]
152 UpstreamProviderError {
153 message: String,
154 status: StatusCode,
155 retry_after: Option<Duration>,
156 },
157 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
158 HttpResponseError {
159 provider: LanguageModelProviderName,
160 status_code: StatusCode,
161 message: String,
162 },
163
164 // Client errors
165 #[error("invalid request format to {provider}'s API: {message}")]
166 BadRequestFormat {
167 provider: LanguageModelProviderName,
168 message: String,
169 },
170 #[error("authentication error with {provider}'s API: {message}")]
171 AuthenticationError {
172 provider: LanguageModelProviderName,
173 message: String,
174 },
175 #[error("Permission error with {provider}'s API: {message}")]
176 PermissionError {
177 provider: LanguageModelProviderName,
178 message: String,
179 },
180 #[error("language model provider API endpoint not found")]
181 ApiEndpointNotFound { provider: LanguageModelProviderName },
182 #[error("I/O error reading response from {provider}'s API")]
183 ApiReadResponseError {
184 provider: LanguageModelProviderName,
185 #[source]
186 error: io::Error,
187 },
188 #[error("error serializing request to {provider} API")]
189 SerializeRequest {
190 provider: LanguageModelProviderName,
191 #[source]
192 error: serde_json::Error,
193 },
194 #[error("error building request body to {provider} API")]
195 BuildRequestBody {
196 provider: LanguageModelProviderName,
197 #[source]
198 error: http::Error,
199 },
200 #[error("error sending HTTP request to {provider} API")]
201 HttpSend {
202 provider: LanguageModelProviderName,
203 #[source]
204 error: anyhow::Error,
205 },
206 #[error("error deserializing {provider} API response")]
207 DeserializeResponse {
208 provider: LanguageModelProviderName,
209 #[source]
210 error: serde_json::Error,
211 },
212
213 #[error("stream from {provider} ended unexpectedly")]
214 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
215
216 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
217 #[error(transparent)]
218 Other(#[from] anyhow::Error),
219}
220
221impl LanguageModelCompletionError {
222 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
223 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
224 let upstream_status = error_json
225 .get("upstream_status")
226 .and_then(|v| v.as_u64())
227 .and_then(|status| u16::try_from(status).ok())
228 .and_then(|status| StatusCode::from_u16(status).ok())?;
229 let inner_message = error_json
230 .get("message")
231 .and_then(|v| v.as_str())
232 .unwrap_or(message)
233 .to_string();
234 Some((upstream_status, inner_message))
235 }
236
237 pub fn from_cloud_failure(
238 upstream_provider: LanguageModelProviderName,
239 code: String,
240 message: String,
241 retry_after: Option<Duration>,
242 ) -> Self {
243 if let Some(tokens) = parse_prompt_too_long(&message) {
244 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
245 // to be reported. This is a temporary workaround to handle this in the case where the
246 // token limit has been exceeded.
247 Self::PromptTooLarge {
248 tokens: Some(tokens),
249 }
250 } else if code == "upstream_http_error" {
251 if let Some((upstream_status, inner_message)) =
252 Self::parse_upstream_error_json(&message)
253 {
254 return Self::from_http_status(
255 upstream_provider,
256 upstream_status,
257 inner_message,
258 retry_after,
259 );
260 }
261 anyhow!("completion request failed, code: {code}, message: {message}").into()
262 } else if let Some(status_code) = code
263 .strip_prefix("upstream_http_")
264 .and_then(|code| StatusCode::from_str(code).ok())
265 {
266 Self::from_http_status(upstream_provider, status_code, message, retry_after)
267 } else if let Some(status_code) = code
268 .strip_prefix("http_")
269 .and_then(|code| StatusCode::from_str(code).ok())
270 {
271 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
272 } else {
273 anyhow!("completion request failed, code: {code}, message: {message}").into()
274 }
275 }
276
277 pub fn from_http_status(
278 provider: LanguageModelProviderName,
279 status_code: StatusCode,
280 message: String,
281 retry_after: Option<Duration>,
282 ) -> Self {
283 match status_code {
284 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
285 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
286 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
287 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
288 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
289 tokens: parse_prompt_too_long(&message),
290 },
291 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
292 provider,
293 retry_after,
294 },
295 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
296 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
297 provider,
298 retry_after,
299 },
300 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
301 provider,
302 retry_after,
303 },
304 _ => Self::HttpResponseError {
305 provider,
306 status_code,
307 message,
308 },
309 }
310 }
311}
312
313impl From<AnthropicError> for LanguageModelCompletionError {
314 fn from(error: AnthropicError) -> Self {
315 let provider = ANTHROPIC_PROVIDER_NAME;
316 match error {
317 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
318 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
319 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
320 AnthropicError::DeserializeResponse(error) => {
321 Self::DeserializeResponse { provider, error }
322 }
323 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
324 AnthropicError::HttpResponseError {
325 status_code,
326 message,
327 } => Self::HttpResponseError {
328 provider,
329 status_code,
330 message,
331 },
332 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
333 provider,
334 retry_after: Some(retry_after),
335 },
336 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
337 provider,
338 retry_after,
339 },
340 AnthropicError::ApiError(api_error) => api_error.into(),
341 }
342 }
343}
344
345impl From<anthropic::ApiError> for LanguageModelCompletionError {
346 fn from(error: anthropic::ApiError) -> Self {
347 use anthropic::ApiErrorCode::*;
348 let provider = ANTHROPIC_PROVIDER_NAME;
349 match error.code() {
350 Some(code) => match code {
351 InvalidRequestError => Self::BadRequestFormat {
352 provider,
353 message: error.message,
354 },
355 AuthenticationError => Self::AuthenticationError {
356 provider,
357 message: error.message,
358 },
359 PermissionError => Self::PermissionError {
360 provider,
361 message: error.message,
362 },
363 NotFoundError => Self::ApiEndpointNotFound { provider },
364 RequestTooLarge => Self::PromptTooLarge {
365 tokens: parse_prompt_too_long(&error.message),
366 },
367 RateLimitError => Self::RateLimitExceeded {
368 provider,
369 retry_after: None,
370 },
371 ApiError => Self::ApiInternalServerError {
372 provider,
373 message: error.message,
374 },
375 OverloadedError => Self::ServerOverloaded {
376 provider,
377 retry_after: None,
378 },
379 },
380 None => Self::Other(error.into()),
381 }
382 }
383}
384
385impl From<open_ai::RequestError> for LanguageModelCompletionError {
386 fn from(error: open_ai::RequestError) -> Self {
387 match error {
388 open_ai::RequestError::HttpResponseError {
389 provider,
390 status_code,
391 body,
392 headers,
393 } => {
394 let retry_after = headers
395 .get(http::header::RETRY_AFTER)
396 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
397 .map(Duration::from_secs);
398
399 Self::from_http_status(provider.into(), status_code, body, retry_after)
400 }
401 open_ai::RequestError::Other(e) => Self::Other(e),
402 }
403 }
404}
405
406impl From<OpenRouterError> for LanguageModelCompletionError {
407 fn from(error: OpenRouterError) -> Self {
408 let provider = LanguageModelProviderName::new("OpenRouter");
409 match error {
410 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
411 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
412 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
413 OpenRouterError::DeserializeResponse(error) => {
414 Self::DeserializeResponse { provider, error }
415 }
416 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
417 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
418 provider,
419 retry_after: Some(retry_after),
420 },
421 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
422 provider,
423 retry_after,
424 },
425 OpenRouterError::ApiError(api_error) => api_error.into(),
426 }
427 }
428}
429
430impl From<open_router::ApiError> for LanguageModelCompletionError {
431 fn from(error: open_router::ApiError) -> Self {
432 use open_router::ApiErrorCode::*;
433 let provider = LanguageModelProviderName::new("OpenRouter");
434 match error.code {
435 InvalidRequestError => Self::BadRequestFormat {
436 provider,
437 message: error.message,
438 },
439 AuthenticationError => Self::AuthenticationError {
440 provider,
441 message: error.message,
442 },
443 PaymentRequiredError => Self::AuthenticationError {
444 provider,
445 message: format!("Payment required: {}", error.message),
446 },
447 PermissionError => Self::PermissionError {
448 provider,
449 message: error.message,
450 },
451 RequestTimedOut => Self::HttpResponseError {
452 provider,
453 status_code: StatusCode::REQUEST_TIMEOUT,
454 message: error.message,
455 },
456 RateLimitError => Self::RateLimitExceeded {
457 provider,
458 retry_after: None,
459 },
460 ApiError => Self::ApiInternalServerError {
461 provider,
462 message: error.message,
463 },
464 OverloadedError => Self::ServerOverloaded {
465 provider,
466 retry_after: None,
467 },
468 }
469 }
470}
471
472#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
473#[serde(rename_all = "snake_case")]
474pub enum StopReason {
475 EndTurn,
476 MaxTokens,
477 ToolUse,
478 Refusal,
479}
480
481#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
482pub struct TokenUsage {
483 #[serde(default, skip_serializing_if = "is_default")]
484 pub input_tokens: u64,
485 #[serde(default, skip_serializing_if = "is_default")]
486 pub output_tokens: u64,
487 #[serde(default, skip_serializing_if = "is_default")]
488 pub cache_creation_input_tokens: u64,
489 #[serde(default, skip_serializing_if = "is_default")]
490 pub cache_read_input_tokens: u64,
491}
492
493impl TokenUsage {
494 pub fn total_tokens(&self) -> u64 {
495 self.input_tokens
496 + self.output_tokens
497 + self.cache_read_input_tokens
498 + self.cache_creation_input_tokens
499 }
500}
501
502impl Add<TokenUsage> for TokenUsage {
503 type Output = Self;
504
505 fn add(self, other: Self) -> Self {
506 Self {
507 input_tokens: self.input_tokens + other.input_tokens,
508 output_tokens: self.output_tokens + other.output_tokens,
509 cache_creation_input_tokens: self.cache_creation_input_tokens
510 + other.cache_creation_input_tokens,
511 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
512 }
513 }
514}
515
516impl Sub<TokenUsage> for TokenUsage {
517 type Output = Self;
518
519 fn sub(self, other: Self) -> Self {
520 Self {
521 input_tokens: self.input_tokens - other.input_tokens,
522 output_tokens: self.output_tokens - other.output_tokens,
523 cache_creation_input_tokens: self.cache_creation_input_tokens
524 - other.cache_creation_input_tokens,
525 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
526 }
527 }
528}
529
530#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
531pub struct LanguageModelToolUseId(Arc<str>);
532
533impl fmt::Display for LanguageModelToolUseId {
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 write!(f, "{}", self.0)
536 }
537}
538
539impl<T> From<T> for LanguageModelToolUseId
540where
541 T: Into<Arc<str>>,
542{
543 fn from(value: T) -> Self {
544 Self(value.into())
545 }
546}
547
548#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
549pub struct LanguageModelToolUse {
550 pub id: LanguageModelToolUseId,
551 pub name: Arc<str>,
552 pub raw_input: String,
553 pub input: serde_json::Value,
554 pub is_input_complete: bool,
555 /// Thought signature the model sent us. Some models require that this
556 /// signature be preserved and sent back in conversation history for validation.
557 pub thought_signature: Option<String>,
558}
559
560pub struct LanguageModelTextStream {
561 pub message_id: Option<String>,
562 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
563 // Has complete token usage after the stream has finished
564 pub last_token_usage: Arc<Mutex<TokenUsage>>,
565}
566
567impl Default for LanguageModelTextStream {
568 fn default() -> Self {
569 Self {
570 message_id: None,
571 stream: Box::pin(futures::stream::empty()),
572 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
573 }
574 }
575}
576
577#[derive(Debug, Clone)]
578pub struct LanguageModelEffortLevel {
579 pub name: SharedString,
580 pub value: SharedString,
581 pub is_default: bool,
582}
583
584pub trait LanguageModel: Send + Sync {
585 fn id(&self) -> LanguageModelId;
586 fn name(&self) -> LanguageModelName;
587 fn provider_id(&self) -> LanguageModelProviderId;
588 fn provider_name(&self) -> LanguageModelProviderName;
589 fn upstream_provider_id(&self) -> LanguageModelProviderId {
590 self.provider_id()
591 }
592 fn upstream_provider_name(&self) -> LanguageModelProviderName {
593 self.provider_name()
594 }
595
596 /// Returns whether this model is the "latest", so we can highlight it in the UI.
597 fn is_latest(&self) -> bool {
598 false
599 }
600
601 fn telemetry_id(&self) -> String;
602
603 fn api_key(&self, _cx: &App) -> Option<String> {
604 None
605 }
606
607 /// Information about the cost of using this model, if available.
608 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
609 None
610 }
611
612 /// Whether this model supports thinking.
613 fn supports_thinking(&self) -> bool {
614 false
615 }
616
617 fn supports_fast_mode(&self) -> bool {
618 false
619 }
620
621 /// Returns the list of supported effort levels that can be used when thinking.
622 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
623 Vec::new()
624 }
625
626 /// Returns the default effort level to use when thinking.
627 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
628 self.supported_effort_levels()
629 .into_iter()
630 .find(|effort_level| effort_level.is_default)
631 }
632
633 /// Whether this model supports images
634 fn supports_images(&self) -> bool;
635
636 /// Whether this model supports tools.
637 fn supports_tools(&self) -> bool;
638
639 /// Whether this model supports choosing which tool to use.
640 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
641
642 /// Returns whether this model or provider supports streaming tool calls;
643 fn supports_streaming_tools(&self) -> bool {
644 false
645 }
646
647 /// Returns whether this model/provider reports accurate split input/output token counts.
648 /// When true, the UI may show separate input/output token indicators.
649 fn supports_split_token_display(&self) -> bool {
650 false
651 }
652
653 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
654 LanguageModelToolSchemaFormat::JsonSchema
655 }
656
657 fn max_token_count(&self) -> u64;
658 fn max_output_tokens(&self) -> Option<u64> {
659 None
660 }
661
662 fn count_tokens(
663 &self,
664 request: LanguageModelRequest,
665 cx: &App,
666 ) -> BoxFuture<'static, Result<u64>>;
667
668 fn stream_completion(
669 &self,
670 request: LanguageModelRequest,
671 cx: &AsyncApp,
672 ) -> BoxFuture<
673 'static,
674 Result<
675 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
676 LanguageModelCompletionError,
677 >,
678 >;
679
680 fn stream_completion_text(
681 &self,
682 request: LanguageModelRequest,
683 cx: &AsyncApp,
684 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
685 let future = self.stream_completion(request, cx);
686
687 async move {
688 let events = future.await?;
689 let mut events = events.fuse();
690 let mut message_id = None;
691 let mut first_item_text = None;
692 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
693
694 if let Some(first_event) = events.next().await {
695 match first_event {
696 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
697 message_id = Some(id);
698 }
699 Ok(LanguageModelCompletionEvent::Text(text)) => {
700 first_item_text = Some(text);
701 }
702 _ => (),
703 }
704 }
705
706 let stream = futures::stream::iter(first_item_text.map(Ok))
707 .chain(events.filter_map({
708 let last_token_usage = last_token_usage.clone();
709 move |result| {
710 let last_token_usage = last_token_usage.clone();
711 async move {
712 match result {
713 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
714 Ok(LanguageModelCompletionEvent::Started) => None,
715 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
716 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
717 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
718 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
719 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
720 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
721 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
722 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
723 ..
724 }) => None,
725 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
726 *last_token_usage.lock() = token_usage;
727 None
728 }
729 Err(err) => Some(Err(err)),
730 }
731 }
732 }
733 }))
734 .boxed();
735
736 Ok(LanguageModelTextStream {
737 message_id,
738 stream,
739 last_token_usage,
740 })
741 }
742 .boxed()
743 }
744
745 fn stream_completion_tool(
746 &self,
747 request: LanguageModelRequest,
748 cx: &AsyncApp,
749 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
750 let future = self.stream_completion(request, cx);
751
752 async move {
753 let events = future.await?;
754 let mut events = events.fuse();
755
756 // Iterate through events until we find a complete ToolUse
757 while let Some(event) = events.next().await {
758 match event {
759 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
760 if tool_use.is_input_complete =>
761 {
762 return Ok(tool_use);
763 }
764 Err(err) => {
765 return Err(err);
766 }
767 _ => {}
768 }
769 }
770
771 // Stream ended without a complete tool use
772 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
773 "Stream ended without receiving a complete tool use"
774 )))
775 }
776 .boxed()
777 }
778
779 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
780 None
781 }
782
783 #[cfg(any(test, feature = "test-support"))]
784 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
785 unimplemented!()
786 }
787}
788
789impl std::fmt::Debug for dyn LanguageModel {
790 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
791 f.debug_struct("<dyn LanguageModel>")
792 .field("id", &self.id())
793 .field("name", &self.name())
794 .field("provider_id", &self.provider_id())
795 .field("provider_name", &self.provider_name())
796 .field("upstream_provider_name", &self.upstream_provider_name())
797 .field("upstream_provider_id", &self.upstream_provider_id())
798 .field("upstream_provider_id", &self.upstream_provider_id())
799 .field("supports_streaming_tools", &self.supports_streaming_tools())
800 .finish()
801 }
802}
803
804/// An error that occurred when trying to authenticate the language model provider.
805#[derive(Debug, Error)]
806pub enum AuthenticateError {
807 #[error("connection refused")]
808 ConnectionRefused,
809 #[error("credentials not found")]
810 CredentialsNotFound,
811 #[error(transparent)]
812 Other(#[from] anyhow::Error),
813}
814
815/// Either a built-in icon name or a path to an external SVG.
816#[derive(Debug, Clone, PartialEq, Eq)]
817pub enum IconOrSvg {
818 /// A built-in icon from Zed's icon set.
819 Icon(IconName),
820 /// Path to a custom SVG icon file.
821 Svg(SharedString),
822}
823
824impl Default for IconOrSvg {
825 fn default() -> Self {
826 Self::Icon(IconName::ZedAssistant)
827 }
828}
829
830pub trait LanguageModelProvider: 'static {
831 fn id(&self) -> LanguageModelProviderId;
832 fn name(&self) -> LanguageModelProviderName;
833 fn icon(&self) -> IconOrSvg {
834 IconOrSvg::default()
835 }
836 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
837 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
838 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
839 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
840 Vec::new()
841 }
842 fn is_authenticated(&self, cx: &App) -> bool;
843 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
844 fn configuration_view(
845 &self,
846 target_agent: ConfigurationViewTargetAgent,
847 window: &mut Window,
848 cx: &mut App,
849 ) -> AnyView;
850 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
851}
852
853#[derive(Default, Clone, PartialEq, Eq)]
854pub enum ConfigurationViewTargetAgent {
855 #[default]
856 ZedAgent,
857 Other(SharedString),
858}
859
860pub trait LanguageModelProviderState: 'static {
861 type ObservableEntity;
862
863 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
864
865 fn subscribe<T: 'static>(
866 &self,
867 cx: &mut gpui::Context<T>,
868 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
869 ) -> Option<gpui::Subscription> {
870 let entity = self.observable_entity()?;
871 Some(cx.observe(&entity, move |this, _, cx| {
872 callback(this, cx);
873 }))
874 }
875}
876
877#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
878pub struct LanguageModelId(pub SharedString);
879
880#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
881pub struct LanguageModelName(pub SharedString);
882
883#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
884pub struct LanguageModelProviderId(pub SharedString);
885
886#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
887pub struct LanguageModelProviderName(pub SharedString);
888
889#[derive(Clone, Debug, PartialEq)]
890pub enum LanguageModelCostInfo {
891 /// Cost per 1,000 input and output tokens
892 TokenCost {
893 input_token_cost_per_1m: f64,
894 output_token_cost_per_1m: f64,
895 },
896 /// Cost per request
897 RequestCost { cost_per_request: f64 },
898}
899
900impl LanguageModelCostInfo {
901 pub fn to_shared_string(&self) -> SharedString {
902 match self {
903 LanguageModelCostInfo::RequestCost { cost_per_request } => {
904 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
905 SharedString::from(cost_str)
906 }
907 LanguageModelCostInfo::TokenCost {
908 input_token_cost_per_1m,
909 output_token_cost_per_1m,
910 } => {
911 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
912 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
913 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
914 }
915 }
916 }
917
918 fn cost_value_to_string(cost: &f64) -> SharedString {
919 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
920 SharedString::from(format!("{:.0}", cost))
921 } else {
922 SharedString::from(format!("{:.2}", cost))
923 }
924 }
925}
926
927impl LanguageModelProviderId {
928 pub const fn new(id: &'static str) -> Self {
929 Self(SharedString::new_static(id))
930 }
931}
932
933impl LanguageModelProviderName {
934 pub const fn new(id: &'static str) -> Self {
935 Self(SharedString::new_static(id))
936 }
937}
938
939impl fmt::Display for LanguageModelProviderId {
940 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
941 write!(f, "{}", self.0)
942 }
943}
944
945impl fmt::Display for LanguageModelProviderName {
946 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
947 write!(f, "{}", self.0)
948 }
949}
950
951impl From<String> for LanguageModelId {
952 fn from(value: String) -> Self {
953 Self(SharedString::from(value))
954 }
955}
956
957impl From<String> for LanguageModelName {
958 fn from(value: String) -> Self {
959 Self(SharedString::from(value))
960 }
961}
962
963impl From<String> for LanguageModelProviderId {
964 fn from(value: String) -> Self {
965 Self(SharedString::from(value))
966 }
967}
968
969impl From<String> for LanguageModelProviderName {
970 fn from(value: String) -> Self {
971 Self(SharedString::from(value))
972 }
973}
974
975impl From<Arc<str>> for LanguageModelProviderId {
976 fn from(value: Arc<str>) -> Self {
977 Self(SharedString::from(value))
978 }
979}
980
981impl From<Arc<str>> for LanguageModelProviderName {
982 fn from(value: Arc<str>) -> Self {
983 Self(SharedString::from(value))
984 }
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990
991 #[test]
992 fn test_from_cloud_failure_with_upstream_http_error() {
993 let error = LanguageModelCompletionError::from_cloud_failure(
994 String::from("anthropic").into(),
995 "upstream_http_error".to_string(),
996 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
997 None,
998 );
999
1000 match error {
1001 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1002 assert_eq!(provider.0, "anthropic");
1003 }
1004 _ => panic!(
1005 "Expected ServerOverloaded error for 503 status, got: {:?}",
1006 error
1007 ),
1008 }
1009
1010 let error = LanguageModelCompletionError::from_cloud_failure(
1011 String::from("anthropic").into(),
1012 "upstream_http_error".to_string(),
1013 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
1014 None,
1015 );
1016
1017 match error {
1018 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1019 assert_eq!(provider.0, "anthropic");
1020 assert_eq!(message, "Internal server error");
1021 }
1022 _ => panic!(
1023 "Expected ApiInternalServerError for 500 status, got: {:?}",
1024 error
1025 ),
1026 }
1027 }
1028
1029 #[test]
1030 fn test_from_cloud_failure_with_standard_format() {
1031 let error = LanguageModelCompletionError::from_cloud_failure(
1032 String::from("anthropic").into(),
1033 "upstream_http_503".to_string(),
1034 "Service unavailable".to_string(),
1035 None,
1036 );
1037
1038 match error {
1039 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1040 assert_eq!(provider.0, "anthropic");
1041 }
1042 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1043 }
1044 }
1045
1046 #[test]
1047 fn test_upstream_http_error_connection_timeout() {
1048 let error = LanguageModelCompletionError::from_cloud_failure(
1049 String::from("anthropic").into(),
1050 "upstream_http_error".to_string(),
1051 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1052 None,
1053 );
1054
1055 match error {
1056 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1057 assert_eq!(provider.0, "anthropic");
1058 }
1059 _ => panic!(
1060 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1061 error
1062 ),
1063 }
1064
1065 let error = LanguageModelCompletionError::from_cloud_failure(
1066 String::from("anthropic").into(),
1067 "upstream_http_error".to_string(),
1068 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1069 None,
1070 );
1071
1072 match error {
1073 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1074 assert_eq!(provider.0, "anthropic");
1075 assert_eq!(
1076 message,
1077 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1078 );
1079 }
1080 _ => panic!(
1081 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1082 error
1083 ),
1084 }
1085 }
1086
1087 #[test]
1088 fn test_language_model_tool_use_serializes_with_signature() {
1089 use serde_json::json;
1090
1091 let tool_use = LanguageModelToolUse {
1092 id: LanguageModelToolUseId::from("test_id"),
1093 name: "test_tool".into(),
1094 raw_input: json!({"arg": "value"}).to_string(),
1095 input: json!({"arg": "value"}),
1096 is_input_complete: true,
1097 thought_signature: Some("test_signature".to_string()),
1098 };
1099
1100 let serialized = serde_json::to_value(&tool_use).unwrap();
1101
1102 assert_eq!(serialized["id"], "test_id");
1103 assert_eq!(serialized["name"], "test_tool");
1104 assert_eq!(serialized["thought_signature"], "test_signature");
1105 }
1106
1107 #[test]
1108 fn test_language_model_tool_use_deserializes_with_missing_signature() {
1109 use serde_json::json;
1110
1111 let json = json!({
1112 "id": "test_id",
1113 "name": "test_tool",
1114 "raw_input": "{\"arg\":\"value\"}",
1115 "input": {"arg": "value"},
1116 "is_input_complete": true
1117 });
1118
1119 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1120
1121 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1122 assert_eq!(tool_use.name.as_ref(), "test_tool");
1123 assert_eq!(tool_use.thought_signature, None);
1124 }
1125
1126 #[test]
1127 fn test_language_model_tool_use_round_trip_with_signature() {
1128 use serde_json::json;
1129
1130 let original = LanguageModelToolUse {
1131 id: LanguageModelToolUseId::from("round_trip_id"),
1132 name: "round_trip_tool".into(),
1133 raw_input: json!({"key": "value"}).to_string(),
1134 input: json!({"key": "value"}),
1135 is_input_complete: true,
1136 thought_signature: Some("round_trip_sig".to_string()),
1137 };
1138
1139 let serialized = serde_json::to_value(&original).unwrap();
1140 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1141
1142 assert_eq!(deserialized.id, original.id);
1143 assert_eq!(deserialized.name, original.name);
1144 assert_eq!(deserialized.thought_signature, original.thought_signature);
1145 }
1146
1147 #[test]
1148 fn test_language_model_tool_use_round_trip_without_signature() {
1149 use serde_json::json;
1150
1151 let original = LanguageModelToolUse {
1152 id: LanguageModelToolUseId::from("no_sig_id"),
1153 name: "no_sig_tool".into(),
1154 raw_input: json!({"arg": "value"}).to_string(),
1155 input: json!({"arg": "value"}),
1156 is_input_complete: true,
1157 thought_signature: None,
1158 };
1159
1160 let serialized = serde_json::to_value(&original).unwrap();
1161 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1162
1163 assert_eq!(deserialized.id, original.id);
1164 assert_eq!(deserialized.name, original.name);
1165 assert_eq!(deserialized.thought_signature, None);
1166 }
1167}