cloud.rs

   1use super::open_ai::count_open_ai_tokens;
   2use crate::{
   3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
   4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
   5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
   6};
   7use anyhow::{anyhow, bail, Context as _, Result};
   8use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
   9use collections::BTreeMap;
  10use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
  11use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
  12use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
  13use http_client::{AsyncBody, HttpClient, Method, Response};
  14use schemars::JsonSchema;
  15use serde::{Deserialize, Serialize};
  16use serde_json::value::RawValue;
  17use settings::{Settings, SettingsStore};
  18use smol::{
  19    io::BufReader,
  20    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
  21};
  22use std::{future, sync::Arc};
  23use strum::IntoEnumIterator;
  24use ui::prelude::*;
  25
  26use crate::{LanguageModelAvailability, LanguageModelProvider};
  27
  28use super::anthropic::count_anthropic_tokens;
  29
  30pub const PROVIDER_ID: &str = "zed.dev";
  31pub const PROVIDER_NAME: &str = "Zed";
  32
  33#[derive(Default, Clone, Debug, PartialEq)]
  34pub struct ZedDotDevSettings {
  35    pub available_models: Vec<AvailableModel>,
  36}
  37
  38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  39#[serde(rename_all = "lowercase")]
  40pub enum AvailableProvider {
  41    Anthropic,
  42    OpenAi,
  43    Google,
  44}
  45
  46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  47pub struct AvailableModel {
  48    provider: AvailableProvider,
  49    name: String,
  50    max_tokens: usize,
  51    tool_override: Option<String>,
  52}
  53
  54pub struct CloudLanguageModelProvider {
  55    client: Arc<Client>,
  56    llm_api_token: LlmApiToken,
  57    state: gpui::Model<State>,
  58    _maintain_client_status: Task<()>,
  59}
  60
  61pub struct State {
  62    client: Arc<Client>,
  63    user_store: Model<UserStore>,
  64    status: client::Status,
  65    _subscription: Subscription,
  66}
  67
  68impl State {
  69    fn is_signed_out(&self) -> bool {
  70        self.status.is_signed_out()
  71    }
  72
  73    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
  74        let client = self.client.clone();
  75        cx.spawn(move |this, mut cx| async move {
  76            client.authenticate_and_connect(true, &cx).await?;
  77            this.update(&mut cx, |_, cx| cx.notify())
  78        })
  79    }
  80}
  81
  82impl CloudLanguageModelProvider {
  83    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
  84        let mut status_rx = client.status();
  85        let status = *status_rx.borrow();
  86
  87        let state = cx.new_model(|cx| State {
  88            client: client.clone(),
  89            user_store,
  90            status,
  91            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
  92                cx.notify();
  93            }),
  94        });
  95
  96        let state_ref = state.downgrade();
  97        let maintain_client_status = cx.spawn(|mut cx| async move {
  98            while let Some(status) = status_rx.next().await {
  99                if let Some(this) = state_ref.upgrade() {
 100                    _ = this.update(&mut cx, |this, cx| {
 101                        if this.status != status {
 102                            this.status = status;
 103                            cx.notify();
 104                        }
 105                    });
 106                } else {
 107                    break;
 108                }
 109            }
 110        });
 111
 112        Self {
 113            client,
 114            state,
 115            llm_api_token: LlmApiToken::default(),
 116            _maintain_client_status: maintain_client_status,
 117        }
 118    }
 119}
 120
 121impl LanguageModelProviderState for CloudLanguageModelProvider {
 122    type ObservableEntity = State;
 123
 124    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 125        Some(self.state.clone())
 126    }
 127}
 128
 129impl LanguageModelProvider for CloudLanguageModelProvider {
 130    fn id(&self) -> LanguageModelProviderId {
 131        LanguageModelProviderId(PROVIDER_ID.into())
 132    }
 133
 134    fn name(&self) -> LanguageModelProviderName {
 135        LanguageModelProviderName(PROVIDER_NAME.into())
 136    }
 137
 138    fn icon(&self) -> IconName {
 139        IconName::AiZed
 140    }
 141
 142    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 143        let mut models = BTreeMap::default();
 144
 145        let is_user = !cx.has_flag::<LanguageModels>();
 146        if is_user {
 147            models.insert(
 148                anthropic::Model::Claude3_5Sonnet.id().to_string(),
 149                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
 150            );
 151        } else {
 152            for model in anthropic::Model::iter() {
 153                if !matches!(model, anthropic::Model::Custom { .. }) {
 154                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
 155                }
 156            }
 157            for model in open_ai::Model::iter() {
 158                if !matches!(model, open_ai::Model::Custom { .. }) {
 159                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
 160                }
 161            }
 162            for model in google_ai::Model::iter() {
 163                if !matches!(model, google_ai::Model::Custom { .. }) {
 164                    models.insert(model.id().to_string(), CloudModel::Google(model));
 165                }
 166            }
 167            for model in ZedModel::iter() {
 168                models.insert(model.id().to_string(), CloudModel::Zed(model));
 169            }
 170
 171            // Override with available models from settings
 172            for model in &AllLanguageModelSettings::get_global(cx)
 173                .zed_dot_dev
 174                .available_models
 175            {
 176                let model = match model.provider {
 177                    AvailableProvider::Anthropic => {
 178                        CloudModel::Anthropic(anthropic::Model::Custom {
 179                            name: model.name.clone(),
 180                            max_tokens: model.max_tokens,
 181                            tool_override: model.tool_override.clone(),
 182                        })
 183                    }
 184                    AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
 185                        name: model.name.clone(),
 186                        max_tokens: model.max_tokens,
 187                    }),
 188                    AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
 189                        name: model.name.clone(),
 190                        max_tokens: model.max_tokens,
 191                    }),
 192                };
 193                models.insert(model.id().to_string(), model.clone());
 194            }
 195        }
 196
 197        models
 198            .into_values()
 199            .map(|model| {
 200                Arc::new(CloudLanguageModel {
 201                    id: LanguageModelId::from(model.id().to_string()),
 202                    model,
 203                    llm_api_token: self.llm_api_token.clone(),
 204                    client: self.client.clone(),
 205                    request_limiter: RateLimiter::new(4),
 206                }) as Arc<dyn LanguageModel>
 207            })
 208            .collect()
 209    }
 210
 211    fn is_authenticated(&self, cx: &AppContext) -> bool {
 212        !self.state.read(cx).is_signed_out()
 213    }
 214
 215    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
 216        Task::ready(Ok(()))
 217    }
 218
 219    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
 220        cx.new_view(|_cx| ConfigurationView {
 221            state: self.state.clone(),
 222        })
 223        .into()
 224    }
 225
 226    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
 227        Task::ready(Ok(()))
 228    }
 229}
 230
 231struct LlmServiceFeatureFlag;
 232
 233impl FeatureFlag for LlmServiceFeatureFlag {
 234    const NAME: &'static str = "llm-service";
 235
 236    fn enabled_for_staff() -> bool {
 237        false
 238    }
 239}
 240
 241pub struct CloudLanguageModel {
 242    id: LanguageModelId,
 243    model: CloudModel,
 244    llm_api_token: LlmApiToken,
 245    client: Arc<Client>,
 246    request_limiter: RateLimiter,
 247}
 248
 249#[derive(Clone, Default)]
 250struct LlmApiToken(Arc<RwLock<Option<String>>>);
 251
 252impl CloudLanguageModel {
 253    async fn perform_llm_completion(
 254        client: Arc<Client>,
 255        llm_api_token: LlmApiToken,
 256        body: PerformCompletionParams,
 257    ) -> Result<Response<AsyncBody>> {
 258        let http_client = &client.http_client();
 259
 260        let mut token = llm_api_token.acquire(&client).await?;
 261        let mut did_retry = false;
 262
 263        let response = loop {
 264            let request = http_client::Request::builder()
 265                .method(Method::POST)
 266                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
 267                .header("Content-Type", "application/json")
 268                .header("Authorization", format!("Bearer {token}"))
 269                .body(serde_json::to_string(&body)?.into())?;
 270            let response = http_client.send(request).await?;
 271            if response.status().is_success() {
 272                break response;
 273            } else if !did_retry
 274                && response
 275                    .headers()
 276                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 277                    .is_some()
 278            {
 279                did_retry = true;
 280                token = llm_api_token.refresh(&client).await?;
 281            } else {
 282                break Err(anyhow!(
 283                    "cloud language model completion failed with status {}",
 284                    response.status()
 285                ))?;
 286            }
 287        };
 288
 289        Ok(response)
 290    }
 291}
 292
 293impl LanguageModel for CloudLanguageModel {
 294    fn id(&self) -> LanguageModelId {
 295        self.id.clone()
 296    }
 297
 298    fn name(&self) -> LanguageModelName {
 299        LanguageModelName::from(self.model.display_name().to_string())
 300    }
 301
 302    fn provider_id(&self) -> LanguageModelProviderId {
 303        LanguageModelProviderId(PROVIDER_ID.into())
 304    }
 305
 306    fn provider_name(&self) -> LanguageModelProviderName {
 307        LanguageModelProviderName(PROVIDER_NAME.into())
 308    }
 309
 310    fn telemetry_id(&self) -> String {
 311        format!("zed.dev/{}", self.model.id())
 312    }
 313
 314    fn availability(&self) -> LanguageModelAvailability {
 315        self.model.availability()
 316    }
 317
 318    fn max_token_count(&self) -> usize {
 319        self.model.max_token_count()
 320    }
 321
 322    fn count_tokens(
 323        &self,
 324        request: LanguageModelRequest,
 325        cx: &AppContext,
 326    ) -> BoxFuture<'static, Result<usize>> {
 327        match self.model.clone() {
 328            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
 329            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
 330            CloudModel::Google(model) => {
 331                let client = self.client.clone();
 332                let request = request.into_google(model.id().into());
 333                let request = google_ai::CountTokensRequest {
 334                    contents: request.contents,
 335                };
 336                async move {
 337                    let request = serde_json::to_string(&request)?;
 338                    let response = client
 339                        .request(proto::CountLanguageModelTokens {
 340                            provider: proto::LanguageModelProvider::Google as i32,
 341                            request,
 342                        })
 343                        .await?;
 344                    Ok(response.token_count as usize)
 345                }
 346                .boxed()
 347            }
 348            CloudModel::Zed(_) => {
 349                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
 350            }
 351        }
 352    }
 353
 354    fn stream_completion(
 355        &self,
 356        request: LanguageModelRequest,
 357        cx: &AsyncAppContext,
 358    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 359        match &self.model {
 360            CloudModel::Anthropic(model) => {
 361                let request = request.into_anthropic(model.id().into());
 362                let client = self.client.clone();
 363
 364                if cx
 365                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 366                    .unwrap_or(false)
 367                {
 368                    let llm_api_token = self.llm_api_token.clone();
 369                    let future = self.request_limiter.stream(async move {
 370                        let response = Self::perform_llm_completion(
 371                            client.clone(),
 372                            llm_api_token,
 373                            PerformCompletionParams {
 374                                provider: client::LanguageModelProvider::Anthropic,
 375                                model: request.model.clone(),
 376                                provider_request: RawValue::from_string(serde_json::to_string(
 377                                    &request,
 378                                )?)?,
 379                            },
 380                        )
 381                        .await?;
 382                        let body = BufReader::new(response.into_body());
 383                        let stream =
 384                            futures::stream::try_unfold(body, move |mut body| async move {
 385                                let mut buffer = String::new();
 386                                match body.read_line(&mut buffer).await {
 387                                    Ok(0) => Ok(None),
 388                                    Ok(_) => {
 389                                        let event: anthropic::Event =
 390                                            serde_json::from_str(&buffer)?;
 391                                        Ok(Some((event, body)))
 392                                    }
 393                                    Err(e) => Err(e.into()),
 394                                }
 395                            });
 396
 397                        Ok(anthropic::extract_text_from_events(stream))
 398                    });
 399                    async move { Ok(future.await?.boxed()) }.boxed()
 400                } else {
 401                    let future = self.request_limiter.stream(async move {
 402                        let request = serde_json::to_string(&request)?;
 403                        let stream = client
 404                            .request_stream(proto::StreamCompleteWithLanguageModel {
 405                                provider: proto::LanguageModelProvider::Anthropic as i32,
 406                                request,
 407                            })
 408                            .await?
 409                            .map(|event| Ok(serde_json::from_str(&event?.event)?));
 410                        Ok(anthropic::extract_text_from_events(stream))
 411                    });
 412                    async move { Ok(future.await?.boxed()) }.boxed()
 413                }
 414            }
 415            CloudModel::OpenAi(model) => {
 416                let client = self.client.clone();
 417                let request = request.into_open_ai(model.id().into());
 418
 419                if cx
 420                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 421                    .unwrap_or(false)
 422                {
 423                    let llm_api_token = self.llm_api_token.clone();
 424                    let future = self.request_limiter.stream(async move {
 425                        let response = Self::perform_llm_completion(
 426                            client.clone(),
 427                            llm_api_token,
 428                            PerformCompletionParams {
 429                                provider: client::LanguageModelProvider::OpenAi,
 430                                model: request.model.clone(),
 431                                provider_request: RawValue::from_string(serde_json::to_string(
 432                                    &request,
 433                                )?)?,
 434                            },
 435                        )
 436                        .await?;
 437                        let body = BufReader::new(response.into_body());
 438                        let stream =
 439                            futures::stream::try_unfold(body, move |mut body| async move {
 440                                let mut buffer = String::new();
 441                                match body.read_line(&mut buffer).await {
 442                                    Ok(0) => Ok(None),
 443                                    Ok(_) => {
 444                                        let event: open_ai::ResponseStreamEvent =
 445                                            serde_json::from_str(&buffer)?;
 446                                        Ok(Some((event, body)))
 447                                    }
 448                                    Err(e) => Err(e.into()),
 449                                }
 450                            });
 451
 452                        Ok(open_ai::extract_text_from_events(stream))
 453                    });
 454                    async move { Ok(future.await?.boxed()) }.boxed()
 455                } else {
 456                    let future = self.request_limiter.stream(async move {
 457                        let request = serde_json::to_string(&request)?;
 458                        let stream = client
 459                            .request_stream(proto::StreamCompleteWithLanguageModel {
 460                                provider: proto::LanguageModelProvider::OpenAi as i32,
 461                                request,
 462                            })
 463                            .await?;
 464                        Ok(open_ai::extract_text_from_events(
 465                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
 466                        ))
 467                    });
 468                    async move { Ok(future.await?.boxed()) }.boxed()
 469                }
 470            }
 471            CloudModel::Google(model) => {
 472                let client = self.client.clone();
 473                let request = request.into_google(model.id().into());
 474
 475                if cx
 476                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 477                    .unwrap_or(false)
 478                {
 479                    let llm_api_token = self.llm_api_token.clone();
 480                    let future = self.request_limiter.stream(async move {
 481                        let response = Self::perform_llm_completion(
 482                            client.clone(),
 483                            llm_api_token,
 484                            PerformCompletionParams {
 485                                provider: client::LanguageModelProvider::Google,
 486                                model: request.model.clone(),
 487                                provider_request: RawValue::from_string(serde_json::to_string(
 488                                    &request,
 489                                )?)?,
 490                            },
 491                        )
 492                        .await?;
 493                        let body = BufReader::new(response.into_body());
 494                        let stream =
 495                            futures::stream::try_unfold(body, move |mut body| async move {
 496                                let mut buffer = String::new();
 497                                match body.read_line(&mut buffer).await {
 498                                    Ok(0) => Ok(None),
 499                                    Ok(_) => {
 500                                        let event: google_ai::GenerateContentResponse =
 501                                            serde_json::from_str(&buffer)?;
 502                                        Ok(Some((event, body)))
 503                                    }
 504                                    Err(e) => Err(e.into()),
 505                                }
 506                            });
 507
 508                        Ok(google_ai::extract_text_from_events(stream))
 509                    });
 510                    async move { Ok(future.await?.boxed()) }.boxed()
 511                } else {
 512                    let future = self.request_limiter.stream(async move {
 513                        let request = serde_json::to_string(&request)?;
 514                        let stream = client
 515                            .request_stream(proto::StreamCompleteWithLanguageModel {
 516                                provider: proto::LanguageModelProvider::Google as i32,
 517                                request,
 518                            })
 519                            .await?;
 520                        Ok(google_ai::extract_text_from_events(
 521                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
 522                        ))
 523                    });
 524                    async move { Ok(future.await?.boxed()) }.boxed()
 525                }
 526            }
 527            CloudModel::Zed(model) => {
 528                let client = self.client.clone();
 529                let mut request = request.into_open_ai(model.id().into());
 530                request.max_tokens = Some(4000);
 531
 532                if cx
 533                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 534                    .unwrap_or(false)
 535                {
 536                    let llm_api_token = self.llm_api_token.clone();
 537                    let future = self.request_limiter.stream(async move {
 538                        let response = Self::perform_llm_completion(
 539                            client.clone(),
 540                            llm_api_token,
 541                            PerformCompletionParams {
 542                                provider: client::LanguageModelProvider::Zed,
 543                                model: request.model.clone(),
 544                                provider_request: RawValue::from_string(serde_json::to_string(
 545                                    &request,
 546                                )?)?,
 547                            },
 548                        )
 549                        .await?;
 550                        let body = BufReader::new(response.into_body());
 551                        let stream =
 552                            futures::stream::try_unfold(body, move |mut body| async move {
 553                                let mut buffer = String::new();
 554                                match body.read_line(&mut buffer).await {
 555                                    Ok(0) => Ok(None),
 556                                    Ok(_) => {
 557                                        let event: open_ai::ResponseStreamEvent =
 558                                            serde_json::from_str(&buffer)?;
 559                                        Ok(Some((event, body)))
 560                                    }
 561                                    Err(e) => Err(e.into()),
 562                                }
 563                            });
 564
 565                        Ok(open_ai::extract_text_from_events(stream))
 566                    });
 567                    async move { Ok(future.await?.boxed()) }.boxed()
 568                } else {
 569                    let future = self.request_limiter.stream(async move {
 570                        let request = serde_json::to_string(&request)?;
 571                        let stream = client
 572                            .request_stream(proto::StreamCompleteWithLanguageModel {
 573                                provider: proto::LanguageModelProvider::Zed as i32,
 574                                request,
 575                            })
 576                            .await?;
 577                        Ok(open_ai::extract_text_from_events(
 578                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
 579                        ))
 580                    });
 581                    async move { Ok(future.await?.boxed()) }.boxed()
 582                }
 583            }
 584        }
 585    }
 586
 587    fn use_any_tool(
 588        &self,
 589        request: LanguageModelRequest,
 590        tool_name: String,
 591        tool_description: String,
 592        input_schema: serde_json::Value,
 593        cx: &AsyncAppContext,
 594    ) -> BoxFuture<'static, Result<serde_json::Value>> {
 595        match &self.model {
 596            CloudModel::Anthropic(model) => {
 597                let client = self.client.clone();
 598                let mut request = request.into_anthropic(model.tool_model_id().into());
 599                request.tool_choice = Some(anthropic::ToolChoice::Tool {
 600                    name: tool_name.clone(),
 601                });
 602                request.tools = vec![anthropic::Tool {
 603                    name: tool_name.clone(),
 604                    description: tool_description,
 605                    input_schema,
 606                }];
 607
 608                if cx
 609                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 610                    .unwrap_or(false)
 611                {
 612                    let llm_api_token = self.llm_api_token.clone();
 613                    self.request_limiter
 614                        .run(async move {
 615                            let response = Self::perform_llm_completion(
 616                                client.clone(),
 617                                llm_api_token,
 618                                PerformCompletionParams {
 619                                    provider: client::LanguageModelProvider::Anthropic,
 620                                    model: request.model.clone(),
 621                                    provider_request: RawValue::from_string(
 622                                        serde_json::to_string(&request)?,
 623                                    )?,
 624                                },
 625                            )
 626                            .await?;
 627
 628                            let mut tool_use_index = None;
 629                            let mut tool_input = String::new();
 630                            let mut body = BufReader::new(response.into_body());
 631                            let mut line = String::new();
 632                            while body.read_line(&mut line).await? > 0 {
 633                                let event: anthropic::Event = serde_json::from_str(&line)?;
 634                                line.clear();
 635
 636                                match event {
 637                                    anthropic::Event::ContentBlockStart {
 638                                        content_block,
 639                                        index,
 640                                    } => {
 641                                        if let anthropic::Content::ToolUse { name, .. } =
 642                                            content_block
 643                                        {
 644                                            if name == tool_name {
 645                                                tool_use_index = Some(index);
 646                                            }
 647                                        }
 648                                    }
 649                                    anthropic::Event::ContentBlockDelta { index, delta } => {
 650                                        match delta {
 651                                            anthropic::ContentDelta::TextDelta { .. } => {}
 652                                            anthropic::ContentDelta::InputJsonDelta {
 653                                                partial_json,
 654                                            } => {
 655                                                if Some(index) == tool_use_index {
 656                                                    tool_input.push_str(&partial_json);
 657                                                }
 658                                            }
 659                                        }
 660                                    }
 661                                    anthropic::Event::ContentBlockStop { index } => {
 662                                        if Some(index) == tool_use_index {
 663                                            return Ok(serde_json::from_str(&tool_input)?);
 664                                        }
 665                                    }
 666                                    _ => {}
 667                                }
 668                            }
 669
 670                            if tool_use_index.is_some() {
 671                                Err(anyhow!("tool content incomplete"))
 672                            } else {
 673                                Err(anyhow!("tool not used"))
 674                            }
 675                        })
 676                        .boxed()
 677                } else {
 678                    self.request_limiter
 679                        .run(async move {
 680                            let request = serde_json::to_string(&request)?;
 681                            let response = client
 682                                .request(proto::CompleteWithLanguageModel {
 683                                    provider: proto::LanguageModelProvider::Anthropic as i32,
 684                                    request,
 685                                })
 686                                .await?;
 687                            let response: anthropic::Response =
 688                                serde_json::from_str(&response.completion)?;
 689                            response
 690                                .content
 691                                .into_iter()
 692                                .find_map(|content| {
 693                                    if let anthropic::Content::ToolUse { name, input, .. } = content
 694                                    {
 695                                        if name == tool_name {
 696                                            Some(input)
 697                                        } else {
 698                                            None
 699                                        }
 700                                    } else {
 701                                        None
 702                                    }
 703                                })
 704                                .context("tool not used")
 705                        })
 706                        .boxed()
 707                }
 708            }
 709            CloudModel::OpenAi(model) => {
 710                let mut request = request.into_open_ai(model.id().into());
 711                let client = self.client.clone();
 712                let mut function = open_ai::FunctionDefinition {
 713                    name: tool_name.clone(),
 714                    description: None,
 715                    parameters: None,
 716                };
 717                let func = open_ai::ToolDefinition::Function {
 718                    function: function.clone(),
 719                };
 720                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
 721                // Fill in description and params separately, as they're not needed for tool_choice field.
 722                function.description = Some(tool_description);
 723                function.parameters = Some(input_schema);
 724                request.tools = vec![open_ai::ToolDefinition::Function { function }];
 725
 726                if cx
 727                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 728                    .unwrap_or(false)
 729                {
 730                    let llm_api_token = self.llm_api_token.clone();
 731                    self.request_limiter
 732                        .run(async move {
 733                            let response = Self::perform_llm_completion(
 734                                client.clone(),
 735                                llm_api_token,
 736                                PerformCompletionParams {
 737                                    provider: client::LanguageModelProvider::OpenAi,
 738                                    model: request.model.clone(),
 739                                    provider_request: RawValue::from_string(
 740                                        serde_json::to_string(&request)?,
 741                                    )?,
 742                                },
 743                            )
 744                            .await?;
 745
 746                            let mut body = BufReader::new(response.into_body());
 747                            let mut line = String::new();
 748                            let mut load_state = None;
 749
 750                            while body.read_line(&mut line).await? > 0 {
 751                                let part: open_ai::ResponseStreamEvent =
 752                                    serde_json::from_str(&line)?;
 753                                line.clear();
 754
 755                                for choice in part.choices {
 756                                    let Some(tool_calls) = choice.delta.tool_calls else {
 757                                        continue;
 758                                    };
 759
 760                                    for call in tool_calls {
 761                                        if let Some(func) = call.function {
 762                                            if func.name.as_deref() == Some(tool_name.as_str()) {
 763                                                load_state = Some((String::default(), call.index));
 764                                            }
 765                                            if let Some((arguments, (output, index))) =
 766                                                func.arguments.zip(load_state.as_mut())
 767                                            {
 768                                                if call.index == *index {
 769                                                    output.push_str(&arguments);
 770                                                }
 771                                            }
 772                                        }
 773                                    }
 774                                }
 775                            }
 776
 777                            if let Some((arguments, _)) = load_state {
 778                                return Ok(serde_json::from_str(&arguments)?);
 779                            } else {
 780                                bail!("tool not used");
 781                            }
 782                        })
 783                        .boxed()
 784                } else {
 785                    self.request_limiter
 786                        .run(async move {
 787                            let request = serde_json::to_string(&request)?;
 788                            let response = client
 789                                .request_stream(proto::StreamCompleteWithLanguageModel {
 790                                    provider: proto::LanguageModelProvider::OpenAi as i32,
 791                                    request,
 792                                })
 793                                .await?;
 794                            let mut load_state = None;
 795                            let mut response = response.map(
 796                                |item: Result<
 797                                    proto::StreamCompleteWithLanguageModelResponse,
 798                                    anyhow::Error,
 799                                >| {
 800                                    Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
 801                                        serde_json::from_str(&item?.event)?,
 802                                    )
 803                                },
 804                            );
 805                            while let Some(Ok(part)) = response.next().await {
 806                                for choice in part.choices {
 807                                    let Some(tool_calls) = choice.delta.tool_calls else {
 808                                        continue;
 809                                    };
 810
 811                                    for call in tool_calls {
 812                                        if let Some(func) = call.function {
 813                                            if func.name.as_deref() == Some(tool_name.as_str()) {
 814                                                load_state = Some((String::default(), call.index));
 815                                            }
 816                                            if let Some((arguments, (output, index))) =
 817                                                func.arguments.zip(load_state.as_mut())
 818                                            {
 819                                                if call.index == *index {
 820                                                    output.push_str(&arguments);
 821                                                }
 822                                            }
 823                                        }
 824                                    }
 825                                }
 826                            }
 827                            if let Some((arguments, _)) = load_state {
 828                                return Ok(serde_json::from_str(&arguments)?);
 829                            } else {
 830                                bail!("tool not used");
 831                            }
 832                        })
 833                        .boxed()
 834                }
 835            }
 836            CloudModel::Google(_) => {
 837                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
 838            }
 839            CloudModel::Zed(model) => {
 840                // All Zed models are OpenAI-based at the time of writing.
 841                let mut request = request.into_open_ai(model.id().into());
 842                let client = self.client.clone();
 843                let mut function = open_ai::FunctionDefinition {
 844                    name: tool_name.clone(),
 845                    description: None,
 846                    parameters: None,
 847                };
 848                let func = open_ai::ToolDefinition::Function {
 849                    function: function.clone(),
 850                };
 851                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
 852                // Fill in description and params separately, as they're not needed for tool_choice field.
 853                function.description = Some(tool_description);
 854                function.parameters = Some(input_schema);
 855                request.tools = vec![open_ai::ToolDefinition::Function { function }];
 856
 857                if cx
 858                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
 859                    .unwrap_or(false)
 860                {
 861                    let llm_api_token = self.llm_api_token.clone();
 862                    self.request_limiter
 863                        .run(async move {
 864                            let response = Self::perform_llm_completion(
 865                                client.clone(),
 866                                llm_api_token,
 867                                PerformCompletionParams {
 868                                    provider: client::LanguageModelProvider::Zed,
 869                                    model: request.model.clone(),
 870                                    provider_request: RawValue::from_string(
 871                                        serde_json::to_string(&request)?,
 872                                    )?,
 873                                },
 874                            )
 875                            .await?;
 876
 877                            let mut body = BufReader::new(response.into_body());
 878                            let mut line = String::new();
 879                            let mut load_state = None;
 880
 881                            while body.read_line(&mut line).await? > 0 {
 882                                let part: open_ai::ResponseStreamEvent =
 883                                    serde_json::from_str(&line)?;
 884                                line.clear();
 885
 886                                for choice in part.choices {
 887                                    let Some(tool_calls) = choice.delta.tool_calls else {
 888                                        continue;
 889                                    };
 890
 891                                    for call in tool_calls {
 892                                        if let Some(func) = call.function {
 893                                            if func.name.as_deref() == Some(tool_name.as_str()) {
 894                                                load_state = Some((String::default(), call.index));
 895                                            }
 896                                            if let Some((arguments, (output, index))) =
 897                                                func.arguments.zip(load_state.as_mut())
 898                                            {
 899                                                if call.index == *index {
 900                                                    output.push_str(&arguments);
 901                                                }
 902                                            }
 903                                        }
 904                                    }
 905                                }
 906                            }
 907                            if let Some((arguments, _)) = load_state {
 908                                return Ok(serde_json::from_str(&arguments)?);
 909                            } else {
 910                                bail!("tool not used");
 911                            }
 912                        })
 913                        .boxed()
 914                } else {
 915                    self.request_limiter
 916                        .run(async move {
 917                            let request = serde_json::to_string(&request)?;
 918                            let response = client
 919                                .request_stream(proto::StreamCompleteWithLanguageModel {
 920                                    provider: proto::LanguageModelProvider::OpenAi as i32,
 921                                    request,
 922                                })
 923                                .await?;
 924                            let mut load_state = None;
 925                            let mut response = response.map(
 926                                |item: Result<
 927                                    proto::StreamCompleteWithLanguageModelResponse,
 928                                    anyhow::Error,
 929                                >| {
 930                                    Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
 931                                        serde_json::from_str(&item?.event)?,
 932                                    )
 933                                },
 934                            );
 935                            while let Some(Ok(part)) = response.next().await {
 936                                for choice in part.choices {
 937                                    let Some(tool_calls) = choice.delta.tool_calls else {
 938                                        continue;
 939                                    };
 940
 941                                    for call in tool_calls {
 942                                        if let Some(func) = call.function {
 943                                            if func.name.as_deref() == Some(tool_name.as_str()) {
 944                                                load_state = Some((String::default(), call.index));
 945                                            }
 946                                            if let Some((arguments, (output, index))) =
 947                                                func.arguments.zip(load_state.as_mut())
 948                                            {
 949                                                if call.index == *index {
 950                                                    output.push_str(&arguments);
 951                                                }
 952                                            }
 953                                        }
 954                                    }
 955                                }
 956                            }
 957                            if let Some((arguments, _)) = load_state {
 958                                return Ok(serde_json::from_str(&arguments)?);
 959                            } else {
 960                                bail!("tool not used");
 961                            }
 962                        })
 963                        .boxed()
 964                }
 965            }
 966        }
 967    }
 968}
 969
 970impl LlmApiToken {
 971    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 972        let lock = self.0.upgradable_read().await;
 973        if let Some(token) = lock.as_ref() {
 974            Ok(token.to_string())
 975        } else {
 976            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
 977        }
 978    }
 979
 980    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 981        Self::fetch(self.0.write().await, &client).await
 982    }
 983
 984    async fn fetch<'a>(
 985        mut lock: RwLockWriteGuard<'a, Option<String>>,
 986        client: &Arc<Client>,
 987    ) -> Result<String> {
 988        let response = client.request(proto::GetLlmToken {}).await?;
 989        *lock = Some(response.token.clone());
 990        Ok(response.token.clone())
 991    }
 992}
 993
 994struct ConfigurationView {
 995    state: gpui::Model<State>,
 996}
 997
 998impl ConfigurationView {
 999    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
1000        self.state.update(cx, |state, cx| {
1001            state.authenticate(cx).detach_and_log_err(cx);
1002        });
1003        cx.notify();
1004    }
1005}
1006
1007impl Render for ConfigurationView {
1008    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1009        const ZED_AI_URL: &str = "https://zed.dev/ai";
1010        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
1011
1012        let is_connected = !self.state.read(cx).is_signed_out();
1013        let plan = self.state.read(cx).user_store.read(cx).current_plan();
1014
1015        let is_pro = plan == Some(proto::Plan::ZedPro);
1016
1017        if is_connected {
1018            v_flex()
1019                .gap_3()
1020                .max_w_4_5()
1021                .child(Label::new(
1022                    if is_pro {
1023                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
1024                    } else {
1025                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
1026                    }))
1027                .child(
1028                    if is_pro {
1029                        h_flex().child(
1030                        Button::new("manage_settings", "Manage Subscription")
1031                            .style(ButtonStyle::Filled)
1032                            .on_click(cx.listener(|_, _, cx| {
1033                                cx.open_url(ACCOUNT_SETTINGS_URL)
1034                            })))
1035                    } else {
1036                        h_flex()
1037                            .gap_2()
1038                            .child(
1039                        Button::new("learn_more", "Learn more")
1040                            .style(ButtonStyle::Subtle)
1041                            .on_click(cx.listener(|_, _, cx| {
1042                                cx.open_url(ZED_AI_URL)
1043                            })))
1044                            .child(
1045                        Button::new("upgrade", "Upgrade")
1046                            .style(ButtonStyle::Subtle)
1047                            .color(Color::Accent)
1048                            .on_click(cx.listener(|_, _, cx| {
1049                                cx.open_url(ACCOUNT_SETTINGS_URL)
1050                            })))
1051                    },
1052                )
1053        } else {
1054            v_flex()
1055                .gap_6()
1056                .child(Label::new("Use the zed.dev to access language models."))
1057                .child(
1058                    v_flex()
1059                        .gap_2()
1060                        .child(
1061                            Button::new("sign_in", "Sign in")
1062                                .icon_color(Color::Muted)
1063                                .icon(IconName::Github)
1064                                .icon_position(IconPosition::Start)
1065                                .style(ButtonStyle::Filled)
1066                                .full_width()
1067                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
1068                        )
1069                        .child(
1070                            div().flex().w_full().items_center().child(
1071                                Label::new("Sign in to enable collaboration.")
1072                                    .color(Color::Muted)
1073                                    .size(LabelSize::Small),
1074                            ),
1075                        ),
1076                )
1077        }
1078    }
1079}