1use ai_onboarding::YoungAccountBanner;
2use anthropic::AnthropicModelMode;
3use anyhow::{Context as _, Result, anyhow};
4use client::{Client, UserStore, zed_urls};
5use cloud_api_types::{OrganizationId, Plan};
6use cloud_llm_client::{
7 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
8 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
9 CountTokensBody, CountTokensResponse, ListModelsResponse,
10 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
11};
12use futures::{
13 AsyncBufReadExt, FutureExt, Stream, StreamExt,
14 future::BoxFuture,
15 stream::{self, BoxStream},
16};
17use google_ai::GoogleModelMode;
18use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
19use http_client::http::{HeaderMap, HeaderValue};
20use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
21use language_model::{
22 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
23 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
27 PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
28};
29use release_channel::AppVersion;
30use schemars::JsonSchema;
31use semver::Version;
32use serde::{Deserialize, Serialize, de::DeserializeOwned};
33use settings::SettingsStore;
34pub use settings::ZedDotDevAvailableModel as AvailableModel;
35pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
36use smol::io::{AsyncReadExt, BufReader};
37use std::collections::VecDeque;
38use std::pin::Pin;
39use std::str::FromStr;
40use std::sync::Arc;
41use std::task::Poll;
42use std::time::Duration;
43use thiserror::Error;
44use ui::{TintColor, prelude::*};
45
46use crate::provider::anthropic::{
47 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
48};
49use crate::provider::google::{GoogleEventMapper, into_google};
50use crate::provider::open_ai::{
51 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
52 into_open_ai_response,
53};
54use crate::provider::x_ai::count_xai_tokens;
55
56const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
57const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
58
59#[derive(Default, Clone, Debug, PartialEq)]
60pub struct ZedDotDevSettings {
61 pub available_models: Vec<AvailableModel>,
62}
63#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64#[serde(tag = "type", rename_all = "lowercase")]
65pub enum ModelMode {
66 #[default]
67 Default,
68 Thinking {
69 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
70 budget_tokens: Option<u32>,
71 },
72}
73
74impl From<ModelMode> for AnthropicModelMode {
75 fn from(value: ModelMode) -> Self {
76 match value {
77 ModelMode::Default => AnthropicModelMode::Default,
78 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
79 }
80 }
81}
82
83pub struct CloudLanguageModelProvider {
84 client: Arc<Client>,
85 state: Entity<State>,
86 _maintain_client_status: Task<()>,
87}
88
89pub struct State {
90 client: Arc<Client>,
91 llm_api_token: LlmApiToken,
92 user_store: Entity<UserStore>,
93 status: client::Status,
94 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
95 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
96 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
97 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
98 _user_store_subscription: Subscription,
99 _settings_subscription: Subscription,
100 _llm_token_subscription: Subscription,
101}
102
103impl State {
104 fn new(
105 client: Arc<Client>,
106 user_store: Entity<UserStore>,
107 status: client::Status,
108 cx: &mut Context<Self>,
109 ) -> Self {
110 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
111 let llm_api_token = LlmApiToken::global(cx);
112 Self {
113 client: client.clone(),
114 llm_api_token,
115 user_store: user_store.clone(),
116 status,
117 models: Vec::new(),
118 default_model: None,
119 default_fast_model: None,
120 recommended_models: Vec::new(),
121 _user_store_subscription: cx.subscribe(
122 &user_store,
123 move |this, _user_store, event, cx| match event {
124 client::user::Event::PrivateUserInfoUpdated => {
125 let status = *client.status().borrow();
126 if status.is_signed_out() {
127 return;
128 }
129
130 let client = this.client.clone();
131 let llm_api_token = this.llm_api_token.clone();
132 let organization_id = this
133 .user_store
134 .read(cx)
135 .current_organization()
136 .map(|organization| organization.id.clone());
137 cx.spawn(async move |this, cx| {
138 let response =
139 Self::fetch_models(client, llm_api_token, organization_id).await?;
140 this.update(cx, |this, cx| this.update_models(response, cx))
141 })
142 .detach_and_log_err(cx);
143 }
144 _ => {}
145 },
146 ),
147 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
148 cx.notify();
149 }),
150 _llm_token_subscription: cx.subscribe(
151 &refresh_llm_token_listener,
152 move |this, _listener, _event, cx| {
153 let client = this.client.clone();
154 let llm_api_token = this.llm_api_token.clone();
155 let organization_id = this
156 .user_store
157 .read(cx)
158 .current_organization()
159 .map(|organization| organization.id.clone());
160 cx.spawn(async move |this, cx| {
161 let response =
162 Self::fetch_models(client, llm_api_token, organization_id).await?;
163 this.update(cx, |this, cx| {
164 this.update_models(response, cx);
165 })
166 })
167 .detach_and_log_err(cx);
168 },
169 ),
170 }
171 }
172
173 fn is_signed_out(&self, cx: &App) -> bool {
174 self.user_store.read(cx).current_user().is_none()
175 }
176
177 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
178 let client = self.client.clone();
179 cx.spawn(async move |state, cx| {
180 client.sign_in_with_optional_connect(true, cx).await?;
181 state.update(cx, |_, cx| cx.notify())
182 })
183 }
184
185 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
186 let mut models = Vec::new();
187
188 for model in response.models {
189 models.push(Arc::new(model.clone()));
190 }
191
192 self.default_model = models
193 .iter()
194 .find(|model| {
195 response
196 .default_model
197 .as_ref()
198 .is_some_and(|default_model_id| &model.id == default_model_id)
199 })
200 .cloned();
201 self.default_fast_model = models
202 .iter()
203 .find(|model| {
204 response
205 .default_fast_model
206 .as_ref()
207 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
208 })
209 .cloned();
210 self.recommended_models = response
211 .recommended_models
212 .iter()
213 .filter_map(|id| models.iter().find(|model| &model.id == id))
214 .cloned()
215 .collect();
216 self.models = models;
217 cx.notify();
218 }
219
220 async fn fetch_models(
221 client: Arc<Client>,
222 llm_api_token: LlmApiToken,
223 organization_id: Option<OrganizationId>,
224 ) -> Result<ListModelsResponse> {
225 let http_client = &client.http_client();
226 let token = llm_api_token.acquire(&client, organization_id).await?;
227
228 let request = http_client::Request::builder()
229 .method(Method::GET)
230 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
231 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
232 .header("Authorization", format!("Bearer {token}"))
233 .body(AsyncBody::empty())?;
234 let mut response = http_client
235 .send(request)
236 .await
237 .context("failed to send list models request")?;
238
239 if response.status().is_success() {
240 let mut body = String::new();
241 response.body_mut().read_to_string(&mut body).await?;
242 Ok(serde_json::from_str(&body)?)
243 } else {
244 let mut body = String::new();
245 response.body_mut().read_to_string(&mut body).await?;
246 anyhow::bail!(
247 "error listing models.\nStatus: {:?}\nBody: {body}",
248 response.status(),
249 );
250 }
251 }
252}
253
254impl CloudLanguageModelProvider {
255 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
256 let mut status_rx = client.status();
257 let status = *status_rx.borrow();
258
259 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
260
261 let state_ref = state.downgrade();
262 let maintain_client_status = cx.spawn(async move |cx| {
263 while let Some(status) = status_rx.next().await {
264 if let Some(this) = state_ref.upgrade() {
265 _ = this.update(cx, |this, cx| {
266 if this.status != status {
267 this.status = status;
268 cx.notify();
269 }
270 });
271 } else {
272 break;
273 }
274 }
275 });
276
277 Self {
278 client,
279 state,
280 _maintain_client_status: maintain_client_status,
281 }
282 }
283
284 fn create_language_model(
285 &self,
286 model: Arc<cloud_llm_client::LanguageModel>,
287 llm_api_token: LlmApiToken,
288 user_store: Entity<UserStore>,
289 ) -> Arc<dyn LanguageModel> {
290 Arc::new(CloudLanguageModel {
291 id: LanguageModelId(SharedString::from(model.id.0.clone())),
292 model,
293 llm_api_token,
294 user_store,
295 client: self.client.clone(),
296 request_limiter: RateLimiter::new(4),
297 })
298 }
299}
300
301impl LanguageModelProviderState for CloudLanguageModelProvider {
302 type ObservableEntity = State;
303
304 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
305 Some(self.state.clone())
306 }
307}
308
309impl LanguageModelProvider for CloudLanguageModelProvider {
310 fn id(&self) -> LanguageModelProviderId {
311 PROVIDER_ID
312 }
313
314 fn name(&self) -> LanguageModelProviderName {
315 PROVIDER_NAME
316 }
317
318 fn icon(&self) -> IconOrSvg {
319 IconOrSvg::Icon(IconName::AiZed)
320 }
321
322 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
323 let state = self.state.read(cx);
324 let default_model = state.default_model.clone()?;
325 let llm_api_token = state.llm_api_token.clone();
326 let user_store = state.user_store.clone();
327 Some(self.create_language_model(default_model, llm_api_token, user_store))
328 }
329
330 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
331 let state = self.state.read(cx);
332 let default_fast_model = state.default_fast_model.clone()?;
333 let llm_api_token = state.llm_api_token.clone();
334 let user_store = state.user_store.clone();
335 Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
336 }
337
338 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
339 let state = self.state.read(cx);
340 let llm_api_token = state.llm_api_token.clone();
341 let user_store = state.user_store.clone();
342 state
343 .recommended_models
344 .iter()
345 .cloned()
346 .map(|model| {
347 self.create_language_model(model, llm_api_token.clone(), user_store.clone())
348 })
349 .collect()
350 }
351
352 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
353 let state = self.state.read(cx);
354 let llm_api_token = state.llm_api_token.clone();
355 let user_store = state.user_store.clone();
356 state
357 .models
358 .iter()
359 .cloned()
360 .map(|model| {
361 self.create_language_model(model, llm_api_token.clone(), user_store.clone())
362 })
363 .collect()
364 }
365
366 fn is_authenticated(&self, cx: &App) -> bool {
367 let state = self.state.read(cx);
368 !state.is_signed_out(cx)
369 }
370
371 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
372 Task::ready(Ok(()))
373 }
374
375 fn configuration_view(
376 &self,
377 _target_agent: language_model::ConfigurationViewTargetAgent,
378 _: &mut Window,
379 cx: &mut App,
380 ) -> AnyView {
381 cx.new(|_| ConfigurationView::new(self.state.clone()))
382 .into()
383 }
384
385 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
386 Task::ready(Ok(()))
387 }
388}
389
390pub struct CloudLanguageModel {
391 id: LanguageModelId,
392 model: Arc<cloud_llm_client::LanguageModel>,
393 llm_api_token: LlmApiToken,
394 user_store: Entity<UserStore>,
395 client: Arc<Client>,
396 request_limiter: RateLimiter,
397}
398
399struct PerformLlmCompletionResponse {
400 response: Response<AsyncBody>,
401 includes_status_messages: bool,
402}
403
404impl CloudLanguageModel {
405 async fn perform_llm_completion(
406 client: Arc<Client>,
407 llm_api_token: LlmApiToken,
408 organization_id: Option<OrganizationId>,
409 app_version: Option<Version>,
410 body: CompletionBody,
411 ) -> Result<PerformLlmCompletionResponse> {
412 let http_client = &client.http_client();
413
414 let mut token = llm_api_token
415 .acquire(&client, organization_id.clone())
416 .await?;
417 let mut refreshed_token = false;
418
419 loop {
420 let request = http_client::Request::builder()
421 .method(Method::POST)
422 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
423 .when_some(app_version.as_ref(), |builder, app_version| {
424 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
425 })
426 .header("Content-Type", "application/json")
427 .header("Authorization", format!("Bearer {token}"))
428 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
429 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
430 .body(serde_json::to_string(&body)?.into())?;
431
432 let mut response = http_client.send(request).await?;
433 let status = response.status();
434 if status.is_success() {
435 let includes_status_messages = response
436 .headers()
437 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
438 .is_some();
439
440 return Ok(PerformLlmCompletionResponse {
441 response,
442 includes_status_messages,
443 });
444 }
445
446 if !refreshed_token && response.needs_llm_token_refresh() {
447 token = llm_api_token
448 .refresh(&client, organization_id.clone())
449 .await?;
450 refreshed_token = true;
451 continue;
452 }
453
454 if status == StatusCode::PAYMENT_REQUIRED {
455 return Err(anyhow!(PaymentRequiredError));
456 }
457
458 let mut body = String::new();
459 let headers = response.headers().clone();
460 response.body_mut().read_to_string(&mut body).await?;
461 return Err(anyhow!(ApiError {
462 status,
463 body,
464 headers
465 }));
466 }
467 }
468}
469
470#[derive(Debug, Error)]
471#[error("cloud language model request failed with status {status}: {body}")]
472struct ApiError {
473 status: StatusCode,
474 body: String,
475 headers: HeaderMap<HeaderValue>,
476}
477
478/// Represents error responses from Zed's cloud API.
479///
480/// Example JSON for an upstream HTTP error:
481/// ```json
482/// {
483/// "code": "upstream_http_error",
484/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
485/// "upstream_status": 503
486/// }
487/// ```
488#[derive(Debug, serde::Deserialize)]
489struct CloudApiError {
490 code: String,
491 message: String,
492 #[serde(default)]
493 #[serde(deserialize_with = "deserialize_optional_status_code")]
494 upstream_status: Option<StatusCode>,
495 #[serde(default)]
496 retry_after: Option<f64>,
497}
498
499fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
500where
501 D: serde::Deserializer<'de>,
502{
503 let opt: Option<u16> = Option::deserialize(deserializer)?;
504 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
505}
506
507impl From<ApiError> for LanguageModelCompletionError {
508 fn from(error: ApiError) -> Self {
509 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
510 if cloud_error.code.starts_with("upstream_http_") {
511 let status = if let Some(status) = cloud_error.upstream_status {
512 status
513 } else if cloud_error.code.ends_with("_error") {
514 error.status
515 } else {
516 // If there's a status code in the code string (e.g. "upstream_http_429")
517 // then use that; otherwise, see if the JSON contains a status code.
518 cloud_error
519 .code
520 .strip_prefix("upstream_http_")
521 .and_then(|code_str| code_str.parse::<u16>().ok())
522 .and_then(|code| StatusCode::from_u16(code).ok())
523 .unwrap_or(error.status)
524 };
525
526 return LanguageModelCompletionError::UpstreamProviderError {
527 message: cloud_error.message,
528 status,
529 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
530 };
531 }
532
533 return LanguageModelCompletionError::from_http_status(
534 PROVIDER_NAME,
535 error.status,
536 cloud_error.message,
537 None,
538 );
539 }
540
541 let retry_after = None;
542 LanguageModelCompletionError::from_http_status(
543 PROVIDER_NAME,
544 error.status,
545 error.body,
546 retry_after,
547 )
548 }
549}
550
551impl LanguageModel for CloudLanguageModel {
552 fn id(&self) -> LanguageModelId {
553 self.id.clone()
554 }
555
556 fn name(&self) -> LanguageModelName {
557 LanguageModelName::from(self.model.display_name.clone())
558 }
559
560 fn provider_id(&self) -> LanguageModelProviderId {
561 PROVIDER_ID
562 }
563
564 fn provider_name(&self) -> LanguageModelProviderName {
565 PROVIDER_NAME
566 }
567
568 fn upstream_provider_id(&self) -> LanguageModelProviderId {
569 use cloud_llm_client::LanguageModelProvider::*;
570 match self.model.provider {
571 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
572 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
573 Google => language_model::GOOGLE_PROVIDER_ID,
574 XAi => language_model::X_AI_PROVIDER_ID,
575 }
576 }
577
578 fn upstream_provider_name(&self) -> LanguageModelProviderName {
579 use cloud_llm_client::LanguageModelProvider::*;
580 match self.model.provider {
581 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
582 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
583 Google => language_model::GOOGLE_PROVIDER_NAME,
584 XAi => language_model::X_AI_PROVIDER_NAME,
585 }
586 }
587
588 fn is_latest(&self) -> bool {
589 self.model.is_latest
590 }
591
592 fn supports_tools(&self) -> bool {
593 self.model.supports_tools
594 }
595
596 fn supports_images(&self) -> bool {
597 self.model.supports_images
598 }
599
600 fn supports_thinking(&self) -> bool {
601 self.model.supports_thinking
602 }
603
604 fn supports_fast_mode(&self) -> bool {
605 self.model.supports_fast_mode
606 }
607
608 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
609 self.model
610 .supported_effort_levels
611 .iter()
612 .map(|effort_level| LanguageModelEffortLevel {
613 name: effort_level.name.clone().into(),
614 value: effort_level.value.clone().into(),
615 is_default: effort_level.is_default.unwrap_or(false),
616 })
617 .collect()
618 }
619
620 fn supports_streaming_tools(&self) -> bool {
621 self.model.supports_streaming_tools
622 }
623
624 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
625 match choice {
626 LanguageModelToolChoice::Auto
627 | LanguageModelToolChoice::Any
628 | LanguageModelToolChoice::None => true,
629 }
630 }
631
632 fn supports_split_token_display(&self) -> bool {
633 use cloud_llm_client::LanguageModelProvider::*;
634 matches!(self.model.provider, OpenAi | XAi)
635 }
636
637 fn telemetry_id(&self) -> String {
638 format!("zed.dev/{}", self.model.id)
639 }
640
641 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
642 match self.model.provider {
643 cloud_llm_client::LanguageModelProvider::Anthropic
644 | cloud_llm_client::LanguageModelProvider::OpenAi => {
645 LanguageModelToolSchemaFormat::JsonSchema
646 }
647 cloud_llm_client::LanguageModelProvider::Google
648 | cloud_llm_client::LanguageModelProvider::XAi => {
649 LanguageModelToolSchemaFormat::JsonSchemaSubset
650 }
651 }
652 }
653
654 fn max_token_count(&self) -> u64 {
655 self.model.max_token_count as u64
656 }
657
658 fn max_output_tokens(&self) -> Option<u64> {
659 Some(self.model.max_output_tokens as u64)
660 }
661
662 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
663 match &self.model.provider {
664 cloud_llm_client::LanguageModelProvider::Anthropic => {
665 Some(LanguageModelCacheConfiguration {
666 min_total_token: 2_048,
667 should_speculate: true,
668 max_cache_anchors: 4,
669 })
670 }
671 cloud_llm_client::LanguageModelProvider::OpenAi
672 | cloud_llm_client::LanguageModelProvider::XAi
673 | cloud_llm_client::LanguageModelProvider::Google => None,
674 }
675 }
676
677 fn count_tokens(
678 &self,
679 request: LanguageModelRequest,
680 cx: &App,
681 ) -> BoxFuture<'static, Result<u64>> {
682 match self.model.provider {
683 cloud_llm_client::LanguageModelProvider::Anthropic => cx
684 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
685 .boxed(),
686 cloud_llm_client::LanguageModelProvider::OpenAi => {
687 let model = match open_ai::Model::from_id(&self.model.id.0) {
688 Ok(model) => model,
689 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
690 };
691 count_open_ai_tokens(request, model, cx)
692 }
693 cloud_llm_client::LanguageModelProvider::XAi => {
694 let model = match x_ai::Model::from_id(&self.model.id.0) {
695 Ok(model) => model,
696 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
697 };
698 count_xai_tokens(request, model, cx)
699 }
700 cloud_llm_client::LanguageModelProvider::Google => {
701 let client = self.client.clone();
702 let llm_api_token = self.llm_api_token.clone();
703 let organization_id = self
704 .user_store
705 .read(cx)
706 .current_organization()
707 .map(|organization| organization.id.clone());
708 let model_id = self.model.id.to_string();
709 let generate_content_request =
710 into_google(request, model_id.clone(), GoogleModelMode::Default);
711 async move {
712 let http_client = &client.http_client();
713 let token = llm_api_token.acquire(&client, organization_id).await?;
714
715 let request_body = CountTokensBody {
716 provider: cloud_llm_client::LanguageModelProvider::Google,
717 model: model_id,
718 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
719 generate_content_request,
720 })?,
721 };
722 let request = http_client::Request::builder()
723 .method(Method::POST)
724 .uri(
725 http_client
726 .build_zed_llm_url("/count_tokens", &[])?
727 .as_ref(),
728 )
729 .header("Content-Type", "application/json")
730 .header("Authorization", format!("Bearer {token}"))
731 .body(serde_json::to_string(&request_body)?.into())?;
732 let mut response = http_client.send(request).await?;
733 let status = response.status();
734 let headers = response.headers().clone();
735 let mut response_body = String::new();
736 response
737 .body_mut()
738 .read_to_string(&mut response_body)
739 .await?;
740
741 if status.is_success() {
742 let response_body: CountTokensResponse =
743 serde_json::from_str(&response_body)?;
744
745 Ok(response_body.tokens as u64)
746 } else {
747 Err(anyhow!(ApiError {
748 status,
749 body: response_body,
750 headers
751 }))
752 }
753 }
754 .boxed()
755 }
756 }
757 }
758
759 fn stream_completion(
760 &self,
761 request: LanguageModelRequest,
762 cx: &AsyncApp,
763 ) -> BoxFuture<
764 'static,
765 Result<
766 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
767 LanguageModelCompletionError,
768 >,
769 > {
770 let thread_id = request.thread_id.clone();
771 let prompt_id = request.prompt_id.clone();
772 let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
773 let user_store = self.user_store.clone();
774 let organization_id = cx.update(|cx| {
775 user_store
776 .read(cx)
777 .current_organization()
778 .map(|organization| organization.id.clone())
779 });
780 let thinking_allowed = request.thinking_allowed;
781 let enable_thinking = thinking_allowed && self.model.supports_thinking;
782 let provider_name = provider_name(&self.model.provider);
783 match self.model.provider {
784 cloud_llm_client::LanguageModelProvider::Anthropic => {
785 let effort = request
786 .thinking_effort
787 .as_ref()
788 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
789
790 let mut request = into_anthropic(
791 request,
792 self.model.id.to_string(),
793 1.0,
794 self.model.max_output_tokens as u64,
795 if enable_thinking {
796 AnthropicModelMode::Thinking {
797 budget_tokens: Some(4_096),
798 }
799 } else {
800 AnthropicModelMode::Default
801 },
802 );
803
804 if enable_thinking && effort.is_some() {
805 request.thinking = Some(anthropic::Thinking::Adaptive);
806 request.output_config = Some(anthropic::OutputConfig { effort });
807 }
808
809 let client = self.client.clone();
810 let llm_api_token = self.llm_api_token.clone();
811 let organization_id = organization_id.clone();
812 let future = self.request_limiter.stream(async move {
813 let PerformLlmCompletionResponse {
814 response,
815 includes_status_messages,
816 } = Self::perform_llm_completion(
817 client.clone(),
818 llm_api_token,
819 organization_id,
820 app_version,
821 CompletionBody {
822 thread_id,
823 prompt_id,
824 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
825 model: request.model.clone(),
826 provider_request: serde_json::to_value(&request)
827 .map_err(|e| anyhow!(e))?,
828 },
829 )
830 .await
831 .map_err(|err| match err.downcast::<ApiError>() {
832 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
833 Err(err) => anyhow!(err),
834 })?;
835
836 let mut mapper = AnthropicEventMapper::new();
837 Ok(map_cloud_completion_events(
838 Box::pin(response_lines(response, includes_status_messages)),
839 &provider_name,
840 move |event| mapper.map_event(event),
841 ))
842 });
843 async move { Ok(future.await?.boxed()) }.boxed()
844 }
845 cloud_llm_client::LanguageModelProvider::OpenAi => {
846 let client = self.client.clone();
847 let llm_api_token = self.llm_api_token.clone();
848 let organization_id = organization_id.clone();
849 let effort = request
850 .thinking_effort
851 .as_ref()
852 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
853
854 let mut request = into_open_ai_response(
855 request,
856 &self.model.id.0,
857 self.model.supports_parallel_tool_calls,
858 true,
859 None,
860 None,
861 );
862
863 if enable_thinking && let Some(effort) = effort {
864 request.reasoning = Some(open_ai::responses::ReasoningConfig {
865 effort,
866 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
867 });
868 }
869
870 let future = self.request_limiter.stream(async move {
871 let PerformLlmCompletionResponse {
872 response,
873 includes_status_messages,
874 } = Self::perform_llm_completion(
875 client.clone(),
876 llm_api_token,
877 organization_id,
878 app_version,
879 CompletionBody {
880 thread_id,
881 prompt_id,
882 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
883 model: request.model.clone(),
884 provider_request: serde_json::to_value(&request)
885 .map_err(|e| anyhow!(e))?,
886 },
887 )
888 .await?;
889
890 let mut mapper = OpenAiResponseEventMapper::new();
891 Ok(map_cloud_completion_events(
892 Box::pin(response_lines(response, includes_status_messages)),
893 &provider_name,
894 move |event| mapper.map_event(event),
895 ))
896 });
897 async move { Ok(future.await?.boxed()) }.boxed()
898 }
899 cloud_llm_client::LanguageModelProvider::XAi => {
900 let client = self.client.clone();
901 let request = into_open_ai(
902 request,
903 &self.model.id.0,
904 self.model.supports_parallel_tool_calls,
905 false,
906 None,
907 None,
908 );
909 let llm_api_token = self.llm_api_token.clone();
910 let organization_id = organization_id.clone();
911 let future = self.request_limiter.stream(async move {
912 let PerformLlmCompletionResponse {
913 response,
914 includes_status_messages,
915 } = Self::perform_llm_completion(
916 client.clone(),
917 llm_api_token,
918 organization_id,
919 app_version,
920 CompletionBody {
921 thread_id,
922 prompt_id,
923 provider: cloud_llm_client::LanguageModelProvider::XAi,
924 model: request.model.clone(),
925 provider_request: serde_json::to_value(&request)
926 .map_err(|e| anyhow!(e))?,
927 },
928 )
929 .await?;
930
931 let mut mapper = OpenAiEventMapper::new();
932 Ok(map_cloud_completion_events(
933 Box::pin(response_lines(response, includes_status_messages)),
934 &provider_name,
935 move |event| mapper.map_event(event),
936 ))
937 });
938 async move { Ok(future.await?.boxed()) }.boxed()
939 }
940 cloud_llm_client::LanguageModelProvider::Google => {
941 let client = self.client.clone();
942 let request =
943 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
944 let llm_api_token = self.llm_api_token.clone();
945 let future = self.request_limiter.stream(async move {
946 let PerformLlmCompletionResponse {
947 response,
948 includes_status_messages,
949 } = Self::perform_llm_completion(
950 client.clone(),
951 llm_api_token,
952 organization_id,
953 app_version,
954 CompletionBody {
955 thread_id,
956 prompt_id,
957 provider: cloud_llm_client::LanguageModelProvider::Google,
958 model: request.model.model_id.clone(),
959 provider_request: serde_json::to_value(&request)
960 .map_err(|e| anyhow!(e))?,
961 },
962 )
963 .await?;
964
965 let mut mapper = GoogleEventMapper::new();
966 Ok(map_cloud_completion_events(
967 Box::pin(response_lines(response, includes_status_messages)),
968 &provider_name,
969 move |event| mapper.map_event(event),
970 ))
971 });
972 async move { Ok(future.await?.boxed()) }.boxed()
973 }
974 }
975 }
976}
977
978fn map_cloud_completion_events<T, F>(
979 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
980 provider: &LanguageModelProviderName,
981 mut map_callback: F,
982) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
983where
984 T: DeserializeOwned + 'static,
985 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
986 + Send
987 + 'static,
988{
989 let provider = provider.clone();
990 let mut stream = stream.fuse();
991
992 let mut saw_stream_ended = false;
993
994 let mut done = false;
995 let mut pending = VecDeque::new();
996
997 stream::poll_fn(move |cx| {
998 loop {
999 if let Some(item) = pending.pop_front() {
1000 return Poll::Ready(Some(item));
1001 }
1002
1003 if done {
1004 return Poll::Ready(None);
1005 }
1006
1007 match stream.poll_next_unpin(cx) {
1008 Poll::Ready(Some(event)) => {
1009 let items = match event {
1010 Err(error) => {
1011 vec![Err(LanguageModelCompletionError::from(error))]
1012 }
1013 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1014 saw_stream_ended = true;
1015 vec![]
1016 }
1017 Ok(CompletionEvent::Status(status)) => {
1018 LanguageModelCompletionEvent::from_completion_request_status(
1019 status,
1020 provider.clone(),
1021 )
1022 .transpose()
1023 .map(|event| vec![event])
1024 .unwrap_or_default()
1025 }
1026 Ok(CompletionEvent::Event(event)) => map_callback(event),
1027 };
1028 pending.extend(items);
1029 }
1030 Poll::Ready(None) => {
1031 done = true;
1032
1033 if !saw_stream_ended {
1034 return Poll::Ready(Some(Err(
1035 LanguageModelCompletionError::StreamEndedUnexpectedly {
1036 provider: provider.clone(),
1037 },
1038 )));
1039 }
1040 }
1041 Poll::Pending => return Poll::Pending,
1042 }
1043 }
1044 })
1045 .boxed()
1046}
1047
1048fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1049 match provider {
1050 cloud_llm_client::LanguageModelProvider::Anthropic => {
1051 language_model::ANTHROPIC_PROVIDER_NAME
1052 }
1053 cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1054 cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1055 cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1056 }
1057}
1058
1059fn response_lines<T: DeserializeOwned>(
1060 response: Response<AsyncBody>,
1061 includes_status_messages: bool,
1062) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1063 futures::stream::try_unfold(
1064 (String::new(), BufReader::new(response.into_body())),
1065 move |(mut line, mut body)| async move {
1066 match body.read_line(&mut line).await {
1067 Ok(0) => Ok(None),
1068 Ok(_) => {
1069 let event = if includes_status_messages {
1070 serde_json::from_str::<CompletionEvent<T>>(&line)?
1071 } else {
1072 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1073 };
1074
1075 line.clear();
1076 Ok(Some((event, (line, body))))
1077 }
1078 Err(e) => Err(e.into()),
1079 }
1080 },
1081 )
1082}
1083
1084#[derive(IntoElement, RegisterComponent)]
1085struct ZedAiConfiguration {
1086 is_connected: bool,
1087 plan: Option<Plan>,
1088 eligible_for_trial: bool,
1089 account_too_young: bool,
1090 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1091}
1092
1093impl RenderOnce for ZedAiConfiguration {
1094 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1095 let (subscription_text, has_paid_plan) = match self.plan {
1096 Some(Plan::ZedPro) => (
1097 "You have access to Zed's hosted models through your Pro subscription.",
1098 true,
1099 ),
1100 Some(Plan::ZedProTrial) => (
1101 "You have access to Zed's hosted models through your Pro trial.",
1102 false,
1103 ),
1104 Some(Plan::ZedStudent) => (
1105 "You have access to Zed's hosted models through your Student subscription.",
1106 true,
1107 ),
1108 Some(Plan::ZedBusiness) => (
1109 "You have access to Zed's hosted models through your Organization.",
1110 true,
1111 ),
1112 Some(Plan::ZedFree) | None => (
1113 if self.eligible_for_trial {
1114 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1115 } else {
1116 "Subscribe for access to Zed's hosted models."
1117 },
1118 false,
1119 ),
1120 };
1121
1122 let manage_subscription_buttons = if has_paid_plan {
1123 Button::new("manage_settings", "Manage Subscription")
1124 .full_width()
1125 .label_size(LabelSize::Small)
1126 .style(ButtonStyle::Tinted(TintColor::Accent))
1127 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1128 .into_any_element()
1129 } else if self.plan.is_none() || self.eligible_for_trial {
1130 Button::new("start_trial", "Start 14-day Free Pro Trial")
1131 .full_width()
1132 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1133 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1134 .into_any_element()
1135 } else {
1136 Button::new("upgrade", "Upgrade to Pro")
1137 .full_width()
1138 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1139 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1140 .into_any_element()
1141 };
1142
1143 if !self.is_connected {
1144 return v_flex()
1145 .gap_2()
1146 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1147 .child(
1148 Button::new("sign_in", "Sign In to use Zed AI")
1149 .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
1150 .full_width()
1151 .on_click({
1152 let callback = self.sign_in_callback.clone();
1153 move |_, window, cx| (callback)(window, cx)
1154 }),
1155 );
1156 }
1157
1158 v_flex().gap_2().w_full().map(|this| {
1159 if self.account_too_young {
1160 this.child(YoungAccountBanner).child(
1161 Button::new("upgrade", "Upgrade to Pro")
1162 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1163 .full_width()
1164 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1165 )
1166 } else {
1167 this.text_sm()
1168 .child(subscription_text)
1169 .child(manage_subscription_buttons)
1170 }
1171 })
1172 }
1173}
1174
1175struct ConfigurationView {
1176 state: Entity<State>,
1177 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1178}
1179
1180impl ConfigurationView {
1181 fn new(state: Entity<State>) -> Self {
1182 let sign_in_callback = Arc::new({
1183 let state = state.clone();
1184 move |_window: &mut Window, cx: &mut App| {
1185 state.update(cx, |state, cx| {
1186 state.authenticate(cx).detach_and_log_err(cx);
1187 });
1188 }
1189 });
1190
1191 Self {
1192 state,
1193 sign_in_callback,
1194 }
1195 }
1196}
1197
1198impl Render for ConfigurationView {
1199 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1200 let state = self.state.read(cx);
1201 let user_store = state.user_store.read(cx);
1202
1203 ZedAiConfiguration {
1204 is_connected: !state.is_signed_out(cx),
1205 plan: user_store.plan(),
1206 eligible_for_trial: user_store.trial_started_at().is_none(),
1207 account_too_young: user_store.account_too_young(),
1208 sign_in_callback: self.sign_in_callback.clone(),
1209 }
1210 }
1211}
1212
1213impl Component for ZedAiConfiguration {
1214 fn name() -> &'static str {
1215 "AI Configuration Content"
1216 }
1217
1218 fn sort_name() -> &'static str {
1219 "AI Configuration Content"
1220 }
1221
1222 fn scope() -> ComponentScope {
1223 ComponentScope::Onboarding
1224 }
1225
1226 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1227 fn configuration(
1228 is_connected: bool,
1229 plan: Option<Plan>,
1230 eligible_for_trial: bool,
1231 account_too_young: bool,
1232 ) -> AnyElement {
1233 ZedAiConfiguration {
1234 is_connected,
1235 plan,
1236 eligible_for_trial,
1237 account_too_young,
1238 sign_in_callback: Arc::new(|_, _| {}),
1239 }
1240 .into_any_element()
1241 }
1242
1243 Some(
1244 v_flex()
1245 .p_4()
1246 .gap_4()
1247 .children(vec![
1248 single_example("Not connected", configuration(false, None, false, false)),
1249 single_example(
1250 "Accept Terms of Service",
1251 configuration(true, None, true, false),
1252 ),
1253 single_example(
1254 "No Plan - Not eligible for trial",
1255 configuration(true, None, false, false),
1256 ),
1257 single_example(
1258 "No Plan - Eligible for trial",
1259 configuration(true, None, true, false),
1260 ),
1261 single_example(
1262 "Free Plan",
1263 configuration(true, Some(Plan::ZedFree), true, false),
1264 ),
1265 single_example(
1266 "Zed Pro Trial Plan",
1267 configuration(true, Some(Plan::ZedProTrial), true, false),
1268 ),
1269 single_example(
1270 "Zed Pro Plan",
1271 configuration(true, Some(Plan::ZedPro), true, false),
1272 ),
1273 ])
1274 .into_any_element(),
1275 )
1276 }
1277}
1278
1279#[cfg(test)]
1280mod tests {
1281 use super::*;
1282 use http_client::http::{HeaderMap, StatusCode};
1283 use language_model::LanguageModelCompletionError;
1284
1285 #[test]
1286 fn test_api_error_conversion_with_upstream_http_error() {
1287 // upstream_http_error with 503 status should become ServerOverloaded
1288 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1289
1290 let api_error = ApiError {
1291 status: StatusCode::INTERNAL_SERVER_ERROR,
1292 body: error_body.to_string(),
1293 headers: HeaderMap::new(),
1294 };
1295
1296 let completion_error: LanguageModelCompletionError = api_error.into();
1297
1298 match completion_error {
1299 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1300 assert_eq!(
1301 message,
1302 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1303 );
1304 }
1305 _ => panic!(
1306 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1307 completion_error
1308 ),
1309 }
1310
1311 // upstream_http_error with 500 status should become ApiInternalServerError
1312 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1313
1314 let api_error = ApiError {
1315 status: StatusCode::INTERNAL_SERVER_ERROR,
1316 body: error_body.to_string(),
1317 headers: HeaderMap::new(),
1318 };
1319
1320 let completion_error: LanguageModelCompletionError = api_error.into();
1321
1322 match completion_error {
1323 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1324 assert_eq!(
1325 message,
1326 "Received an error from the OpenAI API: internal server error"
1327 );
1328 }
1329 _ => panic!(
1330 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1331 completion_error
1332 ),
1333 }
1334
1335 // upstream_http_error with 429 status should become RateLimitExceeded
1336 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1337
1338 let api_error = ApiError {
1339 status: StatusCode::INTERNAL_SERVER_ERROR,
1340 body: error_body.to_string(),
1341 headers: HeaderMap::new(),
1342 };
1343
1344 let completion_error: LanguageModelCompletionError = api_error.into();
1345
1346 match completion_error {
1347 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1348 assert_eq!(
1349 message,
1350 "Received an error from the Google API: rate limit exceeded"
1351 );
1352 }
1353 _ => panic!(
1354 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1355 completion_error
1356 ),
1357 }
1358
1359 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1360 let error_body = "Regular internal server error";
1361
1362 let api_error = ApiError {
1363 status: StatusCode::INTERNAL_SERVER_ERROR,
1364 body: error_body.to_string(),
1365 headers: HeaderMap::new(),
1366 };
1367
1368 let completion_error: LanguageModelCompletionError = api_error.into();
1369
1370 match completion_error {
1371 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1372 assert_eq!(provider, PROVIDER_NAME);
1373 assert_eq!(message, "Regular internal server error");
1374 }
1375 _ => panic!(
1376 "Expected ApiInternalServerError for regular 500, got: {:?}",
1377 completion_error
1378 ),
1379 }
1380
1381 // upstream_http_429 format should be converted to UpstreamProviderError
1382 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1383
1384 let api_error = ApiError {
1385 status: StatusCode::INTERNAL_SERVER_ERROR,
1386 body: error_body.to_string(),
1387 headers: HeaderMap::new(),
1388 };
1389
1390 let completion_error: LanguageModelCompletionError = api_error.into();
1391
1392 match completion_error {
1393 LanguageModelCompletionError::UpstreamProviderError {
1394 message,
1395 status,
1396 retry_after,
1397 } => {
1398 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1399 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1400 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1401 }
1402 _ => panic!(
1403 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1404 completion_error
1405 ),
1406 }
1407
1408 // Invalid JSON in error body should fall back to regular error handling
1409 let error_body = "Not JSON at all";
1410
1411 let api_error = ApiError {
1412 status: StatusCode::INTERNAL_SERVER_ERROR,
1413 body: error_body.to_string(),
1414 headers: HeaderMap::new(),
1415 };
1416
1417 let completion_error: LanguageModelCompletionError = api_error.into();
1418
1419 match completion_error {
1420 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1421 assert_eq!(provider, PROVIDER_NAME);
1422 }
1423 _ => panic!(
1424 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1425 completion_error
1426 ),
1427 }
1428 }
1429}