1use ai_onboarding::YoungAccountBanner;
2use anthropic::{
3 AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, ToolResultPart,
4 Usage,
5};
6use anyhow::{Context as _, Result, anyhow};
7use chrono::{DateTime, Utc};
8use client::{Client, ModelRequestUsage, UserStore, zed_urls};
9use cloud_llm_client::{
10 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME,
11 CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
12 CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
13 MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, PlanV1, PlanV2,
14 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
15 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
16};
17use futures::{
18 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
19};
20use google_ai::GoogleModelMode;
21use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
22use http_client::http::{HeaderMap, HeaderValue};
23use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
24use language_model::{
25 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
26 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
27 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
28 LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
29 LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
30 LanguageModelToolUseId, LlmApiToken, MessageContent, ModelRequestLimitReachedError,
31 PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, Role, StopReason,
32};
33use release_channel::AppVersion;
34use schemars::JsonSchema;
35use semver::Version;
36use serde::{Deserialize, Serialize, de::DeserializeOwned};
37use settings::SettingsStore;
38pub use settings::ZedDotDevAvailableModel as AvailableModel;
39pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
40use smol::io::{AsyncReadExt, BufReader};
41use std::pin::Pin;
42use std::str::FromStr as _;
43use std::sync::Arc;
44use std::time::Duration;
45use thiserror::Error;
46use ui::{TintColor, prelude::*};
47use util::{ResultExt as _, maybe};
48
49use crate::provider::google::{GoogleEventMapper, into_google};
50use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
51use crate::provider::x_ai::count_xai_tokens;
52
53const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
54const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
55
56#[derive(Default, Clone, Debug, PartialEq)]
57pub struct ZedDotDevSettings {
58 pub available_models: Vec<AvailableModel>,
59}
60#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
61#[serde(tag = "type", rename_all = "lowercase")]
62pub enum ModelMode {
63 #[default]
64 Default,
65 Thinking {
66 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
67 budget_tokens: Option<u32>,
68 },
69}
70
71impl From<ModelMode> for AnthropicModelMode {
72 fn from(value: ModelMode) -> Self {
73 match value {
74 ModelMode::Default => AnthropicModelMode::Default,
75 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
76 }
77 }
78}
79
80pub struct CloudLanguageModelProvider {
81 client: Arc<Client>,
82 state: Entity<State>,
83 _maintain_client_status: Task<()>,
84}
85
86pub struct State {
87 client: Arc<Client>,
88 llm_api_token: LlmApiToken,
89 user_store: Entity<UserStore>,
90 status: client::Status,
91 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
92 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
93 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
94 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
95 _fetch_models_task: Task<()>,
96 _settings_subscription: Subscription,
97 _llm_token_subscription: Subscription,
98}
99
100impl State {
101 fn new(
102 client: Arc<Client>,
103 user_store: Entity<UserStore>,
104 status: client::Status,
105 cx: &mut Context<Self>,
106 ) -> Self {
107 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
108 let mut current_user = user_store.read(cx).watch_current_user();
109 Self {
110 client: client.clone(),
111 llm_api_token: LlmApiToken::default(),
112 user_store,
113 status,
114 models: Vec::new(),
115 default_model: None,
116 default_fast_model: None,
117 recommended_models: Vec::new(),
118 _fetch_models_task: cx.spawn(async move |this, cx| {
119 maybe!(async move {
120 let (client, llm_api_token) = this
121 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
122
123 while current_user.borrow().is_none() {
124 current_user.next().await;
125 }
126
127 let response =
128 Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
129 this.update(cx, |this, cx| this.update_models(response, cx))?;
130 anyhow::Ok(())
131 })
132 .await
133 .context("failed to fetch Zed models")
134 .log_err();
135 }),
136 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
137 cx.notify();
138 }),
139 _llm_token_subscription: cx.subscribe(
140 &refresh_llm_token_listener,
141 move |this, _listener, _event, cx| {
142 let client = this.client.clone();
143 let llm_api_token = this.llm_api_token.clone();
144 cx.spawn(async move |this, cx| {
145 llm_api_token.refresh(&client).await?;
146 let response = Self::fetch_models(client, llm_api_token).await?;
147 this.update(cx, |this, cx| {
148 this.update_models(response, cx);
149 })
150 })
151 .detach_and_log_err(cx);
152 },
153 ),
154 }
155 }
156
157 fn is_signed_out(&self, cx: &App) -> bool {
158 self.user_store.read(cx).current_user().is_none()
159 }
160
161 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
162 let client = self.client.clone();
163 cx.spawn(async move |state, cx| {
164 client.sign_in_with_optional_connect(true, cx).await?;
165 state.update(cx, |_, cx| cx.notify())
166 })
167 }
168 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
169 let mut models = Vec::new();
170
171 for model in response.models {
172 models.push(Arc::new(model.clone()));
173
174 // Right now we represent thinking variants of models as separate models on the client,
175 // so we need to insert variants for any model that supports thinking.
176 if model.supports_thinking {
177 models.push(Arc::new(cloud_llm_client::LanguageModel {
178 id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
179 display_name: format!("{} Thinking", model.display_name),
180 ..model
181 }));
182 }
183 }
184
185 self.default_model = models
186 .iter()
187 .find(|model| {
188 response
189 .default_model
190 .as_ref()
191 .is_some_and(|default_model_id| &model.id == default_model_id)
192 })
193 .cloned();
194 self.default_fast_model = models
195 .iter()
196 .find(|model| {
197 response
198 .default_fast_model
199 .as_ref()
200 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
201 })
202 .cloned();
203 self.recommended_models = response
204 .recommended_models
205 .iter()
206 .filter_map(|id| models.iter().find(|model| &model.id == id))
207 .cloned()
208 .collect();
209 self.models = models;
210 cx.notify();
211 }
212
213 async fn fetch_models(
214 client: Arc<Client>,
215 llm_api_token: LlmApiToken,
216 ) -> Result<ListModelsResponse> {
217 let http_client = &client.http_client();
218 let token = llm_api_token.acquire(&client).await?;
219
220 let request = http_client::Request::builder()
221 .method(Method::GET)
222 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
223 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
224 .header("Authorization", format!("Bearer {token}"))
225 .body(AsyncBody::empty())?;
226 let mut response = http_client
227 .send(request)
228 .await
229 .context("failed to send list models request")?;
230
231 if response.status().is_success() {
232 let mut body = String::new();
233 response.body_mut().read_to_string(&mut body).await?;
234 Ok(serde_json::from_str(&body)?)
235 } else {
236 let mut body = String::new();
237 response.body_mut().read_to_string(&mut body).await?;
238 anyhow::bail!(
239 "error listing models.\nStatus: {:?}\nBody: {body}",
240 response.status(),
241 );
242 }
243 }
244}
245
246impl CloudLanguageModelProvider {
247 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
248 let mut status_rx = client.status();
249 let status = *status_rx.borrow();
250
251 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
252
253 let state_ref = state.downgrade();
254 let maintain_client_status = cx.spawn(async move |cx| {
255 while let Some(status) = status_rx.next().await {
256 if let Some(this) = state_ref.upgrade() {
257 _ = this.update(cx, |this, cx| {
258 if this.status != status {
259 this.status = status;
260 cx.notify();
261 }
262 });
263 } else {
264 break;
265 }
266 }
267 });
268
269 Self {
270 client,
271 state,
272 _maintain_client_status: maintain_client_status,
273 }
274 }
275
276 fn create_language_model(
277 &self,
278 model: Arc<cloud_llm_client::LanguageModel>,
279 llm_api_token: LlmApiToken,
280 ) -> Arc<dyn LanguageModel> {
281 Arc::new(CloudLanguageModel {
282 id: LanguageModelId(SharedString::from(model.id.0.clone())),
283 model,
284 llm_api_token,
285 client: self.client.clone(),
286 request_limiter: RateLimiter::new(4),
287 })
288 }
289}
290
291impl LanguageModelProviderState for CloudLanguageModelProvider {
292 type ObservableEntity = State;
293
294 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
295 Some(self.state.clone())
296 }
297}
298
299impl LanguageModelProvider for CloudLanguageModelProvider {
300 fn id(&self) -> LanguageModelProviderId {
301 PROVIDER_ID
302 }
303
304 fn name(&self) -> LanguageModelProviderName {
305 PROVIDER_NAME
306 }
307
308 fn icon(&self) -> IconName {
309 IconName::AiZed
310 }
311
312 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
313 let default_model = self.state.read(cx).default_model.clone()?;
314 let llm_api_token = self.state.read(cx).llm_api_token.clone();
315 Some(self.create_language_model(default_model, llm_api_token))
316 }
317
318 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
319 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
320 let llm_api_token = self.state.read(cx).llm_api_token.clone();
321 Some(self.create_language_model(default_fast_model, llm_api_token))
322 }
323
324 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
325 let llm_api_token = self.state.read(cx).llm_api_token.clone();
326 self.state
327 .read(cx)
328 .recommended_models
329 .iter()
330 .cloned()
331 .map(|model| self.create_language_model(model, llm_api_token.clone()))
332 .collect()
333 }
334
335 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
336 let llm_api_token = self.state.read(cx).llm_api_token.clone();
337 self.state
338 .read(cx)
339 .models
340 .iter()
341 .cloned()
342 .map(|model| self.create_language_model(model, llm_api_token.clone()))
343 .collect()
344 }
345
346 fn is_authenticated(&self, cx: &App) -> bool {
347 let state = self.state.read(cx);
348 !state.is_signed_out(cx)
349 }
350
351 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
352 Task::ready(Ok(()))
353 }
354
355 fn configuration_view(
356 &self,
357 _target_agent: language_model::ConfigurationViewTargetAgent,
358 _: &mut Window,
359 cx: &mut App,
360 ) -> AnyView {
361 cx.new(|_| ConfigurationView::new(self.state.clone()))
362 .into()
363 }
364
365 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
366 Task::ready(Ok(()))
367 }
368}
369
370pub struct CloudLanguageModel {
371 id: LanguageModelId,
372 model: Arc<cloud_llm_client::LanguageModel>,
373 llm_api_token: LlmApiToken,
374 client: Arc<Client>,
375 request_limiter: RateLimiter,
376}
377
378struct PerformLlmCompletionResponse {
379 response: Response<AsyncBody>,
380 usage: Option<ModelRequestUsage>,
381 tool_use_limit_reached: bool,
382 includes_status_messages: bool,
383}
384
385impl CloudLanguageModel {
386 async fn perform_llm_completion(
387 client: Arc<Client>,
388 llm_api_token: LlmApiToken,
389 app_version: Option<Version>,
390 body: CompletionBody,
391 ) -> Result<PerformLlmCompletionResponse> {
392 let http_client = &client.http_client();
393
394 let mut token = llm_api_token.acquire(&client).await?;
395 let mut refreshed_token = false;
396
397 loop {
398 let request = http_client::Request::builder()
399 .method(Method::POST)
400 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
401 .when_some(app_version.as_ref(), |builder, app_version| {
402 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
403 })
404 .header("Content-Type", "application/json")
405 .header("Authorization", format!("Bearer {token}"))
406 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
407 .body(serde_json::to_string(&body)?.into())?;
408
409 let mut response = http_client.send(request).await?;
410 let status = response.status();
411 if status.is_success() {
412 let includes_status_messages = response
413 .headers()
414 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
415 .is_some();
416
417 let tool_use_limit_reached = response
418 .headers()
419 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
420 .is_some();
421
422 let usage = if includes_status_messages {
423 None
424 } else {
425 ModelRequestUsage::from_headers(response.headers()).ok()
426 };
427
428 return Ok(PerformLlmCompletionResponse {
429 response,
430 usage,
431 includes_status_messages,
432 tool_use_limit_reached,
433 });
434 }
435
436 if !refreshed_token
437 && response
438 .headers()
439 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
440 .is_some()
441 {
442 token = llm_api_token.refresh(&client).await?;
443 refreshed_token = true;
444 continue;
445 }
446
447 if status == StatusCode::FORBIDDEN
448 && response
449 .headers()
450 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
451 .is_some()
452 {
453 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
454 .headers()
455 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
456 .and_then(|resource| resource.to_str().ok())
457 && let Some(plan) = response
458 .headers()
459 .get(CURRENT_PLAN_HEADER_NAME)
460 .and_then(|plan| plan.to_str().ok())
461 .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
462 .map(Plan::V1)
463 {
464 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
465 }
466 } else if status == StatusCode::PAYMENT_REQUIRED {
467 return Err(anyhow!(PaymentRequiredError));
468 }
469
470 let mut body = String::new();
471 let headers = response.headers().clone();
472 response.body_mut().read_to_string(&mut body).await?;
473 return Err(anyhow!(ApiError {
474 status,
475 body,
476 headers
477 }));
478 }
479 }
480}
481
482#[derive(Debug, Error)]
483#[error("cloud language model request failed with status {status}: {body}")]
484struct ApiError {
485 status: StatusCode,
486 body: String,
487 headers: HeaderMap<HeaderValue>,
488}
489
490/// Represents error responses from Zed's cloud API.
491///
492/// Example JSON for an upstream HTTP error:
493/// ```json
494/// {
495/// "code": "upstream_http_error",
496/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
497/// "upstream_status": 503
498/// }
499/// ```
500#[derive(Debug, serde::Deserialize)]
501struct CloudApiError {
502 code: String,
503 message: String,
504 #[serde(default)]
505 #[serde(deserialize_with = "deserialize_optional_status_code")]
506 upstream_status: Option<StatusCode>,
507 #[serde(default)]
508 retry_after: Option<f64>,
509}
510
511fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
512where
513 D: serde::Deserializer<'de>,
514{
515 let opt: Option<u16> = Option::deserialize(deserializer)?;
516 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
517}
518
519impl From<ApiError> for LanguageModelCompletionError {
520 fn from(error: ApiError) -> Self {
521 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
522 if cloud_error.code.starts_with("upstream_http_") {
523 let status = if let Some(status) = cloud_error.upstream_status {
524 status
525 } else if cloud_error.code.ends_with("_error") {
526 error.status
527 } else {
528 // If there's a status code in the code string (e.g. "upstream_http_429")
529 // then use that; otherwise, see if the JSON contains a status code.
530 cloud_error
531 .code
532 .strip_prefix("upstream_http_")
533 .and_then(|code_str| code_str.parse::<u16>().ok())
534 .and_then(|code| StatusCode::from_u16(code).ok())
535 .unwrap_or(error.status)
536 };
537
538 return LanguageModelCompletionError::UpstreamProviderError {
539 message: cloud_error.message,
540 status,
541 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
542 };
543 }
544
545 return LanguageModelCompletionError::from_http_status(
546 PROVIDER_NAME,
547 error.status,
548 cloud_error.message,
549 None,
550 );
551 }
552
553 let retry_after = None;
554 LanguageModelCompletionError::from_http_status(
555 PROVIDER_NAME,
556 error.status,
557 error.body,
558 retry_after,
559 )
560 }
561}
562
563impl LanguageModel for CloudLanguageModel {
564 fn id(&self) -> LanguageModelId {
565 self.id.clone()
566 }
567
568 fn name(&self) -> LanguageModelName {
569 LanguageModelName::from(self.model.display_name.clone())
570 }
571
572 fn provider_id(&self) -> LanguageModelProviderId {
573 PROVIDER_ID
574 }
575
576 fn provider_name(&self) -> LanguageModelProviderName {
577 PROVIDER_NAME
578 }
579
580 fn upstream_provider_id(&self) -> LanguageModelProviderId {
581 use cloud_llm_client::LanguageModelProvider::*;
582 match self.model.provider {
583 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
584 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
585 Google => language_model::GOOGLE_PROVIDER_ID,
586 XAi => language_model::X_AI_PROVIDER_ID,
587 }
588 }
589
590 fn upstream_provider_name(&self) -> LanguageModelProviderName {
591 use cloud_llm_client::LanguageModelProvider::*;
592 match self.model.provider {
593 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
594 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
595 Google => language_model::GOOGLE_PROVIDER_NAME,
596 XAi => language_model::X_AI_PROVIDER_NAME,
597 }
598 }
599
600 fn supports_tools(&self) -> bool {
601 self.model.supports_tools
602 }
603
604 fn supports_images(&self) -> bool {
605 self.model.supports_images
606 }
607
608 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
609 match choice {
610 LanguageModelToolChoice::Auto
611 | LanguageModelToolChoice::Any
612 | LanguageModelToolChoice::None => true,
613 }
614 }
615
616 fn supports_burn_mode(&self) -> bool {
617 self.model.supports_max_mode
618 }
619
620 fn telemetry_id(&self) -> String {
621 format!("zed.dev/{}", self.model.id)
622 }
623
624 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
625 match self.model.provider {
626 cloud_llm_client::LanguageModelProvider::Anthropic
627 | cloud_llm_client::LanguageModelProvider::OpenAi
628 | cloud_llm_client::LanguageModelProvider::XAi => {
629 LanguageModelToolSchemaFormat::JsonSchema
630 }
631 cloud_llm_client::LanguageModelProvider::Google => {
632 LanguageModelToolSchemaFormat::JsonSchemaSubset
633 }
634 }
635 }
636
637 fn max_token_count(&self) -> u64 {
638 self.model.max_token_count as u64
639 }
640
641 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
642 self.model
643 .max_token_count_in_max_mode
644 .filter(|_| self.model.supports_max_mode)
645 .map(|max_token_count| max_token_count as u64)
646 }
647
648 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
649 match &self.model.provider {
650 cloud_llm_client::LanguageModelProvider::Anthropic => {
651 Some(LanguageModelCacheConfiguration {
652 min_total_token: 2_048,
653 should_speculate: true,
654 max_cache_anchors: 4,
655 })
656 }
657 cloud_llm_client::LanguageModelProvider::OpenAi
658 | cloud_llm_client::LanguageModelProvider::XAi
659 | cloud_llm_client::LanguageModelProvider::Google => None,
660 }
661 }
662
663 fn count_tokens(
664 &self,
665 request: LanguageModelRequest,
666 cx: &App,
667 ) -> BoxFuture<'static, Result<u64>> {
668 match self.model.provider {
669 cloud_llm_client::LanguageModelProvider::Anthropic => {
670 count_anthropic_tokens(request, cx)
671 }
672 cloud_llm_client::LanguageModelProvider::OpenAi => {
673 let model = match open_ai::Model::from_id(&self.model.id.0) {
674 Ok(model) => model,
675 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
676 };
677 count_open_ai_tokens(request, model, cx)
678 }
679 cloud_llm_client::LanguageModelProvider::XAi => {
680 let model = match x_ai::Model::from_id(&self.model.id.0) {
681 Ok(model) => model,
682 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
683 };
684 count_xai_tokens(request, model, cx)
685 }
686 cloud_llm_client::LanguageModelProvider::Google => {
687 let client = self.client.clone();
688 let llm_api_token = self.llm_api_token.clone();
689 let model_id = self.model.id.to_string();
690 let generate_content_request =
691 into_google(request, model_id.clone(), GoogleModelMode::Default);
692 async move {
693 let http_client = &client.http_client();
694 let token = llm_api_token.acquire(&client).await?;
695
696 let request_body = CountTokensBody {
697 provider: cloud_llm_client::LanguageModelProvider::Google,
698 model: model_id,
699 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
700 generate_content_request,
701 })?,
702 };
703 let request = http_client::Request::builder()
704 .method(Method::POST)
705 .uri(
706 http_client
707 .build_zed_llm_url("/count_tokens", &[])?
708 .as_ref(),
709 )
710 .header("Content-Type", "application/json")
711 .header("Authorization", format!("Bearer {token}"))
712 .body(serde_json::to_string(&request_body)?.into())?;
713 let mut response = http_client.send(request).await?;
714 let status = response.status();
715 let headers = response.headers().clone();
716 let mut response_body = String::new();
717 response
718 .body_mut()
719 .read_to_string(&mut response_body)
720 .await?;
721
722 if status.is_success() {
723 let response_body: CountTokensResponse =
724 serde_json::from_str(&response_body)?;
725
726 Ok(response_body.tokens as u64)
727 } else {
728 Err(anyhow!(ApiError {
729 status,
730 body: response_body,
731 headers
732 }))
733 }
734 }
735 .boxed()
736 }
737 }
738 }
739
740 fn stream_completion(
741 &self,
742 request: LanguageModelRequest,
743 cx: &AsyncApp,
744 ) -> BoxFuture<
745 'static,
746 Result<
747 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
748 LanguageModelCompletionError,
749 >,
750 > {
751 let thread_id = request.thread_id.clone();
752 let prompt_id = request.prompt_id.clone();
753 let intent = request.intent;
754 let mode = request.mode;
755 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
756 let thinking_allowed = request.thinking_allowed;
757 let provider_name = provider_name(&self.model.provider);
758 match self.model.provider {
759 cloud_llm_client::LanguageModelProvider::Anthropic => {
760 let request = into_anthropic(
761 request,
762 self.model.id.to_string(),
763 1.0,
764 self.model.max_output_tokens as u64,
765 if thinking_allowed && self.model.id.0.ends_with("-thinking") {
766 AnthropicModelMode::Thinking {
767 budget_tokens: Some(4_096),
768 }
769 } else {
770 AnthropicModelMode::Default
771 },
772 );
773 let client = self.client.clone();
774 let llm_api_token = self.llm_api_token.clone();
775 let future = self.request_limiter.stream(async move {
776 let PerformLlmCompletionResponse {
777 response,
778 usage,
779 includes_status_messages,
780 tool_use_limit_reached,
781 } = Self::perform_llm_completion(
782 client.clone(),
783 llm_api_token,
784 app_version,
785 CompletionBody {
786 thread_id,
787 prompt_id,
788 intent,
789 mode,
790 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
791 model: request.model.clone(),
792 provider_request: serde_json::to_value(&request)
793 .map_err(|e| anyhow!(e))?,
794 },
795 )
796 .await
797 .map_err(|err| match err.downcast::<ApiError>() {
798 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
799 Err(err) => anyhow!(err),
800 })?;
801
802 let mut mapper = AnthropicEventMapper::new();
803 Ok(map_cloud_completion_events(
804 Box::pin(
805 response_lines(response, includes_status_messages)
806 .chain(usage_updated_event(usage))
807 .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}),
808 ),
809 &provider_name,
810 move |event| mapper.map_event(event),
811 ))
812 });
813 async move { Ok(future.await?.boxed()) }.boxed()
814 }
815 cloud_llm_client::LanguageModelProvider::OpenAi => {
816 let client = self.client.clone();
817 let request = into_open_ai(
818 request,
819 &self.model.id.0,
820 self.model.supports_parallel_tool_calls,
821 true,
822 None,
823 None,
824 );
825 let llm_api_token = self.llm_api_token.clone();
826 let future = self.request_limiter.stream(async move {
827 let PerformLlmCompletionResponse {
828 response,
829 usage,
830 includes_status_messages,
831 tool_use_limit_reached,
832 } = Self::perform_llm_completion(
833 client.clone(),
834 llm_api_token,
835 app_version,
836 CompletionBody {
837 thread_id,
838 prompt_id,
839 intent,
840 mode,
841 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
842 model: request.model.clone(),
843 provider_request: serde_json::to_value(&request)
844 .map_err(|e| anyhow!(e))?,
845 },
846 )
847 .await?;
848
849 let mut mapper = OpenAiEventMapper::new();
850 Ok(map_cloud_completion_events(
851 Box::pin(
852 response_lines(response, includes_status_messages)
853 .chain(usage_updated_event(usage))
854 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
855 ),
856 &provider_name,
857 move |event| mapper.map_event(event),
858 ))
859 });
860 async move { Ok(future.await?.boxed()) }.boxed()
861 }
862 cloud_llm_client::LanguageModelProvider::XAi => {
863 let client = self.client.clone();
864 let request = into_open_ai(
865 request,
866 &self.model.id.0,
867 self.model.supports_parallel_tool_calls,
868 false,
869 None,
870 None,
871 );
872 let llm_api_token = self.llm_api_token.clone();
873 let future = self.request_limiter.stream(async move {
874 let PerformLlmCompletionResponse {
875 response,
876 usage,
877 includes_status_messages,
878 tool_use_limit_reached,
879 } = Self::perform_llm_completion(
880 client.clone(),
881 llm_api_token,
882 app_version,
883 CompletionBody {
884 thread_id,
885 prompt_id,
886 intent,
887 mode,
888 provider: cloud_llm_client::LanguageModelProvider::XAi,
889 model: request.model.clone(),
890 provider_request: serde_json::to_value(&request)
891 .map_err(|e| anyhow!(e))?,
892 },
893 )
894 .await?;
895
896 let mut mapper = OpenAiEventMapper::new();
897 Ok(map_cloud_completion_events(
898 Box::pin(
899 response_lines(response, includes_status_messages)
900 .chain(usage_updated_event(usage))
901 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
902 ),
903 &provider_name,
904 move |event| mapper.map_event(event),
905 ))
906 });
907 async move { Ok(future.await?.boxed()) }.boxed()
908 }
909 cloud_llm_client::LanguageModelProvider::Google => {
910 let client = self.client.clone();
911 let request =
912 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
913 let llm_api_token = self.llm_api_token.clone();
914 let future = self.request_limiter.stream(async move {
915 let PerformLlmCompletionResponse {
916 response,
917 usage,
918 includes_status_messages,
919 tool_use_limit_reached,
920 } = Self::perform_llm_completion(
921 client.clone(),
922 llm_api_token,
923 app_version,
924 CompletionBody {
925 thread_id,
926 prompt_id,
927 intent,
928 mode,
929 provider: cloud_llm_client::LanguageModelProvider::Google,
930 model: request.model.model_id.clone(),
931 provider_request: serde_json::to_value(&request)
932 .map_err(|e| anyhow!(e))?,
933 },
934 )
935 .await?;
936
937 let mut mapper = GoogleEventMapper::new();
938 Ok(map_cloud_completion_events(
939 Box::pin(
940 response_lines(response, includes_status_messages)
941 .chain(usage_updated_event(usage))
942 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
943 ),
944 &provider_name,
945 move |event| mapper.map_event(event),
946 ))
947 });
948 async move { Ok(future.await?.boxed()) }.boxed()
949 }
950 }
951 }
952}
953
954fn map_cloud_completion_events<T, F>(
955 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
956 provider: &LanguageModelProviderName,
957 mut map_callback: F,
958) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
959where
960 T: DeserializeOwned + 'static,
961 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
962 + Send
963 + 'static,
964{
965 let provider = provider.clone();
966 stream
967 .flat_map(move |event| {
968 futures::stream::iter(match event {
969 Err(error) => {
970 vec![Err(LanguageModelCompletionError::from(error))]
971 }
972 Ok(CompletionEvent::Status(event)) => {
973 vec![
974 LanguageModelCompletionEvent::from_completion_request_status(
975 event,
976 provider.clone(),
977 ),
978 ]
979 }
980 Ok(CompletionEvent::Event(event)) => map_callback(event),
981 })
982 })
983 .boxed()
984}
985
986fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
987 match provider {
988 cloud_llm_client::LanguageModelProvider::Anthropic => {
989 language_model::ANTHROPIC_PROVIDER_NAME
990 }
991 cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
992 cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
993 cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
994 }
995}
996
997fn usage_updated_event<T>(
998 usage: Option<ModelRequestUsage>,
999) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1000 futures::stream::iter(usage.map(|usage| {
1001 Ok(CompletionEvent::Status(
1002 CompletionRequestStatus::UsageUpdated {
1003 amount: usage.amount as usize,
1004 limit: usage.limit,
1005 },
1006 ))
1007 }))
1008}
1009
1010fn tool_use_limit_reached_event<T>(
1011 tool_use_limit_reached: bool,
1012) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1013 futures::stream::iter(tool_use_limit_reached.then(|| {
1014 Ok(CompletionEvent::Status(
1015 CompletionRequestStatus::ToolUseLimitReached,
1016 ))
1017 }))
1018}
1019
1020fn response_lines<T: DeserializeOwned>(
1021 response: Response<AsyncBody>,
1022 includes_status_messages: bool,
1023) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1024 futures::stream::try_unfold(
1025 (String::new(), BufReader::new(response.into_body())),
1026 move |(mut line, mut body)| async move {
1027 match body.read_line(&mut line).await {
1028 Ok(0) => Ok(None),
1029 Ok(_) => {
1030 let event = if includes_status_messages {
1031 serde_json::from_str::<CompletionEvent<T>>(&line)?
1032 } else {
1033 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1034 };
1035
1036 line.clear();
1037 Ok(Some((event, (line, body))))
1038 }
1039 Err(e) => Err(e.into()),
1040 }
1041 },
1042 )
1043}
1044
1045#[derive(IntoElement, RegisterComponent)]
1046struct ZedAiConfiguration {
1047 is_connected: bool,
1048 plan: Option<Plan>,
1049 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1050 eligible_for_trial: bool,
1051 account_too_young: bool,
1052 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1053}
1054
1055impl RenderOnce for ZedAiConfiguration {
1056 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1057 let is_pro = self.plan.is_some_and(|plan| {
1058 matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1059 });
1060 let subscription_text = match (self.plan, self.subscription_period) {
1061 (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1062 "You have access to Zed's hosted models through your Pro subscription."
1063 }
1064 (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1065 "You have access to Zed's hosted models through your Pro trial."
1066 }
1067 (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1068 "You have basic access to Zed's hosted models through the Free plan."
1069 }
1070 (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1071 if self.eligible_for_trial {
1072 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1073 } else {
1074 "Subscribe for access to Zed's hosted models."
1075 }
1076 }
1077 _ => {
1078 if self.eligible_for_trial {
1079 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1080 } else {
1081 "Subscribe for access to Zed's hosted models."
1082 }
1083 }
1084 };
1085
1086 let manage_subscription_buttons = if is_pro {
1087 Button::new("manage_settings", "Manage Subscription")
1088 .full_width()
1089 .style(ButtonStyle::Tinted(TintColor::Accent))
1090 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1091 .into_any_element()
1092 } else if self.plan.is_none() || self.eligible_for_trial {
1093 Button::new("start_trial", "Start 14-day Free Pro Trial")
1094 .full_width()
1095 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1096 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1097 .into_any_element()
1098 } else {
1099 Button::new("upgrade", "Upgrade to Pro")
1100 .full_width()
1101 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1102 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1103 .into_any_element()
1104 };
1105
1106 if !self.is_connected {
1107 return v_flex()
1108 .gap_2()
1109 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1110 .child(
1111 Button::new("sign_in", "Sign In to use Zed AI")
1112 .icon_color(Color::Muted)
1113 .icon(IconName::Github)
1114 .icon_size(IconSize::Small)
1115 .icon_position(IconPosition::Start)
1116 .full_width()
1117 .on_click({
1118 let callback = self.sign_in_callback.clone();
1119 move |_, window, cx| (callback)(window, cx)
1120 }),
1121 );
1122 }
1123
1124 v_flex().gap_2().w_full().map(|this| {
1125 if self.account_too_young {
1126 this.child(YoungAccountBanner).child(
1127 Button::new("upgrade", "Upgrade to Pro")
1128 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1129 .full_width()
1130 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1131 )
1132 } else {
1133 this.text_sm()
1134 .child(subscription_text)
1135 .child(manage_subscription_buttons)
1136 }
1137 })
1138 }
1139}
1140
1141struct ConfigurationView {
1142 state: Entity<State>,
1143 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1144}
1145
1146impl ConfigurationView {
1147 fn new(state: Entity<State>) -> Self {
1148 let sign_in_callback = Arc::new({
1149 let state = state.clone();
1150 move |_window: &mut Window, cx: &mut App| {
1151 state.update(cx, |state, cx| {
1152 state.authenticate(cx).detach_and_log_err(cx);
1153 });
1154 }
1155 });
1156
1157 Self {
1158 state,
1159 sign_in_callback,
1160 }
1161 }
1162}
1163
1164impl Render for ConfigurationView {
1165 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1166 let state = self.state.read(cx);
1167 let user_store = state.user_store.read(cx);
1168
1169 ZedAiConfiguration {
1170 is_connected: !state.is_signed_out(cx),
1171 plan: user_store.plan(),
1172 subscription_period: user_store.subscription_period(),
1173 eligible_for_trial: user_store.trial_started_at().is_none(),
1174 account_too_young: user_store.account_too_young(),
1175 sign_in_callback: self.sign_in_callback.clone(),
1176 }
1177 }
1178}
1179
1180impl Component for ZedAiConfiguration {
1181 fn name() -> &'static str {
1182 "AI Configuration Content"
1183 }
1184
1185 fn sort_name() -> &'static str {
1186 "AI Configuration Content"
1187 }
1188
1189 fn scope() -> ComponentScope {
1190 ComponentScope::Onboarding
1191 }
1192
1193 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1194 fn configuration(
1195 is_connected: bool,
1196 plan: Option<Plan>,
1197 eligible_for_trial: bool,
1198 account_too_young: bool,
1199 ) -> AnyElement {
1200 ZedAiConfiguration {
1201 is_connected,
1202 plan,
1203 subscription_period: plan
1204 .is_some()
1205 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1206 eligible_for_trial,
1207 account_too_young,
1208 sign_in_callback: Arc::new(|_, _| {}),
1209 }
1210 .into_any_element()
1211 }
1212
1213 Some(
1214 v_flex()
1215 .p_4()
1216 .gap_4()
1217 .children(vec![
1218 single_example("Not connected", configuration(false, None, false, false)),
1219 single_example(
1220 "Accept Terms of Service",
1221 configuration(true, None, true, false),
1222 ),
1223 single_example(
1224 "No Plan - Not eligible for trial",
1225 configuration(true, None, false, false),
1226 ),
1227 single_example(
1228 "No Plan - Eligible for trial",
1229 configuration(true, None, true, false),
1230 ),
1231 single_example(
1232 "Free Plan",
1233 configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1234 ),
1235 single_example(
1236 "Zed Pro Trial Plan",
1237 configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1238 ),
1239 single_example(
1240 "Zed Pro Plan",
1241 configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1242 ),
1243 ])
1244 .into_any_element(),
1245 )
1246 }
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251 use super::*;
1252 use http_client::http::{HeaderMap, StatusCode};
1253 use language_model::LanguageModelCompletionError;
1254
1255 #[test]
1256 fn test_api_error_conversion_with_upstream_http_error() {
1257 // upstream_http_error with 503 status should become ServerOverloaded
1258 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1259
1260 let api_error = ApiError {
1261 status: StatusCode::INTERNAL_SERVER_ERROR,
1262 body: error_body.to_string(),
1263 headers: HeaderMap::new(),
1264 };
1265
1266 let completion_error: LanguageModelCompletionError = api_error.into();
1267
1268 match completion_error {
1269 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1270 assert_eq!(
1271 message,
1272 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1273 );
1274 }
1275 _ => panic!(
1276 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1277 completion_error
1278 ),
1279 }
1280
1281 // upstream_http_error with 500 status should become ApiInternalServerError
1282 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1283
1284 let api_error = ApiError {
1285 status: StatusCode::INTERNAL_SERVER_ERROR,
1286 body: error_body.to_string(),
1287 headers: HeaderMap::new(),
1288 };
1289
1290 let completion_error: LanguageModelCompletionError = api_error.into();
1291
1292 match completion_error {
1293 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1294 assert_eq!(
1295 message,
1296 "Received an error from the OpenAI API: internal server error"
1297 );
1298 }
1299 _ => panic!(
1300 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1301 completion_error
1302 ),
1303 }
1304
1305 // upstream_http_error with 429 status should become RateLimitExceeded
1306 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1307
1308 let api_error = ApiError {
1309 status: StatusCode::INTERNAL_SERVER_ERROR,
1310 body: error_body.to_string(),
1311 headers: HeaderMap::new(),
1312 };
1313
1314 let completion_error: LanguageModelCompletionError = api_error.into();
1315
1316 match completion_error {
1317 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1318 assert_eq!(
1319 message,
1320 "Received an error from the Google API: rate limit exceeded"
1321 );
1322 }
1323 _ => panic!(
1324 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1325 completion_error
1326 ),
1327 }
1328
1329 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1330 let error_body = "Regular internal server error";
1331
1332 let api_error = ApiError {
1333 status: StatusCode::INTERNAL_SERVER_ERROR,
1334 body: error_body.to_string(),
1335 headers: HeaderMap::new(),
1336 };
1337
1338 let completion_error: LanguageModelCompletionError = api_error.into();
1339
1340 match completion_error {
1341 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1342 assert_eq!(provider, PROVIDER_NAME);
1343 assert_eq!(message, "Regular internal server error");
1344 }
1345 _ => panic!(
1346 "Expected ApiInternalServerError for regular 500, got: {:?}",
1347 completion_error
1348 ),
1349 }
1350
1351 // upstream_http_429 format should be converted to UpstreamProviderError
1352 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1353
1354 let api_error = ApiError {
1355 status: StatusCode::INTERNAL_SERVER_ERROR,
1356 body: error_body.to_string(),
1357 headers: HeaderMap::new(),
1358 };
1359
1360 let completion_error: LanguageModelCompletionError = api_error.into();
1361
1362 match completion_error {
1363 LanguageModelCompletionError::UpstreamProviderError {
1364 message,
1365 status,
1366 retry_after,
1367 } => {
1368 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1369 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1370 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1371 }
1372 _ => panic!(
1373 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1374 completion_error
1375 ),
1376 }
1377
1378 // Invalid JSON in error body should fall back to regular error handling
1379 let error_body = "Not JSON at all";
1380
1381 let api_error = ApiError {
1382 status: StatusCode::INTERNAL_SERVER_ERROR,
1383 body: error_body.to_string(),
1384 headers: HeaderMap::new(),
1385 };
1386
1387 let completion_error: LanguageModelCompletionError = api_error.into();
1388
1389 match completion_error {
1390 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1391 assert_eq!(provider, PROVIDER_NAME);
1392 }
1393 _ => panic!(
1394 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1395 completion_error
1396 ),
1397 }
1398 }
1399}
1400
1401fn count_anthropic_tokens(
1402 request: LanguageModelRequest,
1403 cx: &App,
1404) -> BoxFuture<'static, Result<u64>> {
1405 use gpui::AppContext as _;
1406 cx.background_spawn(async move {
1407 let messages = request.messages;
1408 let mut tokens_from_images = 0;
1409 let mut string_messages = Vec::with_capacity(messages.len());
1410
1411 for message in messages {
1412 let mut string_contents = String::new();
1413
1414 for content in message.content {
1415 match content {
1416 MessageContent::Text(text) => {
1417 string_contents.push_str(&text);
1418 }
1419 MessageContent::Thinking { .. } => {}
1420 MessageContent::RedactedThinking(_) => {}
1421 MessageContent::Image(image) => {
1422 tokens_from_images += image.estimate_tokens();
1423 }
1424 MessageContent::ToolUse(_tool_use) => {}
1425 MessageContent::ToolResult(tool_result) => match &tool_result.content {
1426 LanguageModelToolResultContent::Text(text) => {
1427 string_contents.push_str(text);
1428 }
1429 LanguageModelToolResultContent::Image(image) => {
1430 tokens_from_images += image.estimate_tokens();
1431 }
1432 },
1433 }
1434 }
1435
1436 if !string_contents.is_empty() {
1437 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
1438 role: match message.role {
1439 Role::User => "user".into(),
1440 Role::Assistant => "assistant".into(),
1441 Role::System => "system".into(),
1442 },
1443 content: Some(string_contents),
1444 name: None,
1445 function_call: None,
1446 });
1447 }
1448 }
1449
1450 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
1451 .map(|tokens| (tokens + tokens_from_images) as u64)
1452 })
1453 .boxed()
1454}
1455
1456fn into_anthropic(
1457 request: LanguageModelRequest,
1458 model: String,
1459 default_temperature: f32,
1460 max_output_tokens: u64,
1461 mode: AnthropicModelMode,
1462) -> anthropic::Request {
1463 let mut new_messages: Vec<anthropic::Message> = Vec::new();
1464 let mut system_message = String::new();
1465
1466 for message in request.messages {
1467 if message.contents_empty() {
1468 continue;
1469 }
1470
1471 match message.role {
1472 Role::User | Role::Assistant => {
1473 let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
1474 .content
1475 .into_iter()
1476 .filter_map(|content| match content {
1477 MessageContent::Text(text) => {
1478 let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
1479 text.trim_end().to_string()
1480 } else {
1481 text
1482 };
1483 if !text.is_empty() {
1484 Some(anthropic::RequestContent::Text {
1485 text,
1486 cache_control: None,
1487 })
1488 } else {
1489 None
1490 }
1491 }
1492 MessageContent::Thinking {
1493 text: thinking,
1494 signature,
1495 } => {
1496 if !thinking.is_empty() {
1497 Some(anthropic::RequestContent::Thinking {
1498 thinking,
1499 signature: signature.unwrap_or_default(),
1500 cache_control: None,
1501 })
1502 } else {
1503 None
1504 }
1505 }
1506 MessageContent::RedactedThinking(data) => {
1507 if !data.is_empty() {
1508 Some(anthropic::RequestContent::RedactedThinking { data })
1509 } else {
1510 None
1511 }
1512 }
1513 MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
1514 source: anthropic::ImageSource {
1515 source_type: "base64".to_string(),
1516 media_type: "image/png".to_string(),
1517 data: image.source.to_string(),
1518 },
1519 cache_control: None,
1520 }),
1521 MessageContent::ToolUse(tool_use) => {
1522 Some(anthropic::RequestContent::ToolUse {
1523 id: tool_use.id.to_string(),
1524 name: tool_use.name.to_string(),
1525 input: tool_use.input,
1526 cache_control: None,
1527 })
1528 }
1529 MessageContent::ToolResult(tool_result) => {
1530 Some(anthropic::RequestContent::ToolResult {
1531 tool_use_id: tool_result.tool_use_id.to_string(),
1532 is_error: tool_result.is_error,
1533 content: match tool_result.content {
1534 LanguageModelToolResultContent::Text(text) => {
1535 ToolResultContent::Plain(text.to_string())
1536 }
1537 LanguageModelToolResultContent::Image(image) => {
1538 ToolResultContent::Multipart(vec![ToolResultPart::Image {
1539 source: anthropic::ImageSource {
1540 source_type: "base64".to_string(),
1541 media_type: "image/png".to_string(),
1542 data: image.source.to_string(),
1543 },
1544 }])
1545 }
1546 },
1547 cache_control: None,
1548 })
1549 }
1550 })
1551 .collect();
1552 let anthropic_role = match message.role {
1553 Role::User => anthropic::Role::User,
1554 Role::Assistant => anthropic::Role::Assistant,
1555 Role::System => unreachable!("System role should never occur here"),
1556 };
1557 if let Some(last_message) = new_messages.last_mut()
1558 && last_message.role == anthropic_role
1559 {
1560 last_message.content.extend(anthropic_message_content);
1561 continue;
1562 }
1563
1564 if message.cache {
1565 let cache_control_value = Some(anthropic::CacheControl {
1566 cache_type: anthropic::CacheControlType::Ephemeral,
1567 });
1568 for message_content in anthropic_message_content.iter_mut().rev() {
1569 match message_content {
1570 anthropic::RequestContent::RedactedThinking { .. } => {}
1571 anthropic::RequestContent::Text { cache_control, .. }
1572 | anthropic::RequestContent::Thinking { cache_control, .. }
1573 | anthropic::RequestContent::Image { cache_control, .. }
1574 | anthropic::RequestContent::ToolUse { cache_control, .. }
1575 | anthropic::RequestContent::ToolResult { cache_control, .. } => {
1576 *cache_control = cache_control_value;
1577 break;
1578 }
1579 }
1580 }
1581 }
1582
1583 new_messages.push(anthropic::Message {
1584 role: anthropic_role,
1585 content: anthropic_message_content,
1586 });
1587 }
1588 Role::System => {
1589 if !system_message.is_empty() {
1590 system_message.push_str("\n\n");
1591 }
1592 system_message.push_str(&message.string_contents());
1593 }
1594 }
1595 }
1596
1597 anthropic::Request {
1598 model,
1599 messages: new_messages,
1600 max_tokens: max_output_tokens,
1601 system: if system_message.is_empty() {
1602 None
1603 } else {
1604 Some(anthropic::StringOrContents::String(system_message))
1605 },
1606 thinking: if request.thinking_allowed
1607 && let AnthropicModelMode::Thinking { budget_tokens } = mode
1608 {
1609 Some(anthropic::Thinking::Enabled { budget_tokens })
1610 } else {
1611 None
1612 },
1613 tools: request
1614 .tools
1615 .into_iter()
1616 .map(|tool| anthropic::Tool {
1617 name: tool.name,
1618 description: tool.description,
1619 input_schema: tool.input_schema,
1620 })
1621 .collect(),
1622 tool_choice: request.tool_choice.map(|choice| match choice {
1623 LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
1624 LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
1625 LanguageModelToolChoice::None => anthropic::ToolChoice::None,
1626 }),
1627 metadata: None,
1628 stop_sequences: Vec::new(),
1629 temperature: request.temperature.or(Some(default_temperature)),
1630 top_k: None,
1631 top_p: None,
1632 }
1633}
1634
1635struct AnthropicEventMapper {
1636 tool_uses_by_index: collections::HashMap<usize, RawToolUse>,
1637 usage: Usage,
1638 stop_reason: StopReason,
1639}
1640
1641impl AnthropicEventMapper {
1642 fn new() -> Self {
1643 Self {
1644 tool_uses_by_index: collections::HashMap::default(),
1645 usage: Usage::default(),
1646 stop_reason: StopReason::EndTurn,
1647 }
1648 }
1649
1650 fn map_event(
1651 &mut self,
1652 event: Event,
1653 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
1654 match event {
1655 Event::ContentBlockStart {
1656 index,
1657 content_block,
1658 } => match content_block {
1659 ResponseContent::Text { text } => {
1660 vec![Ok(LanguageModelCompletionEvent::Text(text))]
1661 }
1662 ResponseContent::Thinking { thinking } => {
1663 vec![Ok(LanguageModelCompletionEvent::Thinking {
1664 text: thinking,
1665 signature: None,
1666 })]
1667 }
1668 ResponseContent::RedactedThinking { data } => {
1669 vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
1670 }
1671 ResponseContent::ToolUse { id, name, .. } => {
1672 self.tool_uses_by_index.insert(
1673 index,
1674 RawToolUse {
1675 id,
1676 name,
1677 input_json: String::new(),
1678 },
1679 );
1680 Vec::new()
1681 }
1682 },
1683 Event::ContentBlockDelta { index, delta } => match delta {
1684 ContentDelta::TextDelta { text } => {
1685 vec![Ok(LanguageModelCompletionEvent::Text(text))]
1686 }
1687 ContentDelta::ThinkingDelta { thinking } => {
1688 vec![Ok(LanguageModelCompletionEvent::Thinking {
1689 text: thinking,
1690 signature: None,
1691 })]
1692 }
1693 ContentDelta::SignatureDelta { signature } => {
1694 vec![Ok(LanguageModelCompletionEvent::Thinking {
1695 text: "".to_string(),
1696 signature: Some(signature),
1697 })]
1698 }
1699 ContentDelta::InputJsonDelta { partial_json } => {
1700 if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
1701 tool_use.input_json.push_str(&partial_json);
1702
1703 let event = serde_json::from_str::<serde_json::Value>(&tool_use.input_json)
1704 .ok()
1705 .and_then(|input| {
1706 let input_json_roundtripped = serde_json::to_string(&input).ok()?;
1707
1708 if !tool_use.input_json.starts_with(&input_json_roundtripped) {
1709 return None;
1710 }
1711
1712 Some(LanguageModelCompletionEvent::ToolUse(
1713 LanguageModelToolUse {
1714 id: LanguageModelToolUseId::from(tool_use.id.clone()),
1715 name: tool_use.name.clone().into(),
1716 raw_input: tool_use.input_json.clone(),
1717 input,
1718 is_input_complete: false,
1719 thought_signature: None,
1720 },
1721 ))
1722 });
1723
1724 if let Some(event) = event {
1725 vec![Ok(event)]
1726 } else {
1727 Vec::new()
1728 }
1729 } else {
1730 Vec::new()
1731 }
1732 }
1733 },
1734 Event::ContentBlockStop { index } => {
1735 if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
1736 let event_result = match serde_json::from_str(&tool_use.input_json) {
1737 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
1738 LanguageModelToolUse {
1739 id: LanguageModelToolUseId::from(tool_use.id),
1740 name: tool_use.name.into(),
1741 raw_input: tool_use.input_json,
1742 input,
1743 is_input_complete: true,
1744 thought_signature: None,
1745 },
1746 )),
1747 Err(json_parse_err) => {
1748 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
1749 id: LanguageModelToolUseId::from(tool_use.id),
1750 tool_name: tool_use.name.into(),
1751 raw_input: tool_use.input_json.into(),
1752 json_parse_error: json_parse_err.to_string(),
1753 })
1754 }
1755 };
1756
1757 vec![event_result]
1758 } else {
1759 Vec::new()
1760 }
1761 }
1762 Event::MessageStart { message } => {
1763 update_anthropic_usage(&mut self.usage, &message.usage);
1764 vec![
1765 Ok(LanguageModelCompletionEvent::UsageUpdate(
1766 convert_anthropic_usage(&self.usage),
1767 )),
1768 Ok(LanguageModelCompletionEvent::StartMessage {
1769 message_id: message.id,
1770 }),
1771 ]
1772 }
1773 Event::MessageDelta { delta, usage } => {
1774 update_anthropic_usage(&mut self.usage, &usage);
1775 if let Some(stop_reason) = delta.stop_reason.as_deref() {
1776 self.stop_reason = match stop_reason {
1777 "end_turn" => StopReason::EndTurn,
1778 "max_tokens" => StopReason::MaxTokens,
1779 "tool_use" => StopReason::ToolUse,
1780 "refusal" => StopReason::Refusal,
1781 _ => {
1782 log::error!("Unexpected anthropic stop_reason: {stop_reason}");
1783 StopReason::EndTurn
1784 }
1785 };
1786 }
1787 vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
1788 convert_anthropic_usage(&self.usage),
1789 ))]
1790 }
1791 Event::MessageStop => {
1792 vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
1793 }
1794 Event::Error { error } => {
1795 vec![Err(error.into())]
1796 }
1797 _ => Vec::new(),
1798 }
1799 }
1800}
1801
1802struct RawToolUse {
1803 id: String,
1804 name: String,
1805 input_json: String,
1806}
1807
1808fn update_anthropic_usage(usage: &mut Usage, new: &Usage) {
1809 if let Some(input_tokens) = new.input_tokens {
1810 usage.input_tokens = Some(input_tokens);
1811 }
1812 if let Some(output_tokens) = new.output_tokens {
1813 usage.output_tokens = Some(output_tokens);
1814 }
1815 if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
1816 usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
1817 }
1818 if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
1819 usage.cache_read_input_tokens = Some(cache_read_input_tokens);
1820 }
1821}
1822
1823fn convert_anthropic_usage(usage: &Usage) -> language_model::TokenUsage {
1824 language_model::TokenUsage {
1825 input_tokens: usage.input_tokens.unwrap_or(0),
1826 output_tokens: usage.output_tokens.unwrap_or(0),
1827 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
1828 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
1829 }
1830}