1mod api_key;
2mod model;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7mod telemetry;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anthropic::{AnthropicError, parse_prompt_too_long};
14use anyhow::{Result, anyhow};
15use client::Client;
16use cloud_llm_client::CompletionRequestStatus;
17use futures::FutureExt;
18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
19use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
20use http_client::{StatusCode, http};
21use icons::IconName;
22use open_router::OpenRouterError;
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize};
25pub use settings::LanguageModelCacheConfiguration;
26use std::ops::{Add, Sub};
27use std::str::FromStr;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{fmt, io};
31use thiserror::Error;
32use util::serde::is_default;
33
34pub use crate::api_key::{ApiKey, ApiKeyState};
35pub use crate::model::*;
36pub use crate::rate_limiter::*;
37pub use crate::registry::*;
38pub use crate::request::*;
39pub use crate::role::*;
40pub use crate::telemetry::*;
41pub use crate::tool_schema::LanguageModelToolSchemaFormat;
42pub use zed_env_vars::{EnvVar, env_var};
43
44pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
45 LanguageModelProviderId::new("anthropic");
46pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
47 LanguageModelProviderName::new("Anthropic");
48
49pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
50pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
51 LanguageModelProviderName::new("Google AI");
52
53pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
54pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
55 LanguageModelProviderName::new("OpenAI");
56
57pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
58pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
59
60pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
61pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
62 LanguageModelProviderName::new("Zed");
63
64pub fn init(client: Arc<Client>, cx: &mut App) {
65 init_settings(cx);
66 RefreshLlmTokenListener::register(client, cx);
67}
68
69pub fn init_settings(cx: &mut App) {
70 registry::init(cx);
71}
72
73/// A completion event from a language model.
74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
75pub enum LanguageModelCompletionEvent {
76 Queued {
77 position: usize,
78 },
79 Started,
80 Stop(StopReason),
81 Text(String),
82 Thinking {
83 text: String,
84 signature: Option<String>,
85 },
86 RedactedThinking {
87 data: String,
88 },
89 ToolUse(LanguageModelToolUse),
90 ToolUseJsonParseError {
91 id: LanguageModelToolUseId,
92 tool_name: Arc<str>,
93 raw_input: Arc<str>,
94 json_parse_error: String,
95 },
96 StartMessage {
97 message_id: String,
98 },
99 ReasoningDetails(serde_json::Value),
100 UsageUpdate(TokenUsage),
101}
102
103impl LanguageModelCompletionEvent {
104 pub fn from_completion_request_status(
105 status: CompletionRequestStatus,
106 upstream_provider: LanguageModelProviderName,
107 ) -> Result<Option<Self>, LanguageModelCompletionError> {
108 match status {
109 CompletionRequestStatus::Queued { position } => {
110 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
111 }
112 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
113 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
114 CompletionRequestStatus::UsageUpdated { .. }
115 | CompletionRequestStatus::ToolUseLimitReached => Err(
116 LanguageModelCompletionError::Other(anyhow!("Unexpected status: {status:?}")),
117 ),
118 CompletionRequestStatus::Failed {
119 code,
120 message,
121 request_id: _,
122 retry_after,
123 } => Err(LanguageModelCompletionError::from_cloud_failure(
124 upstream_provider,
125 code,
126 message,
127 retry_after.map(Duration::from_secs_f64),
128 )),
129 }
130 }
131}
132
133#[derive(Error, Debug)]
134pub enum LanguageModelCompletionError {
135 #[error("prompt too large for context window")]
136 PromptTooLarge { tokens: Option<u64> },
137 #[error("missing {provider} API key")]
138 NoApiKey { provider: LanguageModelProviderName },
139 #[error("{provider}'s API rate limit exceeded")]
140 RateLimitExceeded {
141 provider: LanguageModelProviderName,
142 retry_after: Option<Duration>,
143 },
144 #[error("{provider}'s API servers are overloaded right now")]
145 ServerOverloaded {
146 provider: LanguageModelProviderName,
147 retry_after: Option<Duration>,
148 },
149 #[error("{provider}'s API server reported an internal server error: {message}")]
150 ApiInternalServerError {
151 provider: LanguageModelProviderName,
152 message: String,
153 },
154 #[error("{message}")]
155 UpstreamProviderError {
156 message: String,
157 status: StatusCode,
158 retry_after: Option<Duration>,
159 },
160 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
161 HttpResponseError {
162 provider: LanguageModelProviderName,
163 status_code: StatusCode,
164 message: String,
165 },
166
167 // Client errors
168 #[error("invalid request format to {provider}'s API: {message}")]
169 BadRequestFormat {
170 provider: LanguageModelProviderName,
171 message: String,
172 },
173 #[error("authentication error with {provider}'s API: {message}")]
174 AuthenticationError {
175 provider: LanguageModelProviderName,
176 message: String,
177 },
178 #[error("Permission error with {provider}'s API: {message}")]
179 PermissionError {
180 provider: LanguageModelProviderName,
181 message: String,
182 },
183 #[error("language model provider API endpoint not found")]
184 ApiEndpointNotFound { provider: LanguageModelProviderName },
185 #[error("I/O error reading response from {provider}'s API")]
186 ApiReadResponseError {
187 provider: LanguageModelProviderName,
188 #[source]
189 error: io::Error,
190 },
191 #[error("error serializing request to {provider} API")]
192 SerializeRequest {
193 provider: LanguageModelProviderName,
194 #[source]
195 error: serde_json::Error,
196 },
197 #[error("error building request body to {provider} API")]
198 BuildRequestBody {
199 provider: LanguageModelProviderName,
200 #[source]
201 error: http::Error,
202 },
203 #[error("error sending HTTP request to {provider} API")]
204 HttpSend {
205 provider: LanguageModelProviderName,
206 #[source]
207 error: anyhow::Error,
208 },
209 #[error("error deserializing {provider} API response")]
210 DeserializeResponse {
211 provider: LanguageModelProviderName,
212 #[source]
213 error: serde_json::Error,
214 },
215
216 #[error("stream from {provider} ended unexpectedly")]
217 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
218
219 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
220 #[error(transparent)]
221 Other(#[from] anyhow::Error),
222}
223
224impl LanguageModelCompletionError {
225 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
226 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
227 let upstream_status = error_json
228 .get("upstream_status")
229 .and_then(|v| v.as_u64())
230 .and_then(|status| u16::try_from(status).ok())
231 .and_then(|status| StatusCode::from_u16(status).ok())?;
232 let inner_message = error_json
233 .get("message")
234 .and_then(|v| v.as_str())
235 .unwrap_or(message)
236 .to_string();
237 Some((upstream_status, inner_message))
238 }
239
240 pub fn from_cloud_failure(
241 upstream_provider: LanguageModelProviderName,
242 code: String,
243 message: String,
244 retry_after: Option<Duration>,
245 ) -> Self {
246 if let Some(tokens) = parse_prompt_too_long(&message) {
247 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
248 // to be reported. This is a temporary workaround to handle this in the case where the
249 // token limit has been exceeded.
250 Self::PromptTooLarge {
251 tokens: Some(tokens),
252 }
253 } else if code == "upstream_http_error" {
254 if let Some((upstream_status, inner_message)) =
255 Self::parse_upstream_error_json(&message)
256 {
257 return Self::from_http_status(
258 upstream_provider,
259 upstream_status,
260 inner_message,
261 retry_after,
262 );
263 }
264 anyhow!("completion request failed, code: {code}, message: {message}").into()
265 } else if let Some(status_code) = code
266 .strip_prefix("upstream_http_")
267 .and_then(|code| StatusCode::from_str(code).ok())
268 {
269 Self::from_http_status(upstream_provider, status_code, message, retry_after)
270 } else if let Some(status_code) = code
271 .strip_prefix("http_")
272 .and_then(|code| StatusCode::from_str(code).ok())
273 {
274 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
275 } else {
276 anyhow!("completion request failed, code: {code}, message: {message}").into()
277 }
278 }
279
280 pub fn from_http_status(
281 provider: LanguageModelProviderName,
282 status_code: StatusCode,
283 message: String,
284 retry_after: Option<Duration>,
285 ) -> Self {
286 match status_code {
287 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
288 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
289 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
290 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
291 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
292 tokens: parse_prompt_too_long(&message),
293 },
294 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
295 provider,
296 retry_after,
297 },
298 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
299 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
304 provider,
305 retry_after,
306 },
307 _ => Self::HttpResponseError {
308 provider,
309 status_code,
310 message,
311 },
312 }
313 }
314}
315
316impl From<AnthropicError> for LanguageModelCompletionError {
317 fn from(error: AnthropicError) -> Self {
318 let provider = ANTHROPIC_PROVIDER_NAME;
319 match error {
320 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
321 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
322 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
323 AnthropicError::DeserializeResponse(error) => {
324 Self::DeserializeResponse { provider, error }
325 }
326 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
327 AnthropicError::HttpResponseError {
328 status_code,
329 message,
330 } => Self::HttpResponseError {
331 provider,
332 status_code,
333 message,
334 },
335 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
336 provider,
337 retry_after: Some(retry_after),
338 },
339 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
340 provider,
341 retry_after,
342 },
343 AnthropicError::ApiError(api_error) => api_error.into(),
344 }
345 }
346}
347
348impl From<anthropic::ApiError> for LanguageModelCompletionError {
349 fn from(error: anthropic::ApiError) -> Self {
350 use anthropic::ApiErrorCode::*;
351 let provider = ANTHROPIC_PROVIDER_NAME;
352 match error.code() {
353 Some(code) => match code {
354 InvalidRequestError => Self::BadRequestFormat {
355 provider,
356 message: error.message,
357 },
358 AuthenticationError => Self::AuthenticationError {
359 provider,
360 message: error.message,
361 },
362 PermissionError => Self::PermissionError {
363 provider,
364 message: error.message,
365 },
366 NotFoundError => Self::ApiEndpointNotFound { provider },
367 RequestTooLarge => Self::PromptTooLarge {
368 tokens: parse_prompt_too_long(&error.message),
369 },
370 RateLimitError => Self::RateLimitExceeded {
371 provider,
372 retry_after: None,
373 },
374 ApiError => Self::ApiInternalServerError {
375 provider,
376 message: error.message,
377 },
378 OverloadedError => Self::ServerOverloaded {
379 provider,
380 retry_after: None,
381 },
382 },
383 None => Self::Other(error.into()),
384 }
385 }
386}
387
388impl From<open_ai::RequestError> for LanguageModelCompletionError {
389 fn from(error: open_ai::RequestError) -> Self {
390 match error {
391 open_ai::RequestError::HttpResponseError {
392 provider,
393 status_code,
394 body,
395 headers,
396 } => {
397 let retry_after = headers
398 .get(http::header::RETRY_AFTER)
399 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
400 .map(Duration::from_secs);
401
402 Self::from_http_status(provider.into(), status_code, body, retry_after)
403 }
404 open_ai::RequestError::Other(e) => Self::Other(e),
405 }
406 }
407}
408
409impl From<OpenRouterError> for LanguageModelCompletionError {
410 fn from(error: OpenRouterError) -> Self {
411 let provider = LanguageModelProviderName::new("OpenRouter");
412 match error {
413 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
414 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
415 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
416 OpenRouterError::DeserializeResponse(error) => {
417 Self::DeserializeResponse { provider, error }
418 }
419 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
420 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
421 provider,
422 retry_after: Some(retry_after),
423 },
424 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
425 provider,
426 retry_after,
427 },
428 OpenRouterError::ApiError(api_error) => api_error.into(),
429 }
430 }
431}
432
433impl From<open_router::ApiError> for LanguageModelCompletionError {
434 fn from(error: open_router::ApiError) -> Self {
435 use open_router::ApiErrorCode::*;
436 let provider = LanguageModelProviderName::new("OpenRouter");
437 match error.code {
438 InvalidRequestError => Self::BadRequestFormat {
439 provider,
440 message: error.message,
441 },
442 AuthenticationError => Self::AuthenticationError {
443 provider,
444 message: error.message,
445 },
446 PaymentRequiredError => Self::AuthenticationError {
447 provider,
448 message: format!("Payment required: {}", error.message),
449 },
450 PermissionError => Self::PermissionError {
451 provider,
452 message: error.message,
453 },
454 RequestTimedOut => Self::HttpResponseError {
455 provider,
456 status_code: StatusCode::REQUEST_TIMEOUT,
457 message: error.message,
458 },
459 RateLimitError => Self::RateLimitExceeded {
460 provider,
461 retry_after: None,
462 },
463 ApiError => Self::ApiInternalServerError {
464 provider,
465 message: error.message,
466 },
467 OverloadedError => Self::ServerOverloaded {
468 provider,
469 retry_after: None,
470 },
471 }
472 }
473}
474
475#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
476#[serde(rename_all = "snake_case")]
477pub enum StopReason {
478 EndTurn,
479 MaxTokens,
480 ToolUse,
481 Refusal,
482}
483
484#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
485pub struct TokenUsage {
486 #[serde(default, skip_serializing_if = "is_default")]
487 pub input_tokens: u64,
488 #[serde(default, skip_serializing_if = "is_default")]
489 pub output_tokens: u64,
490 #[serde(default, skip_serializing_if = "is_default")]
491 pub cache_creation_input_tokens: u64,
492 #[serde(default, skip_serializing_if = "is_default")]
493 pub cache_read_input_tokens: u64,
494}
495
496impl TokenUsage {
497 pub fn total_tokens(&self) -> u64 {
498 self.input_tokens
499 + self.output_tokens
500 + self.cache_read_input_tokens
501 + self.cache_creation_input_tokens
502 }
503}
504
505impl Add<TokenUsage> for TokenUsage {
506 type Output = Self;
507
508 fn add(self, other: Self) -> Self {
509 Self {
510 input_tokens: self.input_tokens + other.input_tokens,
511 output_tokens: self.output_tokens + other.output_tokens,
512 cache_creation_input_tokens: self.cache_creation_input_tokens
513 + other.cache_creation_input_tokens,
514 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
515 }
516 }
517}
518
519impl Sub<TokenUsage> for TokenUsage {
520 type Output = Self;
521
522 fn sub(self, other: Self) -> Self {
523 Self {
524 input_tokens: self.input_tokens - other.input_tokens,
525 output_tokens: self.output_tokens - other.output_tokens,
526 cache_creation_input_tokens: self.cache_creation_input_tokens
527 - other.cache_creation_input_tokens,
528 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
529 }
530 }
531}
532
533#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
534pub struct LanguageModelToolUseId(Arc<str>);
535
536impl fmt::Display for LanguageModelToolUseId {
537 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
538 write!(f, "{}", self.0)
539 }
540}
541
542impl<T> From<T> for LanguageModelToolUseId
543where
544 T: Into<Arc<str>>,
545{
546 fn from(value: T) -> Self {
547 Self(value.into())
548 }
549}
550
551#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
552pub struct LanguageModelToolUse {
553 pub id: LanguageModelToolUseId,
554 pub name: Arc<str>,
555 pub raw_input: String,
556 pub input: serde_json::Value,
557 pub is_input_complete: bool,
558 /// Thought signature the model sent us. Some models require that this
559 /// signature be preserved and sent back in conversation history for validation.
560 pub thought_signature: Option<String>,
561}
562
563pub struct LanguageModelTextStream {
564 pub message_id: Option<String>,
565 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
566 // Has complete token usage after the stream has finished
567 pub last_token_usage: Arc<Mutex<TokenUsage>>,
568}
569
570impl Default for LanguageModelTextStream {
571 fn default() -> Self {
572 Self {
573 message_id: None,
574 stream: Box::pin(futures::stream::empty()),
575 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
576 }
577 }
578}
579
580#[derive(Debug, Clone)]
581pub struct LanguageModelEffortLevel {
582 pub name: SharedString,
583 pub value: SharedString,
584 pub is_default: bool,
585}
586
587pub trait LanguageModel: Send + Sync {
588 fn id(&self) -> LanguageModelId;
589 fn name(&self) -> LanguageModelName;
590 fn provider_id(&self) -> LanguageModelProviderId;
591 fn provider_name(&self) -> LanguageModelProviderName;
592 fn upstream_provider_id(&self) -> LanguageModelProviderId {
593 self.provider_id()
594 }
595 fn upstream_provider_name(&self) -> LanguageModelProviderName {
596 self.provider_name()
597 }
598
599 /// Returns whether this model is the "latest", so we can highlight it in the UI.
600 fn is_latest(&self) -> bool {
601 false
602 }
603
604 fn telemetry_id(&self) -> String;
605
606 fn api_key(&self, _cx: &App) -> Option<String> {
607 None
608 }
609
610 /// Whether this model supports thinking.
611 fn supports_thinking(&self) -> bool {
612 false
613 }
614
615 /// Returns the list of supported effort levels that can be used when thinking.
616 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
617 Vec::new()
618 }
619
620 /// Returns the default effort level to use when thinking.
621 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
622 self.supported_effort_levels()
623 .into_iter()
624 .find(|effort_level| effort_level.is_default)
625 }
626
627 /// Whether this model supports images
628 fn supports_images(&self) -> bool;
629
630 /// Whether this model supports tools.
631 fn supports_tools(&self) -> bool;
632
633 /// Whether this model supports choosing which tool to use.
634 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
635
636 /// Returns whether this model or provider supports streaming tool calls;
637 fn supports_streaming_tools(&self) -> bool {
638 false
639 }
640
641 /// Returns whether this model/provider reports accurate split input/output token counts.
642 /// When true, the UI may show separate input/output token indicators.
643 fn supports_split_token_display(&self) -> bool {
644 false
645 }
646
647 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
648 LanguageModelToolSchemaFormat::JsonSchema
649 }
650
651 fn max_token_count(&self) -> u64;
652 fn max_output_tokens(&self) -> Option<u64> {
653 None
654 }
655
656 fn count_tokens(
657 &self,
658 request: LanguageModelRequest,
659 cx: &App,
660 ) -> BoxFuture<'static, Result<u64>>;
661
662 fn stream_completion(
663 &self,
664 request: LanguageModelRequest,
665 cx: &AsyncApp,
666 ) -> BoxFuture<
667 'static,
668 Result<
669 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
670 LanguageModelCompletionError,
671 >,
672 >;
673
674 fn stream_completion_text(
675 &self,
676 request: LanguageModelRequest,
677 cx: &AsyncApp,
678 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
679 let future = self.stream_completion(request, cx);
680
681 async move {
682 let events = future.await?;
683 let mut events = events.fuse();
684 let mut message_id = None;
685 let mut first_item_text = None;
686 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
687
688 if let Some(first_event) = events.next().await {
689 match first_event {
690 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
691 message_id = Some(id);
692 }
693 Ok(LanguageModelCompletionEvent::Text(text)) => {
694 first_item_text = Some(text);
695 }
696 _ => (),
697 }
698 }
699
700 let stream = futures::stream::iter(first_item_text.map(Ok))
701 .chain(events.filter_map({
702 let last_token_usage = last_token_usage.clone();
703 move |result| {
704 let last_token_usage = last_token_usage.clone();
705 async move {
706 match result {
707 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
708 Ok(LanguageModelCompletionEvent::Started) => None,
709 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
710 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
711 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
712 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
713 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
714 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
715 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
716 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
717 ..
718 }) => None,
719 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
720 *last_token_usage.lock() = token_usage;
721 None
722 }
723 Err(err) => Some(Err(err)),
724 }
725 }
726 }
727 }))
728 .boxed();
729
730 Ok(LanguageModelTextStream {
731 message_id,
732 stream,
733 last_token_usage,
734 })
735 }
736 .boxed()
737 }
738
739 fn stream_completion_tool(
740 &self,
741 request: LanguageModelRequest,
742 cx: &AsyncApp,
743 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
744 let future = self.stream_completion(request, cx);
745
746 async move {
747 let events = future.await?;
748 let mut events = events.fuse();
749
750 // Iterate through events until we find a complete ToolUse
751 while let Some(event) = events.next().await {
752 match event {
753 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
754 if tool_use.is_input_complete =>
755 {
756 return Ok(tool_use);
757 }
758 Err(err) => {
759 return Err(err);
760 }
761 _ => {}
762 }
763 }
764
765 // Stream ended without a complete tool use
766 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
767 "Stream ended without receiving a complete tool use"
768 )))
769 }
770 .boxed()
771 }
772
773 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
774 None
775 }
776
777 #[cfg(any(test, feature = "test-support"))]
778 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
779 unimplemented!()
780 }
781}
782
783impl std::fmt::Debug for dyn LanguageModel {
784 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
785 f.debug_struct("<dyn LanguageModel>")
786 .field("id", &self.id())
787 .field("name", &self.name())
788 .field("provider_id", &self.provider_id())
789 .field("provider_name", &self.provider_name())
790 .field("upstream_provider_name", &self.upstream_provider_name())
791 .field("upstream_provider_id", &self.upstream_provider_id())
792 .field("upstream_provider_id", &self.upstream_provider_id())
793 .field("supports_streaming_tools", &self.supports_streaming_tools())
794 .finish()
795 }
796}
797
798/// An error that occurred when trying to authenticate the language model provider.
799#[derive(Debug, Error)]
800pub enum AuthenticateError {
801 #[error("connection refused")]
802 ConnectionRefused,
803 #[error("credentials not found")]
804 CredentialsNotFound,
805 #[error(transparent)]
806 Other(#[from] anyhow::Error),
807}
808
809/// Either a built-in icon name or a path to an external SVG.
810#[derive(Debug, Clone, PartialEq, Eq)]
811pub enum IconOrSvg {
812 /// A built-in icon from Zed's icon set.
813 Icon(IconName),
814 /// Path to a custom SVG icon file.
815 Svg(SharedString),
816}
817
818impl Default for IconOrSvg {
819 fn default() -> Self {
820 Self::Icon(IconName::ZedAssistant)
821 }
822}
823
824pub trait LanguageModelProvider: 'static {
825 fn id(&self) -> LanguageModelProviderId;
826 fn name(&self) -> LanguageModelProviderName;
827 fn icon(&self) -> IconOrSvg {
828 IconOrSvg::default()
829 }
830 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
831 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
832 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
833 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
834 Vec::new()
835 }
836 fn is_authenticated(&self, cx: &App) -> bool;
837 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
838 fn configuration_view(
839 &self,
840 target_agent: ConfigurationViewTargetAgent,
841 window: &mut Window,
842 cx: &mut App,
843 ) -> AnyView;
844 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
845}
846
847#[derive(Default, Clone, PartialEq, Eq)]
848pub enum ConfigurationViewTargetAgent {
849 #[default]
850 ZedAgent,
851 Other(SharedString),
852}
853
854#[derive(PartialEq, Eq)]
855pub enum LanguageModelProviderTosView {
856 /// When there are some past interactions in the Agent Panel.
857 ThreadEmptyState,
858 /// When there are no past interactions in the Agent Panel.
859 ThreadFreshStart,
860 TextThreadPopup,
861 Configuration,
862}
863
864pub trait LanguageModelProviderState: 'static {
865 type ObservableEntity;
866
867 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
868
869 fn subscribe<T: 'static>(
870 &self,
871 cx: &mut gpui::Context<T>,
872 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
873 ) -> Option<gpui::Subscription> {
874 let entity = self.observable_entity()?;
875 Some(cx.observe(&entity, move |this, _, cx| {
876 callback(this, cx);
877 }))
878 }
879}
880
881#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
882pub struct LanguageModelId(pub SharedString);
883
884#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
885pub struct LanguageModelName(pub SharedString);
886
887#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
888pub struct LanguageModelProviderId(pub SharedString);
889
890#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
891pub struct LanguageModelProviderName(pub SharedString);
892
893impl LanguageModelProviderId {
894 pub const fn new(id: &'static str) -> Self {
895 Self(SharedString::new_static(id))
896 }
897}
898
899impl LanguageModelProviderName {
900 pub const fn new(id: &'static str) -> Self {
901 Self(SharedString::new_static(id))
902 }
903}
904
905impl fmt::Display for LanguageModelProviderId {
906 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
907 write!(f, "{}", self.0)
908 }
909}
910
911impl fmt::Display for LanguageModelProviderName {
912 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
913 write!(f, "{}", self.0)
914 }
915}
916
917impl From<String> for LanguageModelId {
918 fn from(value: String) -> Self {
919 Self(SharedString::from(value))
920 }
921}
922
923impl From<String> for LanguageModelName {
924 fn from(value: String) -> Self {
925 Self(SharedString::from(value))
926 }
927}
928
929impl From<String> for LanguageModelProviderId {
930 fn from(value: String) -> Self {
931 Self(SharedString::from(value))
932 }
933}
934
935impl From<String> for LanguageModelProviderName {
936 fn from(value: String) -> Self {
937 Self(SharedString::from(value))
938 }
939}
940
941impl From<Arc<str>> for LanguageModelProviderId {
942 fn from(value: Arc<str>) -> Self {
943 Self(SharedString::from(value))
944 }
945}
946
947impl From<Arc<str>> for LanguageModelProviderName {
948 fn from(value: Arc<str>) -> Self {
949 Self(SharedString::from(value))
950 }
951}
952
953#[cfg(test)]
954mod tests {
955 use super::*;
956
957 #[test]
958 fn test_from_cloud_failure_with_upstream_http_error() {
959 let error = LanguageModelCompletionError::from_cloud_failure(
960 String::from("anthropic").into(),
961 "upstream_http_error".to_string(),
962 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
963 None,
964 );
965
966 match error {
967 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
968 assert_eq!(provider.0, "anthropic");
969 }
970 _ => panic!(
971 "Expected ServerOverloaded error for 503 status, got: {:?}",
972 error
973 ),
974 }
975
976 let error = LanguageModelCompletionError::from_cloud_failure(
977 String::from("anthropic").into(),
978 "upstream_http_error".to_string(),
979 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
980 None,
981 );
982
983 match error {
984 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
985 assert_eq!(provider.0, "anthropic");
986 assert_eq!(message, "Internal server error");
987 }
988 _ => panic!(
989 "Expected ApiInternalServerError for 500 status, got: {:?}",
990 error
991 ),
992 }
993 }
994
995 #[test]
996 fn test_from_cloud_failure_with_standard_format() {
997 let error = LanguageModelCompletionError::from_cloud_failure(
998 String::from("anthropic").into(),
999 "upstream_http_503".to_string(),
1000 "Service unavailable".to_string(),
1001 None,
1002 );
1003
1004 match error {
1005 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1006 assert_eq!(provider.0, "anthropic");
1007 }
1008 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1009 }
1010 }
1011
1012 #[test]
1013 fn test_upstream_http_error_connection_timeout() {
1014 let error = LanguageModelCompletionError::from_cloud_failure(
1015 String::from("anthropic").into(),
1016 "upstream_http_error".to_string(),
1017 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1018 None,
1019 );
1020
1021 match error {
1022 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1023 assert_eq!(provider.0, "anthropic");
1024 }
1025 _ => panic!(
1026 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1027 error
1028 ),
1029 }
1030
1031 let error = LanguageModelCompletionError::from_cloud_failure(
1032 String::from("anthropic").into(),
1033 "upstream_http_error".to_string(),
1034 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1035 None,
1036 );
1037
1038 match error {
1039 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1040 assert_eq!(provider.0, "anthropic");
1041 assert_eq!(
1042 message,
1043 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1044 );
1045 }
1046 _ => panic!(
1047 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1048 error
1049 ),
1050 }
1051 }
1052
1053 #[test]
1054 fn test_language_model_tool_use_serializes_with_signature() {
1055 use serde_json::json;
1056
1057 let tool_use = LanguageModelToolUse {
1058 id: LanguageModelToolUseId::from("test_id"),
1059 name: "test_tool".into(),
1060 raw_input: json!({"arg": "value"}).to_string(),
1061 input: json!({"arg": "value"}),
1062 is_input_complete: true,
1063 thought_signature: Some("test_signature".to_string()),
1064 };
1065
1066 let serialized = serde_json::to_value(&tool_use).unwrap();
1067
1068 assert_eq!(serialized["id"], "test_id");
1069 assert_eq!(serialized["name"], "test_tool");
1070 assert_eq!(serialized["thought_signature"], "test_signature");
1071 }
1072
1073 #[test]
1074 fn test_language_model_tool_use_deserializes_with_missing_signature() {
1075 use serde_json::json;
1076
1077 let json = json!({
1078 "id": "test_id",
1079 "name": "test_tool",
1080 "raw_input": "{\"arg\":\"value\"}",
1081 "input": {"arg": "value"},
1082 "is_input_complete": true
1083 });
1084
1085 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1086
1087 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1088 assert_eq!(tool_use.name.as_ref(), "test_tool");
1089 assert_eq!(tool_use.thought_signature, None);
1090 }
1091
1092 #[test]
1093 fn test_language_model_tool_use_round_trip_with_signature() {
1094 use serde_json::json;
1095
1096 let original = LanguageModelToolUse {
1097 id: LanguageModelToolUseId::from("round_trip_id"),
1098 name: "round_trip_tool".into(),
1099 raw_input: json!({"key": "value"}).to_string(),
1100 input: json!({"key": "value"}),
1101 is_input_complete: true,
1102 thought_signature: Some("round_trip_sig".to_string()),
1103 };
1104
1105 let serialized = serde_json::to_value(&original).unwrap();
1106 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1107
1108 assert_eq!(deserialized.id, original.id);
1109 assert_eq!(deserialized.name, original.name);
1110 assert_eq!(deserialized.thought_signature, original.thought_signature);
1111 }
1112
1113 #[test]
1114 fn test_language_model_tool_use_round_trip_without_signature() {
1115 use serde_json::json;
1116
1117 let original = LanguageModelToolUse {
1118 id: LanguageModelToolUseId::from("no_sig_id"),
1119 name: "no_sig_tool".into(),
1120 raw_input: json!({"arg": "value"}).to_string(),
1121 input: json!({"arg": "value"}),
1122 is_input_complete: true,
1123 thought_signature: None,
1124 };
1125
1126 let serialized = serde_json::to_value(&original).unwrap();
1127 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1128
1129 assert_eq!(deserialized.id, original.id);
1130 assert_eq!(deserialized.name, original.name);
1131 assert_eq!(deserialized.thought_signature, None);
1132 }
1133}