1mod api_key;
2mod model;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7mod telemetry;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anthropic::{AnthropicError, parse_prompt_too_long};
14use anyhow::{Result, anyhow};
15use client::Client;
16use client::UserStore;
17use cloud_llm_client::CompletionRequestStatus;
18use futures::FutureExt;
19use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
20use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
21use http_client::{StatusCode, http};
22use icons::IconName;
23use open_router::OpenRouterError;
24use parking_lot::Mutex;
25use serde::{Deserialize, Serialize};
26pub use settings::LanguageModelCacheConfiguration;
27use std::ops::{Add, Sub};
28use std::str::FromStr;
29use std::sync::Arc;
30use std::time::Duration;
31use std::{fmt, io};
32use thiserror::Error;
33use util::serde::is_default;
34
35pub use crate::api_key::{ApiKey, ApiKeyState};
36pub use crate::model::*;
37pub use crate::rate_limiter::*;
38pub use crate::registry::*;
39pub use crate::request::*;
40pub use crate::role::*;
41pub use crate::telemetry::*;
42pub use crate::tool_schema::LanguageModelToolSchemaFormat;
43pub use zed_env_vars::{EnvVar, env_var};
44
45pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
46 LanguageModelProviderId::new("anthropic");
47pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Anthropic");
49
50pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
51pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("Google AI");
53
54pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
55pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
56 LanguageModelProviderName::new("OpenAI");
57
58pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
59pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
60
61pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
62pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
63 LanguageModelProviderName::new("Zed");
64
65pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
66 init_settings(cx);
67 RefreshLlmTokenListener::register(client, user_store, cx);
68}
69
70pub fn init_settings(cx: &mut App) {
71 registry::init(cx);
72}
73
74/// A completion event from a language model.
75#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
76pub enum LanguageModelCompletionEvent {
77 Queued {
78 position: usize,
79 },
80 Started,
81 Stop(StopReason),
82 Text(String),
83 Thinking {
84 text: String,
85 signature: Option<String>,
86 },
87 RedactedThinking {
88 data: String,
89 },
90 ToolUse(LanguageModelToolUse),
91 ToolUseJsonParseError {
92 id: LanguageModelToolUseId,
93 tool_name: Arc<str>,
94 raw_input: Arc<str>,
95 json_parse_error: String,
96 },
97 StartMessage {
98 message_id: String,
99 },
100 ReasoningDetails(serde_json::Value),
101 UsageUpdate(TokenUsage),
102}
103
104impl LanguageModelCompletionEvent {
105 pub fn from_completion_request_status(
106 status: CompletionRequestStatus,
107 upstream_provider: LanguageModelProviderName,
108 ) -> Result<Option<Self>, LanguageModelCompletionError> {
109 match status {
110 CompletionRequestStatus::Queued { position } => {
111 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
112 }
113 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
114 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
115 CompletionRequestStatus::Failed {
116 code,
117 message,
118 request_id: _,
119 retry_after,
120 } => Err(LanguageModelCompletionError::from_cloud_failure(
121 upstream_provider,
122 code,
123 message,
124 retry_after.map(Duration::from_secs_f64),
125 )),
126 }
127 }
128}
129
130#[derive(Error, Debug)]
131pub enum LanguageModelCompletionError {
132 #[error("prompt too large for context window")]
133 PromptTooLarge { tokens: Option<u64> },
134 #[error("missing {provider} API key")]
135 NoApiKey { provider: LanguageModelProviderName },
136 #[error("{provider}'s API rate limit exceeded")]
137 RateLimitExceeded {
138 provider: LanguageModelProviderName,
139 retry_after: Option<Duration>,
140 },
141 #[error("{provider}'s API servers are overloaded right now")]
142 ServerOverloaded {
143 provider: LanguageModelProviderName,
144 retry_after: Option<Duration>,
145 },
146 #[error("{provider}'s API server reported an internal server error: {message}")]
147 ApiInternalServerError {
148 provider: LanguageModelProviderName,
149 message: String,
150 },
151 #[error("{message}")]
152 UpstreamProviderError {
153 message: String,
154 status: StatusCode,
155 retry_after: Option<Duration>,
156 },
157 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
158 HttpResponseError {
159 provider: LanguageModelProviderName,
160 status_code: StatusCode,
161 message: String,
162 },
163
164 // Client errors
165 #[error("invalid request format to {provider}'s API: {message}")]
166 BadRequestFormat {
167 provider: LanguageModelProviderName,
168 message: String,
169 },
170 #[error("authentication error with {provider}'s API: {message}")]
171 AuthenticationError {
172 provider: LanguageModelProviderName,
173 message: String,
174 },
175 #[error("Permission error with {provider}'s API: {message}")]
176 PermissionError {
177 provider: LanguageModelProviderName,
178 message: String,
179 },
180 #[error("language model provider API endpoint not found")]
181 ApiEndpointNotFound { provider: LanguageModelProviderName },
182 #[error("I/O error reading response from {provider}'s API")]
183 ApiReadResponseError {
184 provider: LanguageModelProviderName,
185 #[source]
186 error: io::Error,
187 },
188 #[error("error serializing request to {provider} API")]
189 SerializeRequest {
190 provider: LanguageModelProviderName,
191 #[source]
192 error: serde_json::Error,
193 },
194 #[error("error building request body to {provider} API")]
195 BuildRequestBody {
196 provider: LanguageModelProviderName,
197 #[source]
198 error: http::Error,
199 },
200 #[error("error sending HTTP request to {provider} API")]
201 HttpSend {
202 provider: LanguageModelProviderName,
203 #[source]
204 error: anyhow::Error,
205 },
206 #[error("error deserializing {provider} API response")]
207 DeserializeResponse {
208 provider: LanguageModelProviderName,
209 #[source]
210 error: serde_json::Error,
211 },
212
213 #[error("stream from {provider} ended unexpectedly")]
214 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
215
216 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
217 #[error(transparent)]
218 Other(#[from] anyhow::Error),
219}
220
221impl LanguageModelCompletionError {
222 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
223 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
224 let upstream_status = error_json
225 .get("upstream_status")
226 .and_then(|v| v.as_u64())
227 .and_then(|status| u16::try_from(status).ok())
228 .and_then(|status| StatusCode::from_u16(status).ok())?;
229 let inner_message = error_json
230 .get("message")
231 .and_then(|v| v.as_str())
232 .unwrap_or(message)
233 .to_string();
234 Some((upstream_status, inner_message))
235 }
236
237 pub fn from_cloud_failure(
238 upstream_provider: LanguageModelProviderName,
239 code: String,
240 message: String,
241 retry_after: Option<Duration>,
242 ) -> Self {
243 if let Some(tokens) = parse_prompt_too_long(&message) {
244 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
245 // to be reported. This is a temporary workaround to handle this in the case where the
246 // token limit has been exceeded.
247 Self::PromptTooLarge {
248 tokens: Some(tokens),
249 }
250 } else if code == "upstream_http_error" {
251 if let Some((upstream_status, inner_message)) =
252 Self::parse_upstream_error_json(&message)
253 {
254 return Self::from_http_status(
255 upstream_provider,
256 upstream_status,
257 inner_message,
258 retry_after,
259 );
260 }
261 anyhow!("completion request failed, code: {code}, message: {message}").into()
262 } else if let Some(status_code) = code
263 .strip_prefix("upstream_http_")
264 .and_then(|code| StatusCode::from_str(code).ok())
265 {
266 Self::from_http_status(upstream_provider, status_code, message, retry_after)
267 } else if let Some(status_code) = code
268 .strip_prefix("http_")
269 .and_then(|code| StatusCode::from_str(code).ok())
270 {
271 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
272 } else {
273 anyhow!("completion request failed, code: {code}, message: {message}").into()
274 }
275 }
276
277 pub fn from_http_status(
278 provider: LanguageModelProviderName,
279 status_code: StatusCode,
280 message: String,
281 retry_after: Option<Duration>,
282 ) -> Self {
283 match status_code {
284 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
285 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
286 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
287 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
288 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
289 tokens: parse_prompt_too_long(&message),
290 },
291 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
292 provider,
293 retry_after,
294 },
295 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
296 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
297 provider,
298 retry_after,
299 },
300 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
301 provider,
302 retry_after,
303 },
304 _ => Self::HttpResponseError {
305 provider,
306 status_code,
307 message,
308 },
309 }
310 }
311}
312
313impl From<AnthropicError> for LanguageModelCompletionError {
314 fn from(error: AnthropicError) -> Self {
315 let provider = ANTHROPIC_PROVIDER_NAME;
316 match error {
317 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
318 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
319 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
320 AnthropicError::DeserializeResponse(error) => {
321 Self::DeserializeResponse { provider, error }
322 }
323 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
324 AnthropicError::HttpResponseError {
325 status_code,
326 message,
327 } => Self::HttpResponseError {
328 provider,
329 status_code,
330 message,
331 },
332 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
333 provider,
334 retry_after: Some(retry_after),
335 },
336 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
337 provider,
338 retry_after,
339 },
340 AnthropicError::ApiError(api_error) => api_error.into(),
341 }
342 }
343}
344
345impl From<anthropic::ApiError> for LanguageModelCompletionError {
346 fn from(error: anthropic::ApiError) -> Self {
347 use anthropic::ApiErrorCode::*;
348 let provider = ANTHROPIC_PROVIDER_NAME;
349 match error.code() {
350 Some(code) => match code {
351 InvalidRequestError => Self::BadRequestFormat {
352 provider,
353 message: error.message,
354 },
355 AuthenticationError => Self::AuthenticationError {
356 provider,
357 message: error.message,
358 },
359 PermissionError => Self::PermissionError {
360 provider,
361 message: error.message,
362 },
363 NotFoundError => Self::ApiEndpointNotFound { provider },
364 RequestTooLarge => Self::PromptTooLarge {
365 tokens: parse_prompt_too_long(&error.message),
366 },
367 RateLimitError => Self::RateLimitExceeded {
368 provider,
369 retry_after: None,
370 },
371 ApiError => Self::ApiInternalServerError {
372 provider,
373 message: error.message,
374 },
375 OverloadedError => Self::ServerOverloaded {
376 provider,
377 retry_after: None,
378 },
379 },
380 None => Self::Other(error.into()),
381 }
382 }
383}
384
385impl From<open_ai::RequestError> for LanguageModelCompletionError {
386 fn from(error: open_ai::RequestError) -> Self {
387 match error {
388 open_ai::RequestError::HttpResponseError {
389 provider,
390 status_code,
391 body,
392 headers,
393 } => {
394 let retry_after = headers
395 .get(http::header::RETRY_AFTER)
396 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
397 .map(Duration::from_secs);
398
399 Self::from_http_status(provider.into(), status_code, body, retry_after)
400 }
401 open_ai::RequestError::Other(e) => Self::Other(e),
402 }
403 }
404}
405
406impl From<OpenRouterError> for LanguageModelCompletionError {
407 fn from(error: OpenRouterError) -> Self {
408 let provider = LanguageModelProviderName::new("OpenRouter");
409 match error {
410 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
411 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
412 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
413 OpenRouterError::DeserializeResponse(error) => {
414 Self::DeserializeResponse { provider, error }
415 }
416 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
417 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
418 provider,
419 retry_after: Some(retry_after),
420 },
421 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
422 provider,
423 retry_after,
424 },
425 OpenRouterError::ApiError(api_error) => api_error.into(),
426 }
427 }
428}
429
430impl From<open_router::ApiError> for LanguageModelCompletionError {
431 fn from(error: open_router::ApiError) -> Self {
432 use open_router::ApiErrorCode::*;
433 let provider = LanguageModelProviderName::new("OpenRouter");
434 match error.code {
435 InvalidRequestError => Self::BadRequestFormat {
436 provider,
437 message: error.message,
438 },
439 AuthenticationError => Self::AuthenticationError {
440 provider,
441 message: error.message,
442 },
443 PaymentRequiredError => Self::AuthenticationError {
444 provider,
445 message: format!("Payment required: {}", error.message),
446 },
447 PermissionError => Self::PermissionError {
448 provider,
449 message: error.message,
450 },
451 RequestTimedOut => Self::HttpResponseError {
452 provider,
453 status_code: StatusCode::REQUEST_TIMEOUT,
454 message: error.message,
455 },
456 RateLimitError => Self::RateLimitExceeded {
457 provider,
458 retry_after: None,
459 },
460 ApiError => Self::ApiInternalServerError {
461 provider,
462 message: error.message,
463 },
464 OverloadedError => Self::ServerOverloaded {
465 provider,
466 retry_after: None,
467 },
468 }
469 }
470}
471
472#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
473#[serde(rename_all = "snake_case")]
474pub enum StopReason {
475 EndTurn,
476 MaxTokens,
477 ToolUse,
478 Refusal,
479}
480
481#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
482pub struct TokenUsage {
483 #[serde(default, skip_serializing_if = "is_default")]
484 pub input_tokens: u64,
485 #[serde(default, skip_serializing_if = "is_default")]
486 pub output_tokens: u64,
487 #[serde(default, skip_serializing_if = "is_default")]
488 pub cache_creation_input_tokens: u64,
489 #[serde(default, skip_serializing_if = "is_default")]
490 pub cache_read_input_tokens: u64,
491}
492
493impl TokenUsage {
494 pub fn total_tokens(&self) -> u64 {
495 self.input_tokens
496 + self.output_tokens
497 + self.cache_read_input_tokens
498 + self.cache_creation_input_tokens
499 }
500}
501
502impl Add<TokenUsage> for TokenUsage {
503 type Output = Self;
504
505 fn add(self, other: Self) -> Self {
506 Self {
507 input_tokens: self.input_tokens + other.input_tokens,
508 output_tokens: self.output_tokens + other.output_tokens,
509 cache_creation_input_tokens: self.cache_creation_input_tokens
510 + other.cache_creation_input_tokens,
511 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
512 }
513 }
514}
515
516impl Sub<TokenUsage> for TokenUsage {
517 type Output = Self;
518
519 fn sub(self, other: Self) -> Self {
520 Self {
521 input_tokens: self.input_tokens - other.input_tokens,
522 output_tokens: self.output_tokens - other.output_tokens,
523 cache_creation_input_tokens: self.cache_creation_input_tokens
524 - other.cache_creation_input_tokens,
525 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
526 }
527 }
528}
529
530#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
531pub struct LanguageModelToolUseId(Arc<str>);
532
533impl fmt::Display for LanguageModelToolUseId {
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 write!(f, "{}", self.0)
536 }
537}
538
539impl<T> From<T> for LanguageModelToolUseId
540where
541 T: Into<Arc<str>>,
542{
543 fn from(value: T) -> Self {
544 Self(value.into())
545 }
546}
547
548#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
549pub struct LanguageModelToolUse {
550 pub id: LanguageModelToolUseId,
551 pub name: Arc<str>,
552 pub raw_input: String,
553 pub input: serde_json::Value,
554 pub is_input_complete: bool,
555 /// Thought signature the model sent us. Some models require that this
556 /// signature be preserved and sent back in conversation history for validation.
557 pub thought_signature: Option<String>,
558}
559
560pub struct LanguageModelTextStream {
561 pub message_id: Option<String>,
562 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
563 // Has complete token usage after the stream has finished
564 pub last_token_usage: Arc<Mutex<TokenUsage>>,
565}
566
567impl Default for LanguageModelTextStream {
568 fn default() -> Self {
569 Self {
570 message_id: None,
571 stream: Box::pin(futures::stream::empty()),
572 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
573 }
574 }
575}
576
577#[derive(Debug, Clone)]
578pub struct LanguageModelEffortLevel {
579 pub name: SharedString,
580 pub value: SharedString,
581 pub is_default: bool,
582}
583
584pub trait LanguageModel: Send + Sync {
585 fn id(&self) -> LanguageModelId;
586 fn name(&self) -> LanguageModelName;
587 fn provider_id(&self) -> LanguageModelProviderId;
588 fn provider_name(&self) -> LanguageModelProviderName;
589 fn upstream_provider_id(&self) -> LanguageModelProviderId {
590 self.provider_id()
591 }
592 fn upstream_provider_name(&self) -> LanguageModelProviderName {
593 self.provider_name()
594 }
595
596 /// Returns whether this model is the "latest", so we can highlight it in the UI.
597 fn is_latest(&self) -> bool {
598 false
599 }
600
601 fn telemetry_id(&self) -> String;
602
603 fn api_key(&self, _cx: &App) -> Option<String> {
604 None
605 }
606
607 /// Information about the cost of using this model, if available.
608 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
609 None
610 }
611
612 /// Whether this model supports thinking.
613 fn supports_thinking(&self) -> bool {
614 false
615 }
616
617 fn supports_fast_mode(&self) -> bool {
618 false
619 }
620
621 /// Returns the list of supported effort levels that can be used when thinking.
622 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
623 Vec::new()
624 }
625
626 /// Returns the default effort level to use when thinking.
627 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
628 self.supported_effort_levels()
629 .into_iter()
630 .find(|effort_level| effort_level.is_default)
631 }
632
633 /// Whether this model supports images
634 fn supports_images(&self) -> bool;
635
636 /// Whether this model supports tools.
637 fn supports_tools(&self) -> bool;
638
639 /// Whether this model supports choosing which tool to use.
640 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
641
642 /// Returns whether this model or provider supports streaming tool calls;
643 fn supports_streaming_tools(&self) -> bool {
644 false
645 }
646
647 /// Returns whether this model/provider reports accurate split input/output token counts.
648 /// When true, the UI may show separate input/output token indicators.
649 fn supports_split_token_display(&self) -> bool {
650 false
651 }
652
653 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
654 LanguageModelToolSchemaFormat::JsonSchema
655 }
656
657 fn max_token_count(&self) -> u64;
658 fn max_output_tokens(&self) -> Option<u64> {
659 None
660 }
661
662 fn count_tokens(
663 &self,
664 request: LanguageModelRequest,
665 cx: &App,
666 ) -> BoxFuture<'static, Result<u64>>;
667
668 fn stream_completion(
669 &self,
670 request: LanguageModelRequest,
671 cx: &AsyncApp,
672 ) -> BoxFuture<
673 'static,
674 Result<
675 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
676 LanguageModelCompletionError,
677 >,
678 >;
679
680 fn stream_completion_text(
681 &self,
682 request: LanguageModelRequest,
683 cx: &AsyncApp,
684 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
685 let future = self.stream_completion(request, cx);
686
687 async move {
688 let events = future.await?;
689 let mut events = events.fuse();
690 let mut message_id = None;
691 let mut first_item_text = None;
692 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
693
694 if let Some(first_event) = events.next().await {
695 match first_event {
696 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
697 message_id = Some(id);
698 }
699 Ok(LanguageModelCompletionEvent::Text(text)) => {
700 first_item_text = Some(text);
701 }
702 _ => (),
703 }
704 }
705
706 let stream = futures::stream::iter(first_item_text.map(Ok))
707 .chain(events.filter_map({
708 let last_token_usage = last_token_usage.clone();
709 move |result| {
710 let last_token_usage = last_token_usage.clone();
711 async move {
712 match result {
713 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
714 Ok(LanguageModelCompletionEvent::Started) => None,
715 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
716 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
717 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
718 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
719 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
720 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
721 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
722 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
723 ..
724 }) => None,
725 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
726 *last_token_usage.lock() = token_usage;
727 None
728 }
729 Err(err) => Some(Err(err)),
730 }
731 }
732 }
733 }))
734 .boxed();
735
736 Ok(LanguageModelTextStream {
737 message_id,
738 stream,
739 last_token_usage,
740 })
741 }
742 .boxed()
743 }
744
745 fn stream_completion_tool(
746 &self,
747 request: LanguageModelRequest,
748 cx: &AsyncApp,
749 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
750 let future = self.stream_completion(request, cx);
751
752 async move {
753 let events = future.await?;
754 let mut events = events.fuse();
755
756 // Iterate through events until we find a complete ToolUse
757 while let Some(event) = events.next().await {
758 match event {
759 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
760 if tool_use.is_input_complete =>
761 {
762 return Ok(tool_use);
763 }
764 Err(err) => {
765 return Err(err);
766 }
767 _ => {}
768 }
769 }
770
771 // Stream ended without a complete tool use
772 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
773 "Stream ended without receiving a complete tool use"
774 )))
775 }
776 .boxed()
777 }
778
779 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
780 None
781 }
782
783 #[cfg(any(test, feature = "test-support"))]
784 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
785 unimplemented!()
786 }
787}
788
789impl std::fmt::Debug for dyn LanguageModel {
790 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
791 f.debug_struct("<dyn LanguageModel>")
792 .field("id", &self.id())
793 .field("name", &self.name())
794 .field("provider_id", &self.provider_id())
795 .field("provider_name", &self.provider_name())
796 .field("upstream_provider_name", &self.upstream_provider_name())
797 .field("upstream_provider_id", &self.upstream_provider_id())
798 .field("upstream_provider_id", &self.upstream_provider_id())
799 .field("supports_streaming_tools", &self.supports_streaming_tools())
800 .finish()
801 }
802}
803
804/// An error that occurred when trying to authenticate the language model provider.
805#[derive(Debug, Error)]
806pub enum AuthenticateError {
807 #[error("connection refused")]
808 ConnectionRefused,
809 #[error("credentials not found")]
810 CredentialsNotFound,
811 #[error(transparent)]
812 Other(#[from] anyhow::Error),
813}
814
815/// Either a built-in icon name or a path to an external SVG.
816#[derive(Debug, Clone, PartialEq, Eq)]
817pub enum IconOrSvg {
818 /// A built-in icon from Zed's icon set.
819 Icon(IconName),
820 /// Path to a custom SVG icon file.
821 Svg(SharedString),
822}
823
824impl Default for IconOrSvg {
825 fn default() -> Self {
826 Self::Icon(IconName::ZedAssistant)
827 }
828}
829
830pub trait LanguageModelProvider: 'static {
831 fn id(&self) -> LanguageModelProviderId;
832 fn name(&self) -> LanguageModelProviderName;
833 fn icon(&self) -> IconOrSvg {
834 IconOrSvg::default()
835 }
836 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
837 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
838 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
839 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
840 Vec::new()
841 }
842 fn is_authenticated(&self, cx: &App) -> bool;
843 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
844 fn configuration_view(
845 &self,
846 target_agent: ConfigurationViewTargetAgent,
847 window: &mut Window,
848 cx: &mut App,
849 ) -> AnyView;
850 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
851}
852
853#[derive(Default, Clone, PartialEq, Eq)]
854pub enum ConfigurationViewTargetAgent {
855 #[default]
856 ZedAgent,
857 Other(SharedString),
858}
859
860#[derive(PartialEq, Eq)]
861pub enum LanguageModelProviderTosView {
862 /// When there are some past interactions in the Agent Panel.
863 ThreadEmptyState,
864 /// When there are no past interactions in the Agent Panel.
865 ThreadFreshStart,
866 TextThreadPopup,
867 Configuration,
868}
869
870pub trait LanguageModelProviderState: 'static {
871 type ObservableEntity;
872
873 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
874
875 fn subscribe<T: 'static>(
876 &self,
877 cx: &mut gpui::Context<T>,
878 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
879 ) -> Option<gpui::Subscription> {
880 let entity = self.observable_entity()?;
881 Some(cx.observe(&entity, move |this, _, cx| {
882 callback(this, cx);
883 }))
884 }
885}
886
887#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
888pub struct LanguageModelId(pub SharedString);
889
890#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
891pub struct LanguageModelName(pub SharedString);
892
893#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
894pub struct LanguageModelProviderId(pub SharedString);
895
896#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
897pub struct LanguageModelProviderName(pub SharedString);
898
899#[derive(Clone, Debug, PartialEq)]
900pub enum LanguageModelCostInfo {
901 /// Cost per 1,000 input and output tokens
902 TokenCost {
903 input_token_cost_per_1m: f64,
904 output_token_cost_per_1m: f64,
905 },
906 /// Cost per request
907 RequestCost { cost_per_request: f64 },
908}
909
910impl LanguageModelCostInfo {
911 pub fn to_shared_string(&self) -> SharedString {
912 match self {
913 LanguageModelCostInfo::RequestCost { cost_per_request } => {
914 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
915 SharedString::from(cost_str)
916 }
917 LanguageModelCostInfo::TokenCost {
918 input_token_cost_per_1m,
919 output_token_cost_per_1m,
920 } => {
921 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
922 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
923 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
924 }
925 }
926 }
927
928 fn cost_value_to_string(cost: &f64) -> SharedString {
929 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
930 SharedString::from(format!("{:.0}", cost))
931 } else {
932 SharedString::from(format!("{:.2}", cost))
933 }
934 }
935}
936
937impl LanguageModelProviderId {
938 pub const fn new(id: &'static str) -> Self {
939 Self(SharedString::new_static(id))
940 }
941}
942
943impl LanguageModelProviderName {
944 pub const fn new(id: &'static str) -> Self {
945 Self(SharedString::new_static(id))
946 }
947}
948
949impl fmt::Display for LanguageModelProviderId {
950 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
951 write!(f, "{}", self.0)
952 }
953}
954
955impl fmt::Display for LanguageModelProviderName {
956 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
957 write!(f, "{}", self.0)
958 }
959}
960
961impl From<String> for LanguageModelId {
962 fn from(value: String) -> Self {
963 Self(SharedString::from(value))
964 }
965}
966
967impl From<String> for LanguageModelName {
968 fn from(value: String) -> Self {
969 Self(SharedString::from(value))
970 }
971}
972
973impl From<String> for LanguageModelProviderId {
974 fn from(value: String) -> Self {
975 Self(SharedString::from(value))
976 }
977}
978
979impl From<String> for LanguageModelProviderName {
980 fn from(value: String) -> Self {
981 Self(SharedString::from(value))
982 }
983}
984
985impl From<Arc<str>> for LanguageModelProviderId {
986 fn from(value: Arc<str>) -> Self {
987 Self(SharedString::from(value))
988 }
989}
990
991impl From<Arc<str>> for LanguageModelProviderName {
992 fn from(value: Arc<str>) -> Self {
993 Self(SharedString::from(value))
994 }
995}
996
997#[cfg(test)]
998mod tests {
999 use super::*;
1000
1001 #[test]
1002 fn test_from_cloud_failure_with_upstream_http_error() {
1003 let error = LanguageModelCompletionError::from_cloud_failure(
1004 String::from("anthropic").into(),
1005 "upstream_http_error".to_string(),
1006 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1007 None,
1008 );
1009
1010 match error {
1011 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1012 assert_eq!(provider.0, "anthropic");
1013 }
1014 _ => panic!(
1015 "Expected ServerOverloaded error for 503 status, got: {:?}",
1016 error
1017 ),
1018 }
1019
1020 let error = LanguageModelCompletionError::from_cloud_failure(
1021 String::from("anthropic").into(),
1022 "upstream_http_error".to_string(),
1023 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
1024 None,
1025 );
1026
1027 match error {
1028 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1029 assert_eq!(provider.0, "anthropic");
1030 assert_eq!(message, "Internal server error");
1031 }
1032 _ => panic!(
1033 "Expected ApiInternalServerError for 500 status, got: {:?}",
1034 error
1035 ),
1036 }
1037 }
1038
1039 #[test]
1040 fn test_from_cloud_failure_with_standard_format() {
1041 let error = LanguageModelCompletionError::from_cloud_failure(
1042 String::from("anthropic").into(),
1043 "upstream_http_503".to_string(),
1044 "Service unavailable".to_string(),
1045 None,
1046 );
1047
1048 match error {
1049 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1050 assert_eq!(provider.0, "anthropic");
1051 }
1052 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1053 }
1054 }
1055
1056 #[test]
1057 fn test_upstream_http_error_connection_timeout() {
1058 let error = LanguageModelCompletionError::from_cloud_failure(
1059 String::from("anthropic").into(),
1060 "upstream_http_error".to_string(),
1061 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1062 None,
1063 );
1064
1065 match error {
1066 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1067 assert_eq!(provider.0, "anthropic");
1068 }
1069 _ => panic!(
1070 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1071 error
1072 ),
1073 }
1074
1075 let error = LanguageModelCompletionError::from_cloud_failure(
1076 String::from("anthropic").into(),
1077 "upstream_http_error".to_string(),
1078 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1079 None,
1080 );
1081
1082 match error {
1083 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1084 assert_eq!(provider.0, "anthropic");
1085 assert_eq!(
1086 message,
1087 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1088 );
1089 }
1090 _ => panic!(
1091 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1092 error
1093 ),
1094 }
1095 }
1096
1097 #[test]
1098 fn test_language_model_tool_use_serializes_with_signature() {
1099 use serde_json::json;
1100
1101 let tool_use = LanguageModelToolUse {
1102 id: LanguageModelToolUseId::from("test_id"),
1103 name: "test_tool".into(),
1104 raw_input: json!({"arg": "value"}).to_string(),
1105 input: json!({"arg": "value"}),
1106 is_input_complete: true,
1107 thought_signature: Some("test_signature".to_string()),
1108 };
1109
1110 let serialized = serde_json::to_value(&tool_use).unwrap();
1111
1112 assert_eq!(serialized["id"], "test_id");
1113 assert_eq!(serialized["name"], "test_tool");
1114 assert_eq!(serialized["thought_signature"], "test_signature");
1115 }
1116
1117 #[test]
1118 fn test_language_model_tool_use_deserializes_with_missing_signature() {
1119 use serde_json::json;
1120
1121 let json = json!({
1122 "id": "test_id",
1123 "name": "test_tool",
1124 "raw_input": "{\"arg\":\"value\"}",
1125 "input": {"arg": "value"},
1126 "is_input_complete": true
1127 });
1128
1129 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1130
1131 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1132 assert_eq!(tool_use.name.as_ref(), "test_tool");
1133 assert_eq!(tool_use.thought_signature, None);
1134 }
1135
1136 #[test]
1137 fn test_language_model_tool_use_round_trip_with_signature() {
1138 use serde_json::json;
1139
1140 let original = LanguageModelToolUse {
1141 id: LanguageModelToolUseId::from("round_trip_id"),
1142 name: "round_trip_tool".into(),
1143 raw_input: json!({"key": "value"}).to_string(),
1144 input: json!({"key": "value"}),
1145 is_input_complete: true,
1146 thought_signature: Some("round_trip_sig".to_string()),
1147 };
1148
1149 let serialized = serde_json::to_value(&original).unwrap();
1150 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1151
1152 assert_eq!(deserialized.id, original.id);
1153 assert_eq!(deserialized.name, original.name);
1154 assert_eq!(deserialized.thought_signature, original.thought_signature);
1155 }
1156
1157 #[test]
1158 fn test_language_model_tool_use_round_trip_without_signature() {
1159 use serde_json::json;
1160
1161 let original = LanguageModelToolUse {
1162 id: LanguageModelToolUseId::from("no_sig_id"),
1163 name: "no_sig_tool".into(),
1164 raw_input: json!({"arg": "value"}).to_string(),
1165 input: json!({"arg": "value"}),
1166 is_input_complete: true,
1167 thought_signature: None,
1168 };
1169
1170 let serialized = serde_json::to_value(&original).unwrap();
1171 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1172
1173 assert_eq!(deserialized.id, original.id);
1174 assert_eq!(deserialized.name, original.name);
1175 assert_eq!(deserialized.thought_signature, None);
1176 }
1177}