cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anthropic::AnthropicError;
  8use anyhow::{anyhow, bail, Context as _, Result};
  9use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 10use collections::BTreeMap;
 11use feature_flags::{FeatureFlagAppExt, ZedPro};
 12use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 13use gpui::{
 14    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
 15    Subscription, Task,
 16};
 17use http_client::{AsyncBody, HttpClient, Method, Response};
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use serde_json::value::RawValue;
 21use settings::{Settings, SettingsStore};
 22use smol::{
 23    io::{AsyncReadExt, BufReader},
 24    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 25};
 26use std::{future, sync::Arc};
 27use strum::IntoEnumIterator;
 28use ui::prelude::*;
 29
 30use crate::{LanguageModelAvailability, LanguageModelProvider};
 31
 32use super::anthropic::count_anthropic_tokens;
 33
 34pub const PROVIDER_ID: &str = "zed.dev";
 35pub const PROVIDER_NAME: &str = "Zed";
 36
 37#[derive(Default, Clone, Debug, PartialEq)]
 38pub struct ZedDotDevSettings {
 39    pub available_models: Vec<AvailableModel>,
 40}
 41
 42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 43#[serde(rename_all = "lowercase")]
 44pub enum AvailableProvider {
 45    Anthropic,
 46    OpenAi,
 47    Google,
 48}
 49
 50#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 51pub struct AvailableModel {
 52    provider: AvailableProvider,
 53    name: String,
 54    max_tokens: usize,
 55    tool_override: Option<String>,
 56}
 57
 58pub struct CloudLanguageModelProvider {
 59    client: Arc<Client>,
 60    llm_api_token: LlmApiToken,
 61    state: gpui::Model<State>,
 62    _maintain_client_status: Task<()>,
 63}
 64
 65pub struct State {
 66    client: Arc<Client>,
 67    user_store: Model<UserStore>,
 68    status: client::Status,
 69    accept_terms: Option<Task<Result<()>>>,
 70    _subscription: Subscription,
 71}
 72
 73impl State {
 74    fn is_signed_out(&self) -> bool {
 75        self.status.is_signed_out()
 76    }
 77
 78    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 79        let client = self.client.clone();
 80        cx.spawn(move |this, mut cx| async move {
 81            client.authenticate_and_connect(true, &cx).await?;
 82            this.update(&mut cx, |_, cx| cx.notify())
 83        })
 84    }
 85
 86    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
 87        self.user_store
 88            .read(cx)
 89            .current_user_has_accepted_terms()
 90            .unwrap_or(false)
 91    }
 92
 93    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
 94        let user_store = self.user_store.clone();
 95        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
 96            let _ = user_store
 97                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
 98                .await;
 99            this.update(&mut cx, |this, cx| {
100                this.accept_terms = None;
101                cx.notify()
102            })
103        }));
104    }
105}
106
107impl CloudLanguageModelProvider {
108    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
109        let mut status_rx = client.status();
110        let status = *status_rx.borrow();
111
112        let state = cx.new_model(|cx| State {
113            client: client.clone(),
114            user_store,
115            status,
116            accept_terms: None,
117            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
118                cx.notify();
119            }),
120        });
121
122        let state_ref = state.downgrade();
123        let maintain_client_status = cx.spawn(|mut cx| async move {
124            while let Some(status) = status_rx.next().await {
125                if let Some(this) = state_ref.upgrade() {
126                    _ = this.update(&mut cx, |this, cx| {
127                        if this.status != status {
128                            this.status = status;
129                            cx.notify();
130                        }
131                    });
132                } else {
133                    break;
134                }
135            }
136        });
137
138        Self {
139            client,
140            state,
141            llm_api_token: LlmApiToken::default(),
142            _maintain_client_status: maintain_client_status,
143        }
144    }
145}
146
147impl LanguageModelProviderState for CloudLanguageModelProvider {
148    type ObservableEntity = State;
149
150    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
151        Some(self.state.clone())
152    }
153}
154
155impl LanguageModelProvider for CloudLanguageModelProvider {
156    fn id(&self) -> LanguageModelProviderId {
157        LanguageModelProviderId(PROVIDER_ID.into())
158    }
159
160    fn name(&self) -> LanguageModelProviderName {
161        LanguageModelProviderName(PROVIDER_NAME.into())
162    }
163
164    fn icon(&self) -> IconName {
165        IconName::AiZed
166    }
167
168    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
169        let mut models = BTreeMap::default();
170
171        if cx.is_staff() {
172            for model in anthropic::Model::iter() {
173                if !matches!(model, anthropic::Model::Custom { .. }) {
174                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
175                }
176            }
177            for model in open_ai::Model::iter() {
178                if !matches!(model, open_ai::Model::Custom { .. }) {
179                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
180                }
181            }
182            for model in google_ai::Model::iter() {
183                if !matches!(model, google_ai::Model::Custom { .. }) {
184                    models.insert(model.id().to_string(), CloudModel::Google(model));
185                }
186            }
187            for model in ZedModel::iter() {
188                models.insert(model.id().to_string(), CloudModel::Zed(model));
189            }
190
191            // Override with available models from settings
192            for model in &AllLanguageModelSettings::get_global(cx)
193                .zed_dot_dev
194                .available_models
195            {
196                let model = match model.provider {
197                    AvailableProvider::Anthropic => {
198                        CloudModel::Anthropic(anthropic::Model::Custom {
199                            name: model.name.clone(),
200                            max_tokens: model.max_tokens,
201                            tool_override: model.tool_override.clone(),
202                        })
203                    }
204                    AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
205                        name: model.name.clone(),
206                        max_tokens: model.max_tokens,
207                    }),
208                    AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
209                        name: model.name.clone(),
210                        max_tokens: model.max_tokens,
211                    }),
212                };
213                models.insert(model.id().to_string(), model.clone());
214            }
215        } else {
216            models.insert(
217                anthropic::Model::Claude3_5Sonnet.id().to_string(),
218                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
219            );
220        }
221
222        models
223            .into_values()
224            .map(|model| {
225                Arc::new(CloudLanguageModel {
226                    id: LanguageModelId::from(model.id().to_string()),
227                    model,
228                    llm_api_token: self.llm_api_token.clone(),
229                    client: self.client.clone(),
230                    request_limiter: RateLimiter::new(4),
231                }) as Arc<dyn LanguageModel>
232            })
233            .collect()
234    }
235
236    fn is_authenticated(&self, cx: &AppContext) -> bool {
237        !self.state.read(cx).is_signed_out()
238    }
239
240    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
241        Task::ready(Ok(()))
242    }
243
244    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
245        cx.new_view(|_cx| ConfigurationView {
246            state: self.state.clone(),
247        })
248        .into()
249    }
250
251    fn must_accept_terms(&self, cx: &AppContext) -> bool {
252        !self.state.read(cx).has_accepted_terms_of_service(cx)
253    }
254
255    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
256        let state = self.state.read(cx);
257
258        let terms = [(
259            "anthropic_terms_of_service",
260            "Anthropic Terms of Service",
261            "https://www.anthropic.com/legal/consumer-terms",
262        )]
263        .map(|(id, label, url)| {
264            Button::new(id, label)
265                .style(ButtonStyle::Subtle)
266                .icon(IconName::ExternalLink)
267                .icon_size(IconSize::XSmall)
268                .icon_color(Color::Muted)
269                .on_click(move |_, cx| cx.open_url(url))
270        });
271
272        if state.has_accepted_terms_of_service(cx) {
273            None
274        } else {
275            let disabled = state.accept_terms.is_some();
276            Some(
277                v_flex()
278                    .child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD))
279                    .child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.")
280                    .child(v_flex().m_2().gap_1().children(terms))
281                    .child(
282                        h_flex().justify_end().mt_1().child(
283                            Button::new("accept_terms", "Accept")
284                                .disabled(disabled)
285                                .on_click({
286                                    let state = self.state.downgrade();
287                                    move |_, cx| {
288                                        state
289                                            .update(cx, |state, cx| {
290                                                state.accept_terms_of_service(cx)
291                                            })
292                                            .ok();
293                                    }
294                                }),
295                        ),
296                    )
297                    .into_any(),
298            )
299        }
300    }
301
302    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
303        Task::ready(Ok(()))
304    }
305}
306
307pub struct CloudLanguageModel {
308    id: LanguageModelId,
309    model: CloudModel,
310    llm_api_token: LlmApiToken,
311    client: Arc<Client>,
312    request_limiter: RateLimiter,
313}
314
315#[derive(Clone, Default)]
316struct LlmApiToken(Arc<RwLock<Option<String>>>);
317
318impl CloudLanguageModel {
319    async fn perform_llm_completion(
320        client: Arc<Client>,
321        llm_api_token: LlmApiToken,
322        body: PerformCompletionParams,
323    ) -> Result<Response<AsyncBody>> {
324        let http_client = &client.http_client();
325
326        let mut token = llm_api_token.acquire(&client).await?;
327        let mut did_retry = false;
328
329        let response = loop {
330            let request = http_client::Request::builder()
331                .method(Method::POST)
332                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
333                .header("Content-Type", "application/json")
334                .header("Authorization", format!("Bearer {token}"))
335                .body(serde_json::to_string(&body)?.into())?;
336            let mut response = http_client.send(request).await?;
337            if response.status().is_success() {
338                break response;
339            } else if !did_retry
340                && response
341                    .headers()
342                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
343                    .is_some()
344            {
345                did_retry = true;
346                token = llm_api_token.refresh(&client).await?;
347            } else {
348                let mut body = String::new();
349                response.body_mut().read_to_string(&mut body).await?;
350                break Err(anyhow!(
351                    "cloud language model completion failed with status {}: {body}",
352                    response.status()
353                ))?;
354            }
355        };
356
357        Ok(response)
358    }
359}
360
361impl LanguageModel for CloudLanguageModel {
362    fn id(&self) -> LanguageModelId {
363        self.id.clone()
364    }
365
366    fn name(&self) -> LanguageModelName {
367        LanguageModelName::from(self.model.display_name().to_string())
368    }
369
370    fn provider_id(&self) -> LanguageModelProviderId {
371        LanguageModelProviderId(PROVIDER_ID.into())
372    }
373
374    fn provider_name(&self) -> LanguageModelProviderName {
375        LanguageModelProviderName(PROVIDER_NAME.into())
376    }
377
378    fn telemetry_id(&self) -> String {
379        format!("zed.dev/{}", self.model.id())
380    }
381
382    fn availability(&self) -> LanguageModelAvailability {
383        self.model.availability()
384    }
385
386    fn max_token_count(&self) -> usize {
387        self.model.max_token_count()
388    }
389
390    fn count_tokens(
391        &self,
392        request: LanguageModelRequest,
393        cx: &AppContext,
394    ) -> BoxFuture<'static, Result<usize>> {
395        match self.model.clone() {
396            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
397            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
398            CloudModel::Google(model) => {
399                let client = self.client.clone();
400                let request = request.into_google(model.id().into());
401                let request = google_ai::CountTokensRequest {
402                    contents: request.contents,
403                };
404                async move {
405                    let request = serde_json::to_string(&request)?;
406                    let response = client
407                        .request(proto::CountLanguageModelTokens {
408                            provider: proto::LanguageModelProvider::Google as i32,
409                            request,
410                        })
411                        .await?;
412                    Ok(response.token_count as usize)
413                }
414                .boxed()
415            }
416            CloudModel::Zed(_) => {
417                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
418            }
419        }
420    }
421
422    fn stream_completion(
423        &self,
424        request: LanguageModelRequest,
425        _cx: &AsyncAppContext,
426    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
427        match &self.model {
428            CloudModel::Anthropic(model) => {
429                let request = request.into_anthropic(model.id().into());
430                let client = self.client.clone();
431                let llm_api_token = self.llm_api_token.clone();
432                let future = self.request_limiter.stream(async move {
433                    let response = Self::perform_llm_completion(
434                        client.clone(),
435                        llm_api_token,
436                        PerformCompletionParams {
437                            provider: client::LanguageModelProvider::Anthropic,
438                            model: request.model.clone(),
439                            provider_request: RawValue::from_string(serde_json::to_string(
440                                &request,
441                            )?)?,
442                        },
443                    )
444                    .await?;
445                    let body = BufReader::new(response.into_body());
446                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
447                        let mut buffer = String::new();
448                        match body.read_line(&mut buffer).await {
449                            Ok(0) => Ok(None),
450                            Ok(_) => {
451                                let event: anthropic::Event = serde_json::from_str(&buffer)
452                                    .context("failed to parse Anthropic event")?;
453                                Ok(Some((event, body)))
454                            }
455                            Err(err) => Err(AnthropicError::Other(err.into())),
456                        }
457                    });
458
459                    Ok(anthropic::extract_text_from_events(stream))
460                });
461                async move {
462                    Ok(future
463                        .await?
464                        .map(|result| result.map_err(|err| anyhow!(err)))
465                        .boxed())
466                }
467                .boxed()
468            }
469            CloudModel::OpenAi(model) => {
470                let client = self.client.clone();
471                let request = request.into_open_ai(model.id().into());
472                let llm_api_token = self.llm_api_token.clone();
473                let future = self.request_limiter.stream(async move {
474                    let response = Self::perform_llm_completion(
475                        client.clone(),
476                        llm_api_token,
477                        PerformCompletionParams {
478                            provider: client::LanguageModelProvider::OpenAi,
479                            model: request.model.clone(),
480                            provider_request: RawValue::from_string(serde_json::to_string(
481                                &request,
482                            )?)?,
483                        },
484                    )
485                    .await?;
486                    let body = BufReader::new(response.into_body());
487                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
488                        let mut buffer = String::new();
489                        match body.read_line(&mut buffer).await {
490                            Ok(0) => Ok(None),
491                            Ok(_) => {
492                                let event: open_ai::ResponseStreamEvent =
493                                    serde_json::from_str(&buffer)?;
494                                Ok(Some((event, body)))
495                            }
496                            Err(e) => Err(e.into()),
497                        }
498                    });
499
500                    Ok(open_ai::extract_text_from_events(stream))
501                });
502                async move { Ok(future.await?.boxed()) }.boxed()
503            }
504            CloudModel::Google(model) => {
505                let client = self.client.clone();
506                let request = request.into_google(model.id().into());
507                let llm_api_token = self.llm_api_token.clone();
508                let future = self.request_limiter.stream(async move {
509                    let response = Self::perform_llm_completion(
510                        client.clone(),
511                        llm_api_token,
512                        PerformCompletionParams {
513                            provider: client::LanguageModelProvider::Google,
514                            model: request.model.clone(),
515                            provider_request: RawValue::from_string(serde_json::to_string(
516                                &request,
517                            )?)?,
518                        },
519                    )
520                    .await?;
521                    let body = BufReader::new(response.into_body());
522                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
523                        let mut buffer = String::new();
524                        match body.read_line(&mut buffer).await {
525                            Ok(0) => Ok(None),
526                            Ok(_) => {
527                                let event: google_ai::GenerateContentResponse =
528                                    serde_json::from_str(&buffer)?;
529                                Ok(Some((event, body)))
530                            }
531                            Err(e) => Err(e.into()),
532                        }
533                    });
534
535                    Ok(google_ai::extract_text_from_events(stream))
536                });
537                async move { Ok(future.await?.boxed()) }.boxed()
538            }
539            CloudModel::Zed(model) => {
540                let client = self.client.clone();
541                let mut request = request.into_open_ai(model.id().into());
542                request.max_tokens = Some(4000);
543                let llm_api_token = self.llm_api_token.clone();
544                let future = self.request_limiter.stream(async move {
545                    let response = Self::perform_llm_completion(
546                        client.clone(),
547                        llm_api_token,
548                        PerformCompletionParams {
549                            provider: client::LanguageModelProvider::Zed,
550                            model: request.model.clone(),
551                            provider_request: RawValue::from_string(serde_json::to_string(
552                                &request,
553                            )?)?,
554                        },
555                    )
556                    .await?;
557                    let body = BufReader::new(response.into_body());
558                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
559                        let mut buffer = String::new();
560                        match body.read_line(&mut buffer).await {
561                            Ok(0) => Ok(None),
562                            Ok(_) => {
563                                let event: open_ai::ResponseStreamEvent =
564                                    serde_json::from_str(&buffer)?;
565                                Ok(Some((event, body)))
566                            }
567                            Err(e) => Err(e.into()),
568                        }
569                    });
570
571                    Ok(open_ai::extract_text_from_events(stream))
572                });
573                async move { Ok(future.await?.boxed()) }.boxed()
574            }
575        }
576    }
577
578    fn use_any_tool(
579        &self,
580        request: LanguageModelRequest,
581        tool_name: String,
582        tool_description: String,
583        input_schema: serde_json::Value,
584        _cx: &AsyncAppContext,
585    ) -> BoxFuture<'static, Result<serde_json::Value>> {
586        match &self.model {
587            CloudModel::Anthropic(model) => {
588                let client = self.client.clone();
589                let mut request = request.into_anthropic(model.tool_model_id().into());
590                request.tool_choice = Some(anthropic::ToolChoice::Tool {
591                    name: tool_name.clone(),
592                });
593                request.tools = vec![anthropic::Tool {
594                    name: tool_name.clone(),
595                    description: tool_description,
596                    input_schema,
597                }];
598
599                let llm_api_token = self.llm_api_token.clone();
600                self.request_limiter
601                    .run(async move {
602                        let response = Self::perform_llm_completion(
603                            client.clone(),
604                            llm_api_token,
605                            PerformCompletionParams {
606                                provider: client::LanguageModelProvider::Anthropic,
607                                model: request.model.clone(),
608                                provider_request: RawValue::from_string(serde_json::to_string(
609                                    &request,
610                                )?)?,
611                            },
612                        )
613                        .await?;
614
615                        let mut tool_use_index = None;
616                        let mut tool_input = String::new();
617                        let mut body = BufReader::new(response.into_body());
618                        let mut line = String::new();
619                        while body.read_line(&mut line).await? > 0 {
620                            let event: anthropic::Event = serde_json::from_str(&line)?;
621                            line.clear();
622
623                            match event {
624                                anthropic::Event::ContentBlockStart {
625                                    content_block,
626                                    index,
627                                } => {
628                                    if let anthropic::Content::ToolUse { name, .. } = content_block
629                                    {
630                                        if name == tool_name {
631                                            tool_use_index = Some(index);
632                                        }
633                                    }
634                                }
635                                anthropic::Event::ContentBlockDelta { index, delta } => match delta
636                                {
637                                    anthropic::ContentDelta::TextDelta { .. } => {}
638                                    anthropic::ContentDelta::InputJsonDelta { partial_json } => {
639                                        if Some(index) == tool_use_index {
640                                            tool_input.push_str(&partial_json);
641                                        }
642                                    }
643                                },
644                                anthropic::Event::ContentBlockStop { index } => {
645                                    if Some(index) == tool_use_index {
646                                        return Ok(serde_json::from_str(&tool_input)?);
647                                    }
648                                }
649                                _ => {}
650                            }
651                        }
652
653                        if tool_use_index.is_some() {
654                            Err(anyhow!("tool content incomplete"))
655                        } else {
656                            Err(anyhow!("tool not used"))
657                        }
658                    })
659                    .boxed()
660            }
661            CloudModel::OpenAi(model) => {
662                let mut request = request.into_open_ai(model.id().into());
663                let client = self.client.clone();
664                let mut function = open_ai::FunctionDefinition {
665                    name: tool_name.clone(),
666                    description: None,
667                    parameters: None,
668                };
669                let func = open_ai::ToolDefinition::Function {
670                    function: function.clone(),
671                };
672                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
673                // Fill in description and params separately, as they're not needed for tool_choice field.
674                function.description = Some(tool_description);
675                function.parameters = Some(input_schema);
676                request.tools = vec![open_ai::ToolDefinition::Function { function }];
677
678                let llm_api_token = self.llm_api_token.clone();
679                self.request_limiter
680                    .run(async move {
681                        let response = Self::perform_llm_completion(
682                            client.clone(),
683                            llm_api_token,
684                            PerformCompletionParams {
685                                provider: client::LanguageModelProvider::OpenAi,
686                                model: request.model.clone(),
687                                provider_request: RawValue::from_string(serde_json::to_string(
688                                    &request,
689                                )?)?,
690                            },
691                        )
692                        .await?;
693
694                        let mut body = BufReader::new(response.into_body());
695                        let mut line = String::new();
696                        let mut load_state = None;
697
698                        while body.read_line(&mut line).await? > 0 {
699                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
700                            line.clear();
701
702                            for choice in part.choices {
703                                let Some(tool_calls) = choice.delta.tool_calls else {
704                                    continue;
705                                };
706
707                                for call in tool_calls {
708                                    if let Some(func) = call.function {
709                                        if func.name.as_deref() == Some(tool_name.as_str()) {
710                                            load_state = Some((String::default(), call.index));
711                                        }
712                                        if let Some((arguments, (output, index))) =
713                                            func.arguments.zip(load_state.as_mut())
714                                        {
715                                            if call.index == *index {
716                                                output.push_str(&arguments);
717                                            }
718                                        }
719                                    }
720                                }
721                            }
722                        }
723
724                        if let Some((arguments, _)) = load_state {
725                            return Ok(serde_json::from_str(&arguments)?);
726                        } else {
727                            bail!("tool not used");
728                        }
729                    })
730                    .boxed()
731            }
732            CloudModel::Google(_) => {
733                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
734            }
735            CloudModel::Zed(model) => {
736                // All Zed models are OpenAI-based at the time of writing.
737                let mut request = request.into_open_ai(model.id().into());
738                let client = self.client.clone();
739                let mut function = open_ai::FunctionDefinition {
740                    name: tool_name.clone(),
741                    description: None,
742                    parameters: None,
743                };
744                let func = open_ai::ToolDefinition::Function {
745                    function: function.clone(),
746                };
747                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
748                // Fill in description and params separately, as they're not needed for tool_choice field.
749                function.description = Some(tool_description);
750                function.parameters = Some(input_schema);
751                request.tools = vec![open_ai::ToolDefinition::Function { function }];
752
753                let llm_api_token = self.llm_api_token.clone();
754                self.request_limiter
755                    .run(async move {
756                        let response = Self::perform_llm_completion(
757                            client.clone(),
758                            llm_api_token,
759                            PerformCompletionParams {
760                                provider: client::LanguageModelProvider::Zed,
761                                model: request.model.clone(),
762                                provider_request: RawValue::from_string(serde_json::to_string(
763                                    &request,
764                                )?)?,
765                            },
766                        )
767                        .await?;
768
769                        let mut body = BufReader::new(response.into_body());
770                        let mut line = String::new();
771                        let mut load_state = None;
772
773                        while body.read_line(&mut line).await? > 0 {
774                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
775                            line.clear();
776
777                            for choice in part.choices {
778                                let Some(tool_calls) = choice.delta.tool_calls else {
779                                    continue;
780                                };
781
782                                for call in tool_calls {
783                                    if let Some(func) = call.function {
784                                        if func.name.as_deref() == Some(tool_name.as_str()) {
785                                            load_state = Some((String::default(), call.index));
786                                        }
787                                        if let Some((arguments, (output, index))) =
788                                            func.arguments.zip(load_state.as_mut())
789                                        {
790                                            if call.index == *index {
791                                                output.push_str(&arguments);
792                                            }
793                                        }
794                                    }
795                                }
796                            }
797                        }
798                        if let Some((arguments, _)) = load_state {
799                            return Ok(serde_json::from_str(&arguments)?);
800                        } else {
801                            bail!("tool not used");
802                        }
803                    })
804                    .boxed()
805            }
806        }
807    }
808}
809
810impl LlmApiToken {
811    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
812        let lock = self.0.upgradable_read().await;
813        if let Some(token) = lock.as_ref() {
814            Ok(token.to_string())
815        } else {
816            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
817        }
818    }
819
820    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
821        Self::fetch(self.0.write().await, &client).await
822    }
823
824    async fn fetch<'a>(
825        mut lock: RwLockWriteGuard<'a, Option<String>>,
826        client: &Arc<Client>,
827    ) -> Result<String> {
828        let response = client.request(proto::GetLlmToken {}).await?;
829        *lock = Some(response.token.clone());
830        Ok(response.token.clone())
831    }
832}
833
834struct ConfigurationView {
835    state: gpui::Model<State>,
836}
837
838impl ConfigurationView {
839    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
840        self.state.update(cx, |state, cx| {
841            state.authenticate(cx).detach_and_log_err(cx);
842        });
843        cx.notify();
844    }
845}
846
847impl Render for ConfigurationView {
848    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
849        const ZED_AI_URL: &str = "https://zed.dev/ai";
850        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
851
852        let is_connected = !self.state.read(cx).is_signed_out();
853        let plan = self.state.read(cx).user_store.read(cx).current_plan();
854        let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
855
856        let is_pro = plan == Some(proto::Plan::ZedPro);
857
858        if is_connected {
859            v_flex()
860                .gap_3()
861                .max_w_4_5()
862                .when(must_accept_terms, |this| {
863                    this.child(Label::new(
864                        "You must accept the terms of service to use this provider.",
865                    ))
866                })
867                .child(Label::new(
868                    if is_pro {
869                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
870                    } else {
871                        "You have basic access to models from Anthropic through the Zed AI Free plan."
872                    }))
873                .children(if is_pro {
874                    Some(
875                        h_flex().child(
876                            Button::new("manage_settings", "Manage Subscription")
877                                .style(ButtonStyle::Filled)
878                                .on_click(
879                                    cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
880                                ),
881                        ),
882                    )
883                } else if cx.has_flag::<ZedPro>() {
884                    Some(
885                        h_flex()
886                            .gap_2()
887                            .child(
888                                Button::new("learn_more", "Learn more")
889                                    .style(ButtonStyle::Subtle)
890                                    .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
891                            )
892                            .child(
893                                Button::new("upgrade", "Upgrade")
894                                    .style(ButtonStyle::Subtle)
895                                    .color(Color::Accent)
896                                    .on_click(
897                                        cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
898                                    ),
899                            ),
900                    )
901                } else {
902                    None
903                })
904        } else {
905            v_flex()
906                .gap_6()
907                .child(Label::new("Use the zed.dev to access language models."))
908                .child(
909                    v_flex()
910                        .gap_2()
911                        .child(
912                            Button::new("sign_in", "Sign in")
913                                .icon_color(Color::Muted)
914                                .icon(IconName::Github)
915                                .icon_position(IconPosition::Start)
916                                .style(ButtonStyle::Filled)
917                                .full_width()
918                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
919                        )
920                        .child(
921                            div().flex().w_full().items_center().child(
922                                Label::new("Sign in to enable collaboration.")
923                                    .color(Color::Muted)
924                                    .size(LabelSize::Small),
925                            ),
926                        ),
927                )
928        }
929    }
930}