cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use client::{Client, UserStore};
  9use collections::BTreeMap;
 10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 11use gpui::{
 12    AnyView, AppContext, AsyncAppContext, FocusHandle, Model, ModelContext, Subscription, Task,
 13};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use settings::{Settings, SettingsStore};
 17use std::{future, sync::Arc};
 18use strum::IntoEnumIterator;
 19use ui::prelude::*;
 20
 21use crate::{LanguageModelAvailability, LanguageModelProvider};
 22
 23use super::anthropic::count_anthropic_tokens;
 24
 25pub const PROVIDER_ID: &str = "zed.dev";
 26pub const PROVIDER_NAME: &str = "Zed";
 27
 28#[derive(Default, Clone, Debug, PartialEq)]
 29pub struct ZedDotDevSettings {
 30    pub available_models: Vec<AvailableModel>,
 31}
 32
 33#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 34#[serde(rename_all = "lowercase")]
 35pub enum AvailableProvider {
 36    Anthropic,
 37    OpenAi,
 38    Google,
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 42pub struct AvailableModel {
 43    provider: AvailableProvider,
 44    name: String,
 45    max_tokens: usize,
 46    tool_override: Option<String>,
 47}
 48
 49pub struct CloudLanguageModelProvider {
 50    client: Arc<Client>,
 51    state: gpui::Model<State>,
 52    _maintain_client_status: Task<()>,
 53}
 54
 55pub struct State {
 56    client: Arc<Client>,
 57    user_store: Model<UserStore>,
 58    status: client::Status,
 59    _subscription: Subscription,
 60}
 61
 62impl State {
 63    fn is_connected(&self) -> bool {
 64        self.status.is_connected()
 65    }
 66
 67    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 68        let client = self.client.clone();
 69        cx.spawn(move |this, mut cx| async move {
 70            client.authenticate_and_connect(true, &cx).await?;
 71            this.update(&mut cx, |_, cx| cx.notify())
 72        })
 73    }
 74}
 75
 76impl CloudLanguageModelProvider {
 77    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
 78        let mut status_rx = client.status();
 79        let status = *status_rx.borrow();
 80
 81        let state = cx.new_model(|cx| State {
 82            client: client.clone(),
 83            user_store,
 84            status,
 85            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 86                cx.notify();
 87            }),
 88        });
 89
 90        let state_ref = state.downgrade();
 91        let maintain_client_status = cx.spawn(|mut cx| async move {
 92            while let Some(status) = status_rx.next().await {
 93                if let Some(this) = state_ref.upgrade() {
 94                    _ = this.update(&mut cx, |this, cx| {
 95                        if this.status != status {
 96                            this.status = status;
 97                            cx.notify();
 98                        }
 99                    });
100                } else {
101                    break;
102                }
103            }
104        });
105
106        Self {
107            client,
108            state,
109            _maintain_client_status: maintain_client_status,
110        }
111    }
112}
113
114impl LanguageModelProviderState for CloudLanguageModelProvider {
115    type ObservableEntity = State;
116
117    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
118        Some(self.state.clone())
119    }
120}
121
122impl LanguageModelProvider for CloudLanguageModelProvider {
123    fn id(&self) -> LanguageModelProviderId {
124        LanguageModelProviderId(PROVIDER_ID.into())
125    }
126
127    fn name(&self) -> LanguageModelProviderName {
128        LanguageModelProviderName(PROVIDER_NAME.into())
129    }
130
131    fn icon(&self) -> IconName {
132        IconName::AiZed
133    }
134
135    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
136        let mut models = BTreeMap::default();
137
138        for model in anthropic::Model::iter() {
139            if !matches!(model, anthropic::Model::Custom { .. }) {
140                models.insert(model.id().to_string(), CloudModel::Anthropic(model));
141            }
142        }
143        for model in open_ai::Model::iter() {
144            if !matches!(model, open_ai::Model::Custom { .. }) {
145                models.insert(model.id().to_string(), CloudModel::OpenAi(model));
146            }
147        }
148        for model in google_ai::Model::iter() {
149            if !matches!(model, google_ai::Model::Custom { .. }) {
150                models.insert(model.id().to_string(), CloudModel::Google(model));
151            }
152        }
153        for model in ZedModel::iter() {
154            models.insert(model.id().to_string(), CloudModel::Zed(model));
155        }
156
157        // Override with available models from settings
158        for model in &AllLanguageModelSettings::get_global(cx)
159            .zed_dot_dev
160            .available_models
161        {
162            let model = match model.provider {
163                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
164                    name: model.name.clone(),
165                    max_tokens: model.max_tokens,
166                    tool_override: model.tool_override.clone(),
167                }),
168                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
169                    name: model.name.clone(),
170                    max_tokens: model.max_tokens,
171                }),
172                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
173                    name: model.name.clone(),
174                    max_tokens: model.max_tokens,
175                }),
176            };
177            models.insert(model.id().to_string(), model.clone());
178        }
179
180        models
181            .into_values()
182            .map(|model| {
183                Arc::new(CloudLanguageModel {
184                    id: LanguageModelId::from(model.id().to_string()),
185                    model,
186                    client: self.client.clone(),
187                    request_limiter: RateLimiter::new(4),
188                }) as Arc<dyn LanguageModel>
189            })
190            .collect()
191    }
192
193    fn is_authenticated(&self, cx: &AppContext) -> bool {
194        self.state.read(cx).status.is_connected()
195    }
196
197    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
198        Task::ready(Ok(()))
199    }
200
201    fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
202        let view = cx
203            .new_view(|_cx| ConfigurationView {
204                state: self.state.clone(),
205            })
206            .into();
207        (view, None)
208    }
209
210    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
211        Task::ready(Ok(()))
212    }
213}
214
215pub struct CloudLanguageModel {
216    id: LanguageModelId,
217    model: CloudModel,
218    client: Arc<Client>,
219    request_limiter: RateLimiter,
220}
221
222impl LanguageModel for CloudLanguageModel {
223    fn id(&self) -> LanguageModelId {
224        self.id.clone()
225    }
226
227    fn name(&self) -> LanguageModelName {
228        LanguageModelName::from(self.model.display_name().to_string())
229    }
230
231    fn provider_id(&self) -> LanguageModelProviderId {
232        LanguageModelProviderId(PROVIDER_ID.into())
233    }
234
235    fn provider_name(&self) -> LanguageModelProviderName {
236        LanguageModelProviderName(PROVIDER_NAME.into())
237    }
238
239    fn telemetry_id(&self) -> String {
240        format!("zed.dev/{}", self.model.id())
241    }
242
243    fn availability(&self) -> LanguageModelAvailability {
244        self.model.availability()
245    }
246
247    fn max_token_count(&self) -> usize {
248        self.model.max_token_count()
249    }
250
251    fn count_tokens(
252        &self,
253        request: LanguageModelRequest,
254        cx: &AppContext,
255    ) -> BoxFuture<'static, Result<usize>> {
256        match self.model.clone() {
257            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
258            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
259            CloudModel::Google(model) => {
260                let client = self.client.clone();
261                let request = request.into_google(model.id().into());
262                let request = google_ai::CountTokensRequest {
263                    contents: request.contents,
264                };
265                async move {
266                    let request = serde_json::to_string(&request)?;
267                    let response = client
268                        .request(proto::CountLanguageModelTokens {
269                            provider: proto::LanguageModelProvider::Google as i32,
270                            request,
271                        })
272                        .await?;
273                    Ok(response.token_count as usize)
274                }
275                .boxed()
276            }
277            CloudModel::Zed(_) => {
278                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
279            }
280        }
281    }
282
283    fn stream_completion(
284        &self,
285        request: LanguageModelRequest,
286        _: &AsyncAppContext,
287    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
288        match &self.model {
289            CloudModel::Anthropic(model) => {
290                let client = self.client.clone();
291                let request = request.into_anthropic(model.id().into());
292                let future = self.request_limiter.stream(async move {
293                    let request = serde_json::to_string(&request)?;
294                    let stream = client
295                        .request_stream(proto::StreamCompleteWithLanguageModel {
296                            provider: proto::LanguageModelProvider::Anthropic as i32,
297                            request,
298                        })
299                        .await?;
300                    Ok(anthropic::extract_text_from_events(
301                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
302                    ))
303                });
304                async move { Ok(future.await?.boxed()) }.boxed()
305            }
306            CloudModel::OpenAi(model) => {
307                let client = self.client.clone();
308                let request = request.into_open_ai(model.id().into());
309                let future = self.request_limiter.stream(async move {
310                    let request = serde_json::to_string(&request)?;
311                    let stream = client
312                        .request_stream(proto::StreamCompleteWithLanguageModel {
313                            provider: proto::LanguageModelProvider::OpenAi as i32,
314                            request,
315                        })
316                        .await?;
317                    Ok(open_ai::extract_text_from_events(
318                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
319                    ))
320                });
321                async move { Ok(future.await?.boxed()) }.boxed()
322            }
323            CloudModel::Google(model) => {
324                let client = self.client.clone();
325                let request = request.into_google(model.id().into());
326                let future = self.request_limiter.stream(async move {
327                    let request = serde_json::to_string(&request)?;
328                    let stream = client
329                        .request_stream(proto::StreamCompleteWithLanguageModel {
330                            provider: proto::LanguageModelProvider::Google as i32,
331                            request,
332                        })
333                        .await?;
334                    Ok(google_ai::extract_text_from_events(
335                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
336                    ))
337                });
338                async move { Ok(future.await?.boxed()) }.boxed()
339            }
340            CloudModel::Zed(model) => {
341                let client = self.client.clone();
342                let mut request = request.into_open_ai(model.id().into());
343                request.max_tokens = Some(4000);
344                let future = self.request_limiter.stream(async move {
345                    let request = serde_json::to_string(&request)?;
346                    let stream = client
347                        .request_stream(proto::StreamCompleteWithLanguageModel {
348                            provider: proto::LanguageModelProvider::Zed as i32,
349                            request,
350                        })
351                        .await?;
352                    Ok(open_ai::extract_text_from_events(
353                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
354                    ))
355                });
356                async move { Ok(future.await?.boxed()) }.boxed()
357            }
358        }
359    }
360
361    fn use_any_tool(
362        &self,
363        request: LanguageModelRequest,
364        tool_name: String,
365        tool_description: String,
366        input_schema: serde_json::Value,
367        _cx: &AsyncAppContext,
368    ) -> BoxFuture<'static, Result<serde_json::Value>> {
369        match &self.model {
370            CloudModel::Anthropic(model) => {
371                let client = self.client.clone();
372                let mut request = request.into_anthropic(model.tool_model_id().into());
373                request.tool_choice = Some(anthropic::ToolChoice::Tool {
374                    name: tool_name.clone(),
375                });
376                request.tools = vec![anthropic::Tool {
377                    name: tool_name.clone(),
378                    description: tool_description,
379                    input_schema,
380                }];
381
382                self.request_limiter
383                    .run(async move {
384                        let request = serde_json::to_string(&request)?;
385                        let response = client
386                            .request(proto::CompleteWithLanguageModel {
387                                provider: proto::LanguageModelProvider::Anthropic as i32,
388                                request,
389                            })
390                            .await?;
391                        let response: anthropic::Response =
392                            serde_json::from_str(&response.completion)?;
393                        response
394                            .content
395                            .into_iter()
396                            .find_map(|content| {
397                                if let anthropic::Content::ToolUse { name, input, .. } = content {
398                                    if name == tool_name {
399                                        Some(input)
400                                    } else {
401                                        None
402                                    }
403                                } else {
404                                    None
405                                }
406                            })
407                            .context("tool not used")
408                    })
409                    .boxed()
410            }
411            CloudModel::OpenAi(_) => {
412                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
413            }
414            CloudModel::Google(_) => {
415                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
416            }
417            CloudModel::Zed(_) => {
418                future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
419            }
420        }
421    }
422}
423
424struct ConfigurationView {
425    state: gpui::Model<State>,
426}
427
428impl ConfigurationView {
429    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
430        self.state.update(cx, |state, cx| {
431            state.authenticate(cx).detach_and_log_err(cx);
432        });
433        cx.notify();
434    }
435}
436
437impl Render for ConfigurationView {
438    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
439        const ZED_AI_URL: &str = "https://zed.dev/ai";
440        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
441
442        let is_connected = self.state.read(cx).is_connected();
443        let plan = self.state.read(cx).user_store.read(cx).current_plan();
444
445        let is_pro = plan == Some(proto::Plan::ZedPro);
446
447        if is_connected {
448            v_flex()
449                .gap_3()
450                .max_w_4_5()
451                .child(Label::new(
452                    if is_pro {
453                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
454                    } else {
455                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
456                    }))
457                .child(
458                    if is_pro {
459                        h_flex().child(
460                        Button::new("manage_settings", "Manage Subscription")
461                            .style(ButtonStyle::Filled)
462                            .on_click(cx.listener(|_, _, cx| {
463                                cx.open_url(ACCOUNT_SETTINGS_URL)
464                            })))
465                    } else {
466                        h_flex()
467                            .gap_2()
468                            .child(
469                        Button::new("learn_more", "Learn more")
470                            .style(ButtonStyle::Subtle)
471                            .on_click(cx.listener(|_, _, cx| {
472                                cx.open_url(ZED_AI_URL)
473                            })))
474                            .child(
475                        Button::new("upgrade", "Upgrade")
476                            .style(ButtonStyle::Subtle)
477                            .color(Color::Accent)
478                            .on_click(cx.listener(|_, _, cx| {
479                                cx.open_url(ACCOUNT_SETTINGS_URL)
480                            })))
481                    },
482                )
483        } else {
484            v_flex()
485                .gap_6()
486                .child(Label::new("Use the zed.dev to access language models."))
487                .child(
488                    v_flex()
489                        .gap_2()
490                        .child(
491                            Button::new("sign_in", "Sign in")
492                                .icon_color(Color::Muted)
493                                .icon(IconName::Github)
494                                .icon_position(IconPosition::Start)
495                                .style(ButtonStyle::Filled)
496                                .full_width()
497                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
498                        )
499                        .child(
500                            div().flex().w_full().items_center().child(
501                                Label::new("Sign in to enable collaboration.")
502                                    .color(Color::Muted)
503                                    .size(LabelSize::Small),
504                            ),
505                        ),
506                )
507        }
508    }
509}