1mod api_key;
2mod model;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7mod telemetry;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anthropic::{AnthropicError, parse_prompt_too_long};
14use anyhow::{Result, anyhow};
15use client::Client;
16use cloud_llm_client::CompletionRequestStatus;
17use futures::FutureExt;
18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
19use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
20use http_client::{StatusCode, http};
21use icons::IconName;
22use open_router::OpenRouterError;
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize};
25pub use settings::LanguageModelCacheConfiguration;
26use std::ops::{Add, Sub};
27use std::str::FromStr;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{fmt, io};
31use thiserror::Error;
32use util::serde::is_default;
33
34pub use crate::api_key::{ApiKey, ApiKeyState};
35pub use crate::model::*;
36pub use crate::rate_limiter::*;
37pub use crate::registry::*;
38pub use crate::request::*;
39pub use crate::role::*;
40pub use crate::telemetry::*;
41pub use crate::tool_schema::LanguageModelToolSchemaFormat;
42pub use zed_env_vars::{EnvVar, env_var};
43
44pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
45 LanguageModelProviderId::new("anthropic");
46pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
47 LanguageModelProviderName::new("Anthropic");
48
49pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
50pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
51 LanguageModelProviderName::new("Google AI");
52
53pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
54pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
55 LanguageModelProviderName::new("OpenAI");
56
57pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
58pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
59
60pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
61pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
62 LanguageModelProviderName::new("Zed");
63
64pub fn init(client: Arc<Client>, cx: &mut App) {
65 init_settings(cx);
66 RefreshLlmTokenListener::register(client, cx);
67}
68
69pub fn init_settings(cx: &mut App) {
70 registry::init(cx);
71}
72
73/// A completion event from a language model.
74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
75pub enum LanguageModelCompletionEvent {
76 Queued {
77 position: usize,
78 },
79 Started,
80 Stop(StopReason),
81 Text(String),
82 Thinking {
83 text: String,
84 signature: Option<String>,
85 },
86 RedactedThinking {
87 data: String,
88 },
89 ToolUse(LanguageModelToolUse),
90 ToolUseJsonParseError {
91 id: LanguageModelToolUseId,
92 tool_name: Arc<str>,
93 raw_input: Arc<str>,
94 json_parse_error: String,
95 },
96 StartMessage {
97 message_id: String,
98 },
99 ReasoningDetails(serde_json::Value),
100 UsageUpdate(TokenUsage),
101}
102
103impl LanguageModelCompletionEvent {
104 pub fn from_completion_request_status(
105 status: CompletionRequestStatus,
106 upstream_provider: LanguageModelProviderName,
107 ) -> Result<Option<Self>, LanguageModelCompletionError> {
108 match status {
109 CompletionRequestStatus::Queued { position } => {
110 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
111 }
112 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
113 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
114 CompletionRequestStatus::Failed {
115 code,
116 message,
117 request_id: _,
118 retry_after,
119 } => Err(LanguageModelCompletionError::from_cloud_failure(
120 upstream_provider,
121 code,
122 message,
123 retry_after.map(Duration::from_secs_f64),
124 )),
125 }
126 }
127}
128
129#[derive(Error, Debug)]
130pub enum LanguageModelCompletionError {
131 #[error("prompt too large for context window")]
132 PromptTooLarge { tokens: Option<u64> },
133 #[error("missing {provider} API key")]
134 NoApiKey { provider: LanguageModelProviderName },
135 #[error("{provider}'s API rate limit exceeded")]
136 RateLimitExceeded {
137 provider: LanguageModelProviderName,
138 retry_after: Option<Duration>,
139 },
140 #[error("{provider}'s API servers are overloaded right now")]
141 ServerOverloaded {
142 provider: LanguageModelProviderName,
143 retry_after: Option<Duration>,
144 },
145 #[error("{provider}'s API server reported an internal server error: {message}")]
146 ApiInternalServerError {
147 provider: LanguageModelProviderName,
148 message: String,
149 },
150 #[error("{message}")]
151 UpstreamProviderError {
152 message: String,
153 status: StatusCode,
154 retry_after: Option<Duration>,
155 },
156 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
157 HttpResponseError {
158 provider: LanguageModelProviderName,
159 status_code: StatusCode,
160 message: String,
161 },
162
163 // Client errors
164 #[error("invalid request format to {provider}'s API: {message}")]
165 BadRequestFormat {
166 provider: LanguageModelProviderName,
167 message: String,
168 },
169 #[error("authentication error with {provider}'s API: {message}")]
170 AuthenticationError {
171 provider: LanguageModelProviderName,
172 message: String,
173 },
174 #[error("Permission error with {provider}'s API: {message}")]
175 PermissionError {
176 provider: LanguageModelProviderName,
177 message: String,
178 },
179 #[error("language model provider API endpoint not found")]
180 ApiEndpointNotFound { provider: LanguageModelProviderName },
181 #[error("I/O error reading response from {provider}'s API")]
182 ApiReadResponseError {
183 provider: LanguageModelProviderName,
184 #[source]
185 error: io::Error,
186 },
187 #[error("error serializing request to {provider} API")]
188 SerializeRequest {
189 provider: LanguageModelProviderName,
190 #[source]
191 error: serde_json::Error,
192 },
193 #[error("error building request body to {provider} API")]
194 BuildRequestBody {
195 provider: LanguageModelProviderName,
196 #[source]
197 error: http::Error,
198 },
199 #[error("error sending HTTP request to {provider} API")]
200 HttpSend {
201 provider: LanguageModelProviderName,
202 #[source]
203 error: anyhow::Error,
204 },
205 #[error("error deserializing {provider} API response")]
206 DeserializeResponse {
207 provider: LanguageModelProviderName,
208 #[source]
209 error: serde_json::Error,
210 },
211
212 #[error("stream from {provider} ended unexpectedly")]
213 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
214
215 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
216 #[error(transparent)]
217 Other(#[from] anyhow::Error),
218}
219
220impl LanguageModelCompletionError {
221 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
222 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
223 let upstream_status = error_json
224 .get("upstream_status")
225 .and_then(|v| v.as_u64())
226 .and_then(|status| u16::try_from(status).ok())
227 .and_then(|status| StatusCode::from_u16(status).ok())?;
228 let inner_message = error_json
229 .get("message")
230 .and_then(|v| v.as_str())
231 .unwrap_or(message)
232 .to_string();
233 Some((upstream_status, inner_message))
234 }
235
236 pub fn from_cloud_failure(
237 upstream_provider: LanguageModelProviderName,
238 code: String,
239 message: String,
240 retry_after: Option<Duration>,
241 ) -> Self {
242 if let Some(tokens) = parse_prompt_too_long(&message) {
243 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
244 // to be reported. This is a temporary workaround to handle this in the case where the
245 // token limit has been exceeded.
246 Self::PromptTooLarge {
247 tokens: Some(tokens),
248 }
249 } else if code == "upstream_http_error" {
250 if let Some((upstream_status, inner_message)) =
251 Self::parse_upstream_error_json(&message)
252 {
253 return Self::from_http_status(
254 upstream_provider,
255 upstream_status,
256 inner_message,
257 retry_after,
258 );
259 }
260 anyhow!("completion request failed, code: {code}, message: {message}").into()
261 } else if let Some(status_code) = code
262 .strip_prefix("upstream_http_")
263 .and_then(|code| StatusCode::from_str(code).ok())
264 {
265 Self::from_http_status(upstream_provider, status_code, message, retry_after)
266 } else if let Some(status_code) = code
267 .strip_prefix("http_")
268 .and_then(|code| StatusCode::from_str(code).ok())
269 {
270 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
271 } else {
272 anyhow!("completion request failed, code: {code}, message: {message}").into()
273 }
274 }
275
276 pub fn from_http_status(
277 provider: LanguageModelProviderName,
278 status_code: StatusCode,
279 message: String,
280 retry_after: Option<Duration>,
281 ) -> Self {
282 match status_code {
283 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
284 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
285 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
286 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
287 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
288 tokens: parse_prompt_too_long(&message),
289 },
290 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
291 provider,
292 retry_after,
293 },
294 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
295 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
296 provider,
297 retry_after,
298 },
299 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 _ => Self::HttpResponseError {
304 provider,
305 status_code,
306 message,
307 },
308 }
309 }
310}
311
312impl From<AnthropicError> for LanguageModelCompletionError {
313 fn from(error: AnthropicError) -> Self {
314 let provider = ANTHROPIC_PROVIDER_NAME;
315 match error {
316 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
317 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
318 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
319 AnthropicError::DeserializeResponse(error) => {
320 Self::DeserializeResponse { provider, error }
321 }
322 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
323 AnthropicError::HttpResponseError {
324 status_code,
325 message,
326 } => Self::HttpResponseError {
327 provider,
328 status_code,
329 message,
330 },
331 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
332 provider,
333 retry_after: Some(retry_after),
334 },
335 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
336 provider,
337 retry_after,
338 },
339 AnthropicError::ApiError(api_error) => api_error.into(),
340 }
341 }
342}
343
344impl From<anthropic::ApiError> for LanguageModelCompletionError {
345 fn from(error: anthropic::ApiError) -> Self {
346 use anthropic::ApiErrorCode::*;
347 let provider = ANTHROPIC_PROVIDER_NAME;
348 match error.code() {
349 Some(code) => match code {
350 InvalidRequestError => Self::BadRequestFormat {
351 provider,
352 message: error.message,
353 },
354 AuthenticationError => Self::AuthenticationError {
355 provider,
356 message: error.message,
357 },
358 PermissionError => Self::PermissionError {
359 provider,
360 message: error.message,
361 },
362 NotFoundError => Self::ApiEndpointNotFound { provider },
363 RequestTooLarge => Self::PromptTooLarge {
364 tokens: parse_prompt_too_long(&error.message),
365 },
366 RateLimitError => Self::RateLimitExceeded {
367 provider,
368 retry_after: None,
369 },
370 ApiError => Self::ApiInternalServerError {
371 provider,
372 message: error.message,
373 },
374 OverloadedError => Self::ServerOverloaded {
375 provider,
376 retry_after: None,
377 },
378 },
379 None => Self::Other(error.into()),
380 }
381 }
382}
383
384impl From<open_ai::RequestError> for LanguageModelCompletionError {
385 fn from(error: open_ai::RequestError) -> Self {
386 match error {
387 open_ai::RequestError::HttpResponseError {
388 provider,
389 status_code,
390 body,
391 headers,
392 } => {
393 let retry_after = headers
394 .get(http::header::RETRY_AFTER)
395 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
396 .map(Duration::from_secs);
397
398 Self::from_http_status(provider.into(), status_code, body, retry_after)
399 }
400 open_ai::RequestError::Other(e) => Self::Other(e),
401 }
402 }
403}
404
405impl From<OpenRouterError> for LanguageModelCompletionError {
406 fn from(error: OpenRouterError) -> Self {
407 let provider = LanguageModelProviderName::new("OpenRouter");
408 match error {
409 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
410 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
411 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
412 OpenRouterError::DeserializeResponse(error) => {
413 Self::DeserializeResponse { provider, error }
414 }
415 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
416 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
417 provider,
418 retry_after: Some(retry_after),
419 },
420 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
421 provider,
422 retry_after,
423 },
424 OpenRouterError::ApiError(api_error) => api_error.into(),
425 }
426 }
427}
428
429impl From<open_router::ApiError> for LanguageModelCompletionError {
430 fn from(error: open_router::ApiError) -> Self {
431 use open_router::ApiErrorCode::*;
432 let provider = LanguageModelProviderName::new("OpenRouter");
433 match error.code {
434 InvalidRequestError => Self::BadRequestFormat {
435 provider,
436 message: error.message,
437 },
438 AuthenticationError => Self::AuthenticationError {
439 provider,
440 message: error.message,
441 },
442 PaymentRequiredError => Self::AuthenticationError {
443 provider,
444 message: format!("Payment required: {}", error.message),
445 },
446 PermissionError => Self::PermissionError {
447 provider,
448 message: error.message,
449 },
450 RequestTimedOut => Self::HttpResponseError {
451 provider,
452 status_code: StatusCode::REQUEST_TIMEOUT,
453 message: error.message,
454 },
455 RateLimitError => Self::RateLimitExceeded {
456 provider,
457 retry_after: None,
458 },
459 ApiError => Self::ApiInternalServerError {
460 provider,
461 message: error.message,
462 },
463 OverloadedError => Self::ServerOverloaded {
464 provider,
465 retry_after: None,
466 },
467 }
468 }
469}
470
471#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
472#[serde(rename_all = "snake_case")]
473pub enum StopReason {
474 EndTurn,
475 MaxTokens,
476 ToolUse,
477 Refusal,
478}
479
480#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
481pub struct TokenUsage {
482 #[serde(default, skip_serializing_if = "is_default")]
483 pub input_tokens: u64,
484 #[serde(default, skip_serializing_if = "is_default")]
485 pub output_tokens: u64,
486 #[serde(default, skip_serializing_if = "is_default")]
487 pub cache_creation_input_tokens: u64,
488 #[serde(default, skip_serializing_if = "is_default")]
489 pub cache_read_input_tokens: u64,
490}
491
492impl TokenUsage {
493 pub fn total_tokens(&self) -> u64 {
494 self.input_tokens
495 + self.output_tokens
496 + self.cache_read_input_tokens
497 + self.cache_creation_input_tokens
498 }
499}
500
501impl Add<TokenUsage> for TokenUsage {
502 type Output = Self;
503
504 fn add(self, other: Self) -> Self {
505 Self {
506 input_tokens: self.input_tokens + other.input_tokens,
507 output_tokens: self.output_tokens + other.output_tokens,
508 cache_creation_input_tokens: self.cache_creation_input_tokens
509 + other.cache_creation_input_tokens,
510 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
511 }
512 }
513}
514
515impl Sub<TokenUsage> for TokenUsage {
516 type Output = Self;
517
518 fn sub(self, other: Self) -> Self {
519 Self {
520 input_tokens: self.input_tokens - other.input_tokens,
521 output_tokens: self.output_tokens - other.output_tokens,
522 cache_creation_input_tokens: self.cache_creation_input_tokens
523 - other.cache_creation_input_tokens,
524 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
525 }
526 }
527}
528
529#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
530pub struct LanguageModelToolUseId(Arc<str>);
531
532impl fmt::Display for LanguageModelToolUseId {
533 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534 write!(f, "{}", self.0)
535 }
536}
537
538impl<T> From<T> for LanguageModelToolUseId
539where
540 T: Into<Arc<str>>,
541{
542 fn from(value: T) -> Self {
543 Self(value.into())
544 }
545}
546
547#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
548pub struct LanguageModelToolUse {
549 pub id: LanguageModelToolUseId,
550 pub name: Arc<str>,
551 pub raw_input: String,
552 pub input: serde_json::Value,
553 pub is_input_complete: bool,
554 /// Thought signature the model sent us. Some models require that this
555 /// signature be preserved and sent back in conversation history for validation.
556 pub thought_signature: Option<String>,
557}
558
559pub struct LanguageModelTextStream {
560 pub message_id: Option<String>,
561 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
562 // Has complete token usage after the stream has finished
563 pub last_token_usage: Arc<Mutex<TokenUsage>>,
564}
565
566impl Default for LanguageModelTextStream {
567 fn default() -> Self {
568 Self {
569 message_id: None,
570 stream: Box::pin(futures::stream::empty()),
571 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
572 }
573 }
574}
575
576#[derive(Debug, Clone)]
577pub struct LanguageModelEffortLevel {
578 pub name: SharedString,
579 pub value: SharedString,
580 pub is_default: bool,
581}
582
583pub trait LanguageModel: Send + Sync {
584 fn id(&self) -> LanguageModelId;
585 fn name(&self) -> LanguageModelName;
586 fn provider_id(&self) -> LanguageModelProviderId;
587 fn provider_name(&self) -> LanguageModelProviderName;
588 fn upstream_provider_id(&self) -> LanguageModelProviderId {
589 self.provider_id()
590 }
591 fn upstream_provider_name(&self) -> LanguageModelProviderName {
592 self.provider_name()
593 }
594
595 /// Returns whether this model is the "latest", so we can highlight it in the UI.
596 fn is_latest(&self) -> bool {
597 false
598 }
599
600 fn telemetry_id(&self) -> String;
601
602 fn api_key(&self, _cx: &App) -> Option<String> {
603 None
604 }
605
606 /// Information about the cost of using this model, if available.
607 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
608 None
609 }
610
611 /// Whether this model supports thinking.
612 fn supports_thinking(&self) -> bool {
613 false
614 }
615
616 /// Returns the list of supported effort levels that can be used when thinking.
617 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
618 Vec::new()
619 }
620
621 /// Returns the default effort level to use when thinking.
622 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
623 self.supported_effort_levels()
624 .into_iter()
625 .find(|effort_level| effort_level.is_default)
626 }
627
628 /// Whether this model supports images
629 fn supports_images(&self) -> bool;
630
631 /// Whether this model supports tools.
632 fn supports_tools(&self) -> bool;
633
634 /// Whether this model supports choosing which tool to use.
635 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
636
637 /// Returns whether this model or provider supports streaming tool calls;
638 fn supports_streaming_tools(&self) -> bool {
639 false
640 }
641
642 /// Returns whether this model/provider reports accurate split input/output token counts.
643 /// When true, the UI may show separate input/output token indicators.
644 fn supports_split_token_display(&self) -> bool {
645 false
646 }
647
648 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
649 LanguageModelToolSchemaFormat::JsonSchema
650 }
651
652 fn max_token_count(&self) -> u64;
653 fn max_output_tokens(&self) -> Option<u64> {
654 None
655 }
656
657 fn count_tokens(
658 &self,
659 request: LanguageModelRequest,
660 cx: &App,
661 ) -> BoxFuture<'static, Result<u64>>;
662
663 fn stream_completion(
664 &self,
665 request: LanguageModelRequest,
666 cx: &AsyncApp,
667 ) -> BoxFuture<
668 'static,
669 Result<
670 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
671 LanguageModelCompletionError,
672 >,
673 >;
674
675 fn stream_completion_text(
676 &self,
677 request: LanguageModelRequest,
678 cx: &AsyncApp,
679 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
680 let future = self.stream_completion(request, cx);
681
682 async move {
683 let events = future.await?;
684 let mut events = events.fuse();
685 let mut message_id = None;
686 let mut first_item_text = None;
687 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
688
689 if let Some(first_event) = events.next().await {
690 match first_event {
691 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
692 message_id = Some(id);
693 }
694 Ok(LanguageModelCompletionEvent::Text(text)) => {
695 first_item_text = Some(text);
696 }
697 _ => (),
698 }
699 }
700
701 let stream = futures::stream::iter(first_item_text.map(Ok))
702 .chain(events.filter_map({
703 let last_token_usage = last_token_usage.clone();
704 move |result| {
705 let last_token_usage = last_token_usage.clone();
706 async move {
707 match result {
708 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
709 Ok(LanguageModelCompletionEvent::Started) => None,
710 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
711 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
712 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
713 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
714 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
715 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
716 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
717 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
718 ..
719 }) => None,
720 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
721 *last_token_usage.lock() = token_usage;
722 None
723 }
724 Err(err) => Some(Err(err)),
725 }
726 }
727 }
728 }))
729 .boxed();
730
731 Ok(LanguageModelTextStream {
732 message_id,
733 stream,
734 last_token_usage,
735 })
736 }
737 .boxed()
738 }
739
740 fn stream_completion_tool(
741 &self,
742 request: LanguageModelRequest,
743 cx: &AsyncApp,
744 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
745 let future = self.stream_completion(request, cx);
746
747 async move {
748 let events = future.await?;
749 let mut events = events.fuse();
750
751 // Iterate through events until we find a complete ToolUse
752 while let Some(event) = events.next().await {
753 match event {
754 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
755 if tool_use.is_input_complete =>
756 {
757 return Ok(tool_use);
758 }
759 Err(err) => {
760 return Err(err);
761 }
762 _ => {}
763 }
764 }
765
766 // Stream ended without a complete tool use
767 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
768 "Stream ended without receiving a complete tool use"
769 )))
770 }
771 .boxed()
772 }
773
774 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
775 None
776 }
777
778 #[cfg(any(test, feature = "test-support"))]
779 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
780 unimplemented!()
781 }
782}
783
784impl std::fmt::Debug for dyn LanguageModel {
785 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
786 f.debug_struct("<dyn LanguageModel>")
787 .field("id", &self.id())
788 .field("name", &self.name())
789 .field("provider_id", &self.provider_id())
790 .field("provider_name", &self.provider_name())
791 .field("upstream_provider_name", &self.upstream_provider_name())
792 .field("upstream_provider_id", &self.upstream_provider_id())
793 .field("upstream_provider_id", &self.upstream_provider_id())
794 .field("supports_streaming_tools", &self.supports_streaming_tools())
795 .finish()
796 }
797}
798
799/// An error that occurred when trying to authenticate the language model provider.
800#[derive(Debug, Error)]
801pub enum AuthenticateError {
802 #[error("connection refused")]
803 ConnectionRefused,
804 #[error("credentials not found")]
805 CredentialsNotFound,
806 #[error(transparent)]
807 Other(#[from] anyhow::Error),
808}
809
810/// Either a built-in icon name or a path to an external SVG.
811#[derive(Debug, Clone, PartialEq, Eq)]
812pub enum IconOrSvg {
813 /// A built-in icon from Zed's icon set.
814 Icon(IconName),
815 /// Path to a custom SVG icon file.
816 Svg(SharedString),
817}
818
819impl Default for IconOrSvg {
820 fn default() -> Self {
821 Self::Icon(IconName::ZedAssistant)
822 }
823}
824
825pub trait LanguageModelProvider: 'static {
826 fn id(&self) -> LanguageModelProviderId;
827 fn name(&self) -> LanguageModelProviderName;
828 fn icon(&self) -> IconOrSvg {
829 IconOrSvg::default()
830 }
831 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
832 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
833 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
834 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
835 Vec::new()
836 }
837 fn is_authenticated(&self, cx: &App) -> bool;
838 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
839 fn configuration_view(
840 &self,
841 target_agent: ConfigurationViewTargetAgent,
842 window: &mut Window,
843 cx: &mut App,
844 ) -> AnyView;
845 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
846}
847
848#[derive(Default, Clone, PartialEq, Eq)]
849pub enum ConfigurationViewTargetAgent {
850 #[default]
851 ZedAgent,
852 Other(SharedString),
853}
854
855#[derive(PartialEq, Eq)]
856pub enum LanguageModelProviderTosView {
857 /// When there are some past interactions in the Agent Panel.
858 ThreadEmptyState,
859 /// When there are no past interactions in the Agent Panel.
860 ThreadFreshStart,
861 TextThreadPopup,
862 Configuration,
863}
864
865pub trait LanguageModelProviderState: 'static {
866 type ObservableEntity;
867
868 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
869
870 fn subscribe<T: 'static>(
871 &self,
872 cx: &mut gpui::Context<T>,
873 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
874 ) -> Option<gpui::Subscription> {
875 let entity = self.observable_entity()?;
876 Some(cx.observe(&entity, move |this, _, cx| {
877 callback(this, cx);
878 }))
879 }
880}
881
882#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
883pub struct LanguageModelId(pub SharedString);
884
885#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
886pub struct LanguageModelName(pub SharedString);
887
888#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
889pub struct LanguageModelProviderId(pub SharedString);
890
891#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
892pub struct LanguageModelProviderName(pub SharedString);
893
894#[derive(Clone, Debug, PartialEq)]
895pub enum LanguageModelCostInfo {
896 /// Cost per 1,000 input and output tokens
897 TokenCost {
898 input_token_cost_per_1m: f64,
899 output_token_cost_per_1m: f64,
900 },
901 /// Cost per request
902 RequestCost { cost_per_request: f64 },
903}
904
905impl LanguageModelCostInfo {
906 pub fn to_shared_string(&self) -> SharedString {
907 match self {
908 LanguageModelCostInfo::RequestCost { cost_per_request } => {
909 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
910 SharedString::from(cost_str)
911 }
912 LanguageModelCostInfo::TokenCost {
913 input_token_cost_per_1m,
914 output_token_cost_per_1m,
915 } => {
916 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
917 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
918 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
919 }
920 }
921 }
922
923 fn cost_value_to_string(cost: &f64) -> SharedString {
924 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
925 SharedString::from(format!("{:.0}", cost))
926 } else {
927 SharedString::from(format!("{:.2}", cost))
928 }
929 }
930}
931
932impl LanguageModelProviderId {
933 pub const fn new(id: &'static str) -> Self {
934 Self(SharedString::new_static(id))
935 }
936}
937
938impl LanguageModelProviderName {
939 pub const fn new(id: &'static str) -> Self {
940 Self(SharedString::new_static(id))
941 }
942}
943
944impl fmt::Display for LanguageModelProviderId {
945 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
946 write!(f, "{}", self.0)
947 }
948}
949
950impl fmt::Display for LanguageModelProviderName {
951 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
952 write!(f, "{}", self.0)
953 }
954}
955
956impl From<String> for LanguageModelId {
957 fn from(value: String) -> Self {
958 Self(SharedString::from(value))
959 }
960}
961
962impl From<String> for LanguageModelName {
963 fn from(value: String) -> Self {
964 Self(SharedString::from(value))
965 }
966}
967
968impl From<String> for LanguageModelProviderId {
969 fn from(value: String) -> Self {
970 Self(SharedString::from(value))
971 }
972}
973
974impl From<String> for LanguageModelProviderName {
975 fn from(value: String) -> Self {
976 Self(SharedString::from(value))
977 }
978}
979
980impl From<Arc<str>> for LanguageModelProviderId {
981 fn from(value: Arc<str>) -> Self {
982 Self(SharedString::from(value))
983 }
984}
985
986impl From<Arc<str>> for LanguageModelProviderName {
987 fn from(value: Arc<str>) -> Self {
988 Self(SharedString::from(value))
989 }
990}
991
992#[cfg(test)]
993mod tests {
994 use super::*;
995
996 #[test]
997 fn test_from_cloud_failure_with_upstream_http_error() {
998 let error = LanguageModelCompletionError::from_cloud_failure(
999 String::from("anthropic").into(),
1000 "upstream_http_error".to_string(),
1001 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1002 None,
1003 );
1004
1005 match error {
1006 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1007 assert_eq!(provider.0, "anthropic");
1008 }
1009 _ => panic!(
1010 "Expected ServerOverloaded error for 503 status, got: {:?}",
1011 error
1012 ),
1013 }
1014
1015 let error = LanguageModelCompletionError::from_cloud_failure(
1016 String::from("anthropic").into(),
1017 "upstream_http_error".to_string(),
1018 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
1019 None,
1020 );
1021
1022 match error {
1023 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1024 assert_eq!(provider.0, "anthropic");
1025 assert_eq!(message, "Internal server error");
1026 }
1027 _ => panic!(
1028 "Expected ApiInternalServerError for 500 status, got: {:?}",
1029 error
1030 ),
1031 }
1032 }
1033
1034 #[test]
1035 fn test_from_cloud_failure_with_standard_format() {
1036 let error = LanguageModelCompletionError::from_cloud_failure(
1037 String::from("anthropic").into(),
1038 "upstream_http_503".to_string(),
1039 "Service unavailable".to_string(),
1040 None,
1041 );
1042
1043 match error {
1044 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1045 assert_eq!(provider.0, "anthropic");
1046 }
1047 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1048 }
1049 }
1050
1051 #[test]
1052 fn test_upstream_http_error_connection_timeout() {
1053 let error = LanguageModelCompletionError::from_cloud_failure(
1054 String::from("anthropic").into(),
1055 "upstream_http_error".to_string(),
1056 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1057 None,
1058 );
1059
1060 match error {
1061 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1062 assert_eq!(provider.0, "anthropic");
1063 }
1064 _ => panic!(
1065 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1066 error
1067 ),
1068 }
1069
1070 let error = LanguageModelCompletionError::from_cloud_failure(
1071 String::from("anthropic").into(),
1072 "upstream_http_error".to_string(),
1073 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1074 None,
1075 );
1076
1077 match error {
1078 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1079 assert_eq!(provider.0, "anthropic");
1080 assert_eq!(
1081 message,
1082 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1083 );
1084 }
1085 _ => panic!(
1086 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1087 error
1088 ),
1089 }
1090 }
1091
1092 #[test]
1093 fn test_language_model_tool_use_serializes_with_signature() {
1094 use serde_json::json;
1095
1096 let tool_use = LanguageModelToolUse {
1097 id: LanguageModelToolUseId::from("test_id"),
1098 name: "test_tool".into(),
1099 raw_input: json!({"arg": "value"}).to_string(),
1100 input: json!({"arg": "value"}),
1101 is_input_complete: true,
1102 thought_signature: Some("test_signature".to_string()),
1103 };
1104
1105 let serialized = serde_json::to_value(&tool_use).unwrap();
1106
1107 assert_eq!(serialized["id"], "test_id");
1108 assert_eq!(serialized["name"], "test_tool");
1109 assert_eq!(serialized["thought_signature"], "test_signature");
1110 }
1111
1112 #[test]
1113 fn test_language_model_tool_use_deserializes_with_missing_signature() {
1114 use serde_json::json;
1115
1116 let json = json!({
1117 "id": "test_id",
1118 "name": "test_tool",
1119 "raw_input": "{\"arg\":\"value\"}",
1120 "input": {"arg": "value"},
1121 "is_input_complete": true
1122 });
1123
1124 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1125
1126 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1127 assert_eq!(tool_use.name.as_ref(), "test_tool");
1128 assert_eq!(tool_use.thought_signature, None);
1129 }
1130
1131 #[test]
1132 fn test_language_model_tool_use_round_trip_with_signature() {
1133 use serde_json::json;
1134
1135 let original = LanguageModelToolUse {
1136 id: LanguageModelToolUseId::from("round_trip_id"),
1137 name: "round_trip_tool".into(),
1138 raw_input: json!({"key": "value"}).to_string(),
1139 input: json!({"key": "value"}),
1140 is_input_complete: true,
1141 thought_signature: Some("round_trip_sig".to_string()),
1142 };
1143
1144 let serialized = serde_json::to_value(&original).unwrap();
1145 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1146
1147 assert_eq!(deserialized.id, original.id);
1148 assert_eq!(deserialized.name, original.name);
1149 assert_eq!(deserialized.thought_signature, original.thought_signature);
1150 }
1151
1152 #[test]
1153 fn test_language_model_tool_use_round_trip_without_signature() {
1154 use serde_json::json;
1155
1156 let original = LanguageModelToolUse {
1157 id: LanguageModelToolUseId::from("no_sig_id"),
1158 name: "no_sig_tool".into(),
1159 raw_input: json!({"arg": "value"}).to_string(),
1160 input: json!({"arg": "value"}),
1161 is_input_complete: true,
1162 thought_signature: None,
1163 };
1164
1165 let serialized = serde_json::to_value(&original).unwrap();
1166 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1167
1168 assert_eq!(deserialized.id, original.id);
1169 assert_eq!(deserialized.name, original.name);
1170 assert_eq!(deserialized.thought_signature, None);
1171 }
1172}