cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anyhow::{anyhow, bail, Context as _, Result};
  8use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
  9use collections::BTreeMap;
 10use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
 11use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 12use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
 13use http_client::{AsyncBody, HttpClient, Method, Response};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use serde_json::value::RawValue;
 17use settings::{Settings, SettingsStore};
 18use smol::{
 19    io::BufReader,
 20    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 21};
 22use std::{future, sync::Arc};
 23use strum::IntoEnumIterator;
 24use ui::prelude::*;
 25
 26use crate::{LanguageModelAvailability, LanguageModelProvider};
 27
 28use super::anthropic::count_anthropic_tokens;
 29
 30pub const PROVIDER_ID: &str = "zed.dev";
 31pub const PROVIDER_NAME: &str = "Zed";
 32
 33#[derive(Default, Clone, Debug, PartialEq)]
 34pub struct ZedDotDevSettings {
 35    pub available_models: Vec<AvailableModel>,
 36}
 37
 38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 39#[serde(rename_all = "lowercase")]
 40pub enum AvailableProvider {
 41    Anthropic,
 42    OpenAi,
 43    Google,
 44}
 45
 46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 47pub struct AvailableModel {
 48    provider: AvailableProvider,
 49    name: String,
 50    max_tokens: usize,
 51    tool_override: Option<String>,
 52}
 53
 54pub struct CloudLanguageModelProvider {
 55    client: Arc<Client>,
 56    llm_api_token: LlmApiToken,
 57    state: gpui::Model<State>,
 58    _maintain_client_status: Task<()>,
 59}
 60
 61pub struct State {
 62    client: Arc<Client>,
 63    user_store: Model<UserStore>,
 64    status: client::Status,
 65    _subscription: Subscription,
 66}
 67
 68impl State {
 69    fn is_signed_out(&self) -> bool {
 70        self.status.is_signed_out()
 71    }
 72
 73    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 74        let client = self.client.clone();
 75        cx.spawn(move |this, mut cx| async move {
 76            client.authenticate_and_connect(true, &cx).await?;
 77            this.update(&mut cx, |_, cx| cx.notify())
 78        })
 79    }
 80}
 81
 82impl CloudLanguageModelProvider {
 83    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
 84        let mut status_rx = client.status();
 85        let status = *status_rx.borrow();
 86
 87        let state = cx.new_model(|cx| State {
 88            client: client.clone(),
 89            user_store,
 90            status,
 91            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 92                cx.notify();
 93            }),
 94        });
 95
 96        let state_ref = state.downgrade();
 97        let maintain_client_status = cx.spawn(|mut cx| async move {
 98            while let Some(status) = status_rx.next().await {
 99                if let Some(this) = state_ref.upgrade() {
100                    _ = this.update(&mut cx, |this, cx| {
101                        if this.status != status {
102                            this.status = status;
103                            cx.notify();
104                        }
105                    });
106                } else {
107                    break;
108                }
109            }
110        });
111
112        Self {
113            client,
114            state,
115            llm_api_token: LlmApiToken::default(),
116            _maintain_client_status: maintain_client_status,
117        }
118    }
119}
120
121impl LanguageModelProviderState for CloudLanguageModelProvider {
122    type ObservableEntity = State;
123
124    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
125        Some(self.state.clone())
126    }
127}
128
129impl LanguageModelProvider for CloudLanguageModelProvider {
130    fn id(&self) -> LanguageModelProviderId {
131        LanguageModelProviderId(PROVIDER_ID.into())
132    }
133
134    fn name(&self) -> LanguageModelProviderName {
135        LanguageModelProviderName(PROVIDER_NAME.into())
136    }
137
138    fn icon(&self) -> IconName {
139        IconName::AiZed
140    }
141
142    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
143        let mut models = BTreeMap::default();
144
145        let is_user = !cx.has_flag::<LanguageModels>();
146        if is_user {
147            models.insert(
148                anthropic::Model::Claude3_5Sonnet.id().to_string(),
149                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
150            );
151        } else {
152            for model in anthropic::Model::iter() {
153                if !matches!(model, anthropic::Model::Custom { .. }) {
154                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
155                }
156            }
157            for model in open_ai::Model::iter() {
158                if !matches!(model, open_ai::Model::Custom { .. }) {
159                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
160                }
161            }
162            for model in google_ai::Model::iter() {
163                if !matches!(model, google_ai::Model::Custom { .. }) {
164                    models.insert(model.id().to_string(), CloudModel::Google(model));
165                }
166            }
167            for model in ZedModel::iter() {
168                models.insert(model.id().to_string(), CloudModel::Zed(model));
169            }
170
171            // Override with available models from settings
172            for model in &AllLanguageModelSettings::get_global(cx)
173                .zed_dot_dev
174                .available_models
175            {
176                let model = match model.provider {
177                    AvailableProvider::Anthropic => {
178                        CloudModel::Anthropic(anthropic::Model::Custom {
179                            name: model.name.clone(),
180                            max_tokens: model.max_tokens,
181                            tool_override: model.tool_override.clone(),
182                        })
183                    }
184                    AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
185                        name: model.name.clone(),
186                        max_tokens: model.max_tokens,
187                    }),
188                    AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
189                        name: model.name.clone(),
190                        max_tokens: model.max_tokens,
191                    }),
192                };
193                models.insert(model.id().to_string(), model.clone());
194            }
195        }
196
197        models
198            .into_values()
199            .map(|model| {
200                Arc::new(CloudLanguageModel {
201                    id: LanguageModelId::from(model.id().to_string()),
202                    model,
203                    llm_api_token: self.llm_api_token.clone(),
204                    client: self.client.clone(),
205                    request_limiter: RateLimiter::new(4),
206                }) as Arc<dyn LanguageModel>
207            })
208            .collect()
209    }
210
211    fn is_authenticated(&self, cx: &AppContext) -> bool {
212        !self.state.read(cx).is_signed_out()
213    }
214
215    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
216        Task::ready(Ok(()))
217    }
218
219    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
220        cx.new_view(|_cx| ConfigurationView {
221            state: self.state.clone(),
222        })
223        .into()
224    }
225
226    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
227        Task::ready(Ok(()))
228    }
229}
230
231struct LlmServiceFeatureFlag;
232
233impl FeatureFlag for LlmServiceFeatureFlag {
234    const NAME: &'static str = "llm-service";
235
236    fn enabled_for_staff() -> bool {
237        false
238    }
239}
240
241pub struct CloudLanguageModel {
242    id: LanguageModelId,
243    model: CloudModel,
244    llm_api_token: LlmApiToken,
245    client: Arc<Client>,
246    request_limiter: RateLimiter,
247}
248
249#[derive(Clone, Default)]
250struct LlmApiToken(Arc<RwLock<Option<String>>>);
251
252impl CloudLanguageModel {
253    async fn perform_llm_completion(
254        client: Arc<Client>,
255        llm_api_token: LlmApiToken,
256        body: PerformCompletionParams,
257    ) -> Result<Response<AsyncBody>> {
258        let http_client = &client.http_client();
259
260        let mut token = llm_api_token.acquire(&client).await?;
261        let mut did_retry = false;
262
263        let response = loop {
264            let request = http_client::Request::builder()
265                .method(Method::POST)
266                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
267                .header("Content-Type", "application/json")
268                .header("Authorization", format!("Bearer {token}"))
269                .body(serde_json::to_string(&body)?.into())?;
270            let response = http_client.send(request).await?;
271            if response.status().is_success() {
272                break response;
273            } else if !did_retry
274                && response
275                    .headers()
276                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
277                    .is_some()
278            {
279                did_retry = true;
280                token = llm_api_token.refresh(&client).await?;
281            } else {
282                break Err(anyhow!(
283                    "cloud language model completion failed with status {}",
284                    response.status()
285                ))?;
286            }
287        };
288
289        Ok(response)
290    }
291}
292
293impl LanguageModel for CloudLanguageModel {
294    fn id(&self) -> LanguageModelId {
295        self.id.clone()
296    }
297
298    fn name(&self) -> LanguageModelName {
299        LanguageModelName::from(self.model.display_name().to_string())
300    }
301
302    fn provider_id(&self) -> LanguageModelProviderId {
303        LanguageModelProviderId(PROVIDER_ID.into())
304    }
305
306    fn provider_name(&self) -> LanguageModelProviderName {
307        LanguageModelProviderName(PROVIDER_NAME.into())
308    }
309
310    fn telemetry_id(&self) -> String {
311        format!("zed.dev/{}", self.model.id())
312    }
313
314    fn availability(&self) -> LanguageModelAvailability {
315        self.model.availability()
316    }
317
318    fn max_token_count(&self) -> usize {
319        self.model.max_token_count()
320    }
321
322    fn count_tokens(
323        &self,
324        request: LanguageModelRequest,
325        cx: &AppContext,
326    ) -> BoxFuture<'static, Result<usize>> {
327        match self.model.clone() {
328            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
329            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
330            CloudModel::Google(model) => {
331                let client = self.client.clone();
332                let request = request.into_google(model.id().into());
333                let request = google_ai::CountTokensRequest {
334                    contents: request.contents,
335                };
336                async move {
337                    let request = serde_json::to_string(&request)?;
338                    let response = client
339                        .request(proto::CountLanguageModelTokens {
340                            provider: proto::LanguageModelProvider::Google as i32,
341                            request,
342                        })
343                        .await?;
344                    Ok(response.token_count as usize)
345                }
346                .boxed()
347            }
348            CloudModel::Zed(_) => {
349                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
350            }
351        }
352    }
353
354    fn stream_completion(
355        &self,
356        request: LanguageModelRequest,
357        cx: &AsyncAppContext,
358    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
359        match &self.model {
360            CloudModel::Anthropic(model) => {
361                let request = request.into_anthropic(model.id().into());
362                let client = self.client.clone();
363
364                if cx
365                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
366                    .unwrap_or(false)
367                {
368                    let llm_api_token = self.llm_api_token.clone();
369                    let future = self.request_limiter.stream(async move {
370                        let response = Self::perform_llm_completion(
371                            client.clone(),
372                            llm_api_token,
373                            PerformCompletionParams {
374                                provider: client::LanguageModelProvider::Anthropic,
375                                model: request.model.clone(),
376                                provider_request: RawValue::from_string(serde_json::to_string(
377                                    &request,
378                                )?)?,
379                            },
380                        )
381                        .await?;
382                        let body = BufReader::new(response.into_body());
383                        let stream =
384                            futures::stream::try_unfold(body, move |mut body| async move {
385                                let mut buffer = String::new();
386                                match body.read_line(&mut buffer).await {
387                                    Ok(0) => Ok(None),
388                                    Ok(_) => {
389                                        let event: anthropic::Event =
390                                            serde_json::from_str(&buffer)?;
391                                        Ok(Some((event, body)))
392                                    }
393                                    Err(e) => Err(e.into()),
394                                }
395                            });
396
397                        Ok(anthropic::extract_text_from_events(stream))
398                    });
399                    async move { Ok(future.await?.boxed()) }.boxed()
400                } else {
401                    let future = self.request_limiter.stream(async move {
402                        let request = serde_json::to_string(&request)?;
403                        let stream = client
404                            .request_stream(proto::StreamCompleteWithLanguageModel {
405                                provider: proto::LanguageModelProvider::Anthropic as i32,
406                                request,
407                            })
408                            .await?
409                            .map(|event| Ok(serde_json::from_str(&event?.event)?));
410                        Ok(anthropic::extract_text_from_events(stream))
411                    });
412                    async move { Ok(future.await?.boxed()) }.boxed()
413                }
414            }
415            CloudModel::OpenAi(model) => {
416                let client = self.client.clone();
417                let request = request.into_open_ai(model.id().into());
418
419                if cx
420                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
421                    .unwrap_or(false)
422                {
423                    let llm_api_token = self.llm_api_token.clone();
424                    let future = self.request_limiter.stream(async move {
425                        let response = Self::perform_llm_completion(
426                            client.clone(),
427                            llm_api_token,
428                            PerformCompletionParams {
429                                provider: client::LanguageModelProvider::OpenAi,
430                                model: request.model.clone(),
431                                provider_request: RawValue::from_string(serde_json::to_string(
432                                    &request,
433                                )?)?,
434                            },
435                        )
436                        .await?;
437                        let body = BufReader::new(response.into_body());
438                        let stream =
439                            futures::stream::try_unfold(body, move |mut body| async move {
440                                let mut buffer = String::new();
441                                match body.read_line(&mut buffer).await {
442                                    Ok(0) => Ok(None),
443                                    Ok(_) => {
444                                        let event: open_ai::ResponseStreamEvent =
445                                            serde_json::from_str(&buffer)?;
446                                        Ok(Some((event, body)))
447                                    }
448                                    Err(e) => Err(e.into()),
449                                }
450                            });
451
452                        Ok(open_ai::extract_text_from_events(stream))
453                    });
454                    async move { Ok(future.await?.boxed()) }.boxed()
455                } else {
456                    let future = self.request_limiter.stream(async move {
457                        let request = serde_json::to_string(&request)?;
458                        let stream = client
459                            .request_stream(proto::StreamCompleteWithLanguageModel {
460                                provider: proto::LanguageModelProvider::OpenAi as i32,
461                                request,
462                            })
463                            .await?;
464                        Ok(open_ai::extract_text_from_events(
465                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
466                        ))
467                    });
468                    async move { Ok(future.await?.boxed()) }.boxed()
469                }
470            }
471            CloudModel::Google(model) => {
472                let client = self.client.clone();
473                let request = request.into_google(model.id().into());
474
475                if cx
476                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
477                    .unwrap_or(false)
478                {
479                    let llm_api_token = self.llm_api_token.clone();
480                    let future = self.request_limiter.stream(async move {
481                        let response = Self::perform_llm_completion(
482                            client.clone(),
483                            llm_api_token,
484                            PerformCompletionParams {
485                                provider: client::LanguageModelProvider::Google,
486                                model: request.model.clone(),
487                                provider_request: RawValue::from_string(serde_json::to_string(
488                                    &request,
489                                )?)?,
490                            },
491                        )
492                        .await?;
493                        let body = BufReader::new(response.into_body());
494                        let stream =
495                            futures::stream::try_unfold(body, move |mut body| async move {
496                                let mut buffer = String::new();
497                                match body.read_line(&mut buffer).await {
498                                    Ok(0) => Ok(None),
499                                    Ok(_) => {
500                                        let event: google_ai::GenerateContentResponse =
501                                            serde_json::from_str(&buffer)?;
502                                        Ok(Some((event, body)))
503                                    }
504                                    Err(e) => Err(e.into()),
505                                }
506                            });
507
508                        Ok(google_ai::extract_text_from_events(stream))
509                    });
510                    async move { Ok(future.await?.boxed()) }.boxed()
511                } else {
512                    let future = self.request_limiter.stream(async move {
513                        let request = serde_json::to_string(&request)?;
514                        let stream = client
515                            .request_stream(proto::StreamCompleteWithLanguageModel {
516                                provider: proto::LanguageModelProvider::Google as i32,
517                                request,
518                            })
519                            .await?;
520                        Ok(google_ai::extract_text_from_events(
521                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
522                        ))
523                    });
524                    async move { Ok(future.await?.boxed()) }.boxed()
525                }
526            }
527            CloudModel::Zed(model) => {
528                let client = self.client.clone();
529                let mut request = request.into_open_ai(model.id().into());
530                request.max_tokens = Some(4000);
531
532                if cx
533                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
534                    .unwrap_or(false)
535                {
536                    let llm_api_token = self.llm_api_token.clone();
537                    let future = self.request_limiter.stream(async move {
538                        let response = Self::perform_llm_completion(
539                            client.clone(),
540                            llm_api_token,
541                            PerformCompletionParams {
542                                provider: client::LanguageModelProvider::Zed,
543                                model: request.model.clone(),
544                                provider_request: RawValue::from_string(serde_json::to_string(
545                                    &request,
546                                )?)?,
547                            },
548                        )
549                        .await?;
550                        let body = BufReader::new(response.into_body());
551                        let stream =
552                            futures::stream::try_unfold(body, move |mut body| async move {
553                                let mut buffer = String::new();
554                                match body.read_line(&mut buffer).await {
555                                    Ok(0) => Ok(None),
556                                    Ok(_) => {
557                                        let event: open_ai::ResponseStreamEvent =
558                                            serde_json::from_str(&buffer)?;
559                                        Ok(Some((event, body)))
560                                    }
561                                    Err(e) => Err(e.into()),
562                                }
563                            });
564
565                        Ok(open_ai::extract_text_from_events(stream))
566                    });
567                    async move { Ok(future.await?.boxed()) }.boxed()
568                } else {
569                    let future = self.request_limiter.stream(async move {
570                        let request = serde_json::to_string(&request)?;
571                        let stream = client
572                            .request_stream(proto::StreamCompleteWithLanguageModel {
573                                provider: proto::LanguageModelProvider::Zed as i32,
574                                request,
575                            })
576                            .await?;
577                        Ok(open_ai::extract_text_from_events(
578                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
579                        ))
580                    });
581                    async move { Ok(future.await?.boxed()) }.boxed()
582                }
583            }
584        }
585    }
586
587    fn use_any_tool(
588        &self,
589        request: LanguageModelRequest,
590        tool_name: String,
591        tool_description: String,
592        input_schema: serde_json::Value,
593        _cx: &AsyncAppContext,
594    ) -> BoxFuture<'static, Result<serde_json::Value>> {
595        match &self.model {
596            CloudModel::Anthropic(model) => {
597                let client = self.client.clone();
598                let mut request = request.into_anthropic(model.tool_model_id().into());
599                request.tool_choice = Some(anthropic::ToolChoice::Tool {
600                    name: tool_name.clone(),
601                });
602                request.tools = vec![anthropic::Tool {
603                    name: tool_name.clone(),
604                    description: tool_description,
605                    input_schema,
606                }];
607
608                self.request_limiter
609                    .run(async move {
610                        let request = serde_json::to_string(&request)?;
611                        let response = client
612                            .request(proto::CompleteWithLanguageModel {
613                                provider: proto::LanguageModelProvider::Anthropic as i32,
614                                request,
615                            })
616                            .await?;
617                        let response: anthropic::Response =
618                            serde_json::from_str(&response.completion)?;
619                        response
620                            .content
621                            .into_iter()
622                            .find_map(|content| {
623                                if let anthropic::Content::ToolUse { name, input, .. } = content {
624                                    if name == tool_name {
625                                        Some(input)
626                                    } else {
627                                        None
628                                    }
629                                } else {
630                                    None
631                                }
632                            })
633                            .context("tool not used")
634                    })
635                    .boxed()
636            }
637            CloudModel::OpenAi(model) => {
638                let mut request = request.into_open_ai(model.id().into());
639                let client = self.client.clone();
640                let mut function = open_ai::FunctionDefinition {
641                    name: tool_name.clone(),
642                    description: None,
643                    parameters: None,
644                };
645                let func = open_ai::ToolDefinition::Function {
646                    function: function.clone(),
647                };
648                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
649                // Fill in description and params separately, as they're not needed for tool_choice field.
650                function.description = Some(tool_description);
651                function.parameters = Some(input_schema);
652                request.tools = vec![open_ai::ToolDefinition::Function { function }];
653                self.request_limiter
654                    .run(async move {
655                        let request = serde_json::to_string(&request)?;
656                        let response = client
657                            .request_stream(proto::StreamCompleteWithLanguageModel {
658                                provider: proto::LanguageModelProvider::OpenAi as i32,
659                                request,
660                            })
661                            .await?;
662                        // Call arguments are gonna be streamed in over multiple chunks.
663                        let mut load_state = None;
664                        let mut response = response.map(
665                            |item: Result<
666                                proto::StreamCompleteWithLanguageModelResponse,
667                                anyhow::Error,
668                            >| {
669                                Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
670                                    serde_json::from_str(&item?.event)?,
671                                )
672                            },
673                        );
674                        while let Some(Ok(part)) = response.next().await {
675                            for choice in part.choices {
676                                let Some(tool_calls) = choice.delta.tool_calls else {
677                                    continue;
678                                };
679
680                                for call in tool_calls {
681                                    if let Some(func) = call.function {
682                                        if func.name.as_deref() == Some(tool_name.as_str()) {
683                                            load_state = Some((String::default(), call.index));
684                                        }
685                                        if let Some((arguments, (output, index))) =
686                                            func.arguments.zip(load_state.as_mut())
687                                        {
688                                            if call.index == *index {
689                                                output.push_str(&arguments);
690                                            }
691                                        }
692                                    }
693                                }
694                            }
695                        }
696                        if let Some((arguments, _)) = load_state {
697                            return Ok(serde_json::from_str(&arguments)?);
698                        } else {
699                            bail!("tool not used");
700                        }
701                    })
702                    .boxed()
703            }
704            CloudModel::Google(_) => {
705                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
706            }
707            CloudModel::Zed(model) => {
708                // All Zed models are OpenAI-based at the time of writing.
709                let mut request = request.into_open_ai(model.id().into());
710                let client = self.client.clone();
711                let mut function = open_ai::FunctionDefinition {
712                    name: tool_name.clone(),
713                    description: None,
714                    parameters: None,
715                };
716                let func = open_ai::ToolDefinition::Function {
717                    function: function.clone(),
718                };
719                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
720                // Fill in description and params separately, as they're not needed for tool_choice field.
721                function.description = Some(tool_description);
722                function.parameters = Some(input_schema);
723                request.tools = vec![open_ai::ToolDefinition::Function { function }];
724                self.request_limiter
725                    .run(async move {
726                        let request = serde_json::to_string(&request)?;
727                        let response = client
728                            .request_stream(proto::StreamCompleteWithLanguageModel {
729                                provider: proto::LanguageModelProvider::OpenAi as i32,
730                                request,
731                            })
732                            .await?;
733                        // Call arguments are gonna be streamed in over multiple chunks.
734                        let mut load_state = None;
735                        let mut response = response.map(
736                            |item: Result<
737                                proto::StreamCompleteWithLanguageModelResponse,
738                                anyhow::Error,
739                            >| {
740                                Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
741                                    serde_json::from_str(&item?.event)?,
742                                )
743                            },
744                        );
745                        while let Some(Ok(part)) = response.next().await {
746                            for choice in part.choices {
747                                let Some(tool_calls) = choice.delta.tool_calls else {
748                                    continue;
749                                };
750
751                                for call in tool_calls {
752                                    if let Some(func) = call.function {
753                                        if func.name.as_deref() == Some(tool_name.as_str()) {
754                                            load_state = Some((String::default(), call.index));
755                                        }
756                                        if let Some((arguments, (output, index))) =
757                                            func.arguments.zip(load_state.as_mut())
758                                        {
759                                            if call.index == *index {
760                                                output.push_str(&arguments);
761                                            }
762                                        }
763                                    }
764                                }
765                            }
766                        }
767                        if let Some((arguments, _)) = load_state {
768                            return Ok(serde_json::from_str(&arguments)?);
769                        } else {
770                            bail!("tool not used");
771                        }
772                    })
773                    .boxed()
774            }
775        }
776    }
777}
778
779impl LlmApiToken {
780    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
781        let lock = self.0.upgradable_read().await;
782        if let Some(token) = lock.as_ref() {
783            Ok(token.to_string())
784        } else {
785            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
786        }
787    }
788
789    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
790        Self::fetch(self.0.write().await, &client).await
791    }
792
793    async fn fetch<'a>(
794        mut lock: RwLockWriteGuard<'a, Option<String>>,
795        client: &Arc<Client>,
796    ) -> Result<String> {
797        let response = client.request(proto::GetLlmToken {}).await?;
798        *lock = Some(response.token.clone());
799        Ok(response.token.clone())
800    }
801}
802
803struct ConfigurationView {
804    state: gpui::Model<State>,
805}
806
807impl ConfigurationView {
808    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
809        self.state.update(cx, |state, cx| {
810            state.authenticate(cx).detach_and_log_err(cx);
811        });
812        cx.notify();
813    }
814}
815
816impl Render for ConfigurationView {
817    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
818        const ZED_AI_URL: &str = "https://zed.dev/ai";
819        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
820
821        let is_connected = !self.state.read(cx).is_signed_out();
822        let plan = self.state.read(cx).user_store.read(cx).current_plan();
823
824        let is_pro = plan == Some(proto::Plan::ZedPro);
825
826        if is_connected {
827            v_flex()
828                .gap_3()
829                .max_w_4_5()
830                .child(Label::new(
831                    if is_pro {
832                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
833                    } else {
834                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
835                    }))
836                .child(
837                    if is_pro {
838                        h_flex().child(
839                        Button::new("manage_settings", "Manage Subscription")
840                            .style(ButtonStyle::Filled)
841                            .on_click(cx.listener(|_, _, cx| {
842                                cx.open_url(ACCOUNT_SETTINGS_URL)
843                            })))
844                    } else {
845                        h_flex()
846                            .gap_2()
847                            .child(
848                        Button::new("learn_more", "Learn more")
849                            .style(ButtonStyle::Subtle)
850                            .on_click(cx.listener(|_, _, cx| {
851                                cx.open_url(ZED_AI_URL)
852                            })))
853                            .child(
854                        Button::new("upgrade", "Upgrade")
855                            .style(ButtonStyle::Subtle)
856                            .color(Color::Accent)
857                            .on_click(cx.listener(|_, _, cx| {
858                                cx.open_url(ACCOUNT_SETTINGS_URL)
859                            })))
860                    },
861                )
862        } else {
863            v_flex()
864                .gap_6()
865                .child(Label::new("Use the zed.dev to access language models."))
866                .child(
867                    v_flex()
868                        .gap_2()
869                        .child(
870                            Button::new("sign_in", "Sign in")
871                                .icon_color(Color::Muted)
872                                .icon(IconName::Github)
873                                .icon_position(IconPosition::Start)
874                                .style(ButtonStyle::Filled)
875                                .full_width()
876                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
877                        )
878                        .child(
879                            div().flex().w_full().items_center().child(
880                                Label::new("Sign in to enable collaboration.")
881                                    .color(Color::Muted)
882                                    .size(LabelSize::Small),
883                            ),
884                        ),
885                )
886        }
887    }
888}