cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest,
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use client::Client;
  9use collections::BTreeMap;
 10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 11use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::{Settings, SettingsStore};
 15use std::{future, sync::Arc};
 16use strum::IntoEnumIterator;
 17use ui::prelude::*;
 18
 19use crate::LanguageModelProvider;
 20
 21use super::anthropic::count_anthropic_tokens;
 22
 23pub const PROVIDER_ID: &str = "zed.dev";
 24pub const PROVIDER_NAME: &str = "zed.dev";
 25
 26#[derive(Default, Clone, Debug, PartialEq)]
 27pub struct ZedDotDevSettings {
 28    pub available_models: Vec<AvailableModel>,
 29}
 30
 31#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 32#[serde(rename_all = "lowercase")]
 33pub enum AvailableProvider {
 34    Anthropic,
 35    OpenAi,
 36    Google,
 37}
 38
 39#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 40pub struct AvailableModel {
 41    provider: AvailableProvider,
 42    name: String,
 43    max_tokens: usize,
 44}
 45
 46pub struct CloudLanguageModelProvider {
 47    client: Arc<Client>,
 48    state: gpui::Model<State>,
 49    _maintain_client_status: Task<()>,
 50}
 51
 52struct State {
 53    client: Arc<Client>,
 54    status: client::Status,
 55    _subscription: Subscription,
 56}
 57
 58impl State {
 59    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 60        let client = self.client.clone();
 61        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 62    }
 63}
 64
 65impl CloudLanguageModelProvider {
 66    pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
 67        let mut status_rx = client.status();
 68        let status = *status_rx.borrow();
 69
 70        let state = cx.new_model(|cx| State {
 71            client: client.clone(),
 72            status,
 73            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 74                cx.notify();
 75            }),
 76        });
 77
 78        let state_ref = state.downgrade();
 79        let maintain_client_status = cx.spawn(|mut cx| async move {
 80            while let Some(status) = status_rx.next().await {
 81                if let Some(this) = state_ref.upgrade() {
 82                    _ = this.update(&mut cx, |this, cx| {
 83                        this.status = status;
 84                        cx.notify();
 85                    });
 86                } else {
 87                    break;
 88                }
 89            }
 90        });
 91
 92        Self {
 93            client,
 94            state,
 95            _maintain_client_status: maintain_client_status,
 96        }
 97    }
 98}
 99
100impl LanguageModelProviderState for CloudLanguageModelProvider {
101    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
102        Some(cx.observe(&self.state, |_, _, cx| {
103            cx.notify();
104        }))
105    }
106}
107
108impl LanguageModelProvider for CloudLanguageModelProvider {
109    fn id(&self) -> LanguageModelProviderId {
110        LanguageModelProviderId(PROVIDER_ID.into())
111    }
112
113    fn name(&self) -> LanguageModelProviderName {
114        LanguageModelProviderName(PROVIDER_NAME.into())
115    }
116
117    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
118        let mut models = BTreeMap::default();
119
120        for model in anthropic::Model::iter() {
121            if !matches!(model, anthropic::Model::Custom { .. }) {
122                models.insert(model.id().to_string(), CloudModel::Anthropic(model));
123            }
124        }
125        for model in open_ai::Model::iter() {
126            if !matches!(model, open_ai::Model::Custom { .. }) {
127                models.insert(model.id().to_string(), CloudModel::OpenAi(model));
128            }
129        }
130        for model in google_ai::Model::iter() {
131            if !matches!(model, google_ai::Model::Custom { .. }) {
132                models.insert(model.id().to_string(), CloudModel::Google(model));
133            }
134        }
135
136        // Override with available models from settings
137        for model in &AllLanguageModelSettings::get_global(cx)
138            .zed_dot_dev
139            .available_models
140        {
141            let model = match model.provider {
142                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
143                    name: model.name.clone(),
144                    max_tokens: model.max_tokens,
145                }),
146                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
147                    name: model.name.clone(),
148                    max_tokens: model.max_tokens,
149                }),
150                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
151                    name: model.name.clone(),
152                    max_tokens: model.max_tokens,
153                }),
154            };
155            models.insert(model.id().to_string(), model.clone());
156        }
157
158        models
159            .into_values()
160            .map(|model| {
161                Arc::new(CloudLanguageModel {
162                    id: LanguageModelId::from(model.id().to_string()),
163                    model,
164                    client: self.client.clone(),
165                }) as Arc<dyn LanguageModel>
166            })
167            .collect()
168    }
169
170    fn is_authenticated(&self, cx: &AppContext) -> bool {
171        self.state.read(cx).status.is_connected()
172    }
173
174    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
175        self.state.read(cx).authenticate(cx)
176    }
177
178    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
179        cx.new_view(|_cx| AuthenticationPrompt {
180            state: self.state.clone(),
181        })
182        .into()
183    }
184
185    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
186        Task::ready(Ok(()))
187    }
188}
189
190pub struct CloudLanguageModel {
191    id: LanguageModelId,
192    model: CloudModel,
193    client: Arc<Client>,
194}
195
196impl LanguageModel for CloudLanguageModel {
197    fn id(&self) -> LanguageModelId {
198        self.id.clone()
199    }
200
201    fn name(&self) -> LanguageModelName {
202        LanguageModelName::from(self.model.display_name().to_string())
203    }
204
205    fn provider_id(&self) -> LanguageModelProviderId {
206        LanguageModelProviderId(PROVIDER_ID.into())
207    }
208
209    fn provider_name(&self) -> LanguageModelProviderName {
210        LanguageModelProviderName(PROVIDER_NAME.into())
211    }
212
213    fn telemetry_id(&self) -> String {
214        format!("zed.dev/{}", self.model.id())
215    }
216
217    fn max_token_count(&self) -> usize {
218        self.model.max_token_count()
219    }
220
221    fn count_tokens(
222        &self,
223        request: LanguageModelRequest,
224        cx: &AppContext,
225    ) -> BoxFuture<'static, Result<usize>> {
226        match self.model.clone() {
227            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
228            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
229            CloudModel::Google(model) => {
230                let client = self.client.clone();
231                let request = request.into_google(model.id().into());
232                let request = google_ai::CountTokensRequest {
233                    contents: request.contents,
234                };
235                async move {
236                    let request = serde_json::to_string(&request)?;
237                    let response = client
238                        .request(proto::CountLanguageModelTokens {
239                            provider: proto::LanguageModelProvider::Google as i32,
240                            request,
241                        })
242                        .await?;
243                    Ok(response.token_count as usize)
244                }
245                .boxed()
246            }
247        }
248    }
249
250    fn stream_completion(
251        &self,
252        request: LanguageModelRequest,
253        _: &AsyncAppContext,
254    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
255        match &self.model {
256            CloudModel::Anthropic(model) => {
257                let client = self.client.clone();
258                let request = request.into_anthropic(model.id().into());
259                async move {
260                    let request = serde_json::to_string(&request)?;
261                    let stream = client
262                        .request_stream(proto::StreamCompleteWithLanguageModel {
263                            provider: proto::LanguageModelProvider::Anthropic as i32,
264                            request,
265                        })
266                        .await?;
267                    Ok(anthropic::extract_text_from_events(
268                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
269                    )
270                    .boxed())
271                }
272                .boxed()
273            }
274            CloudModel::OpenAi(model) => {
275                let client = self.client.clone();
276                let request = request.into_open_ai(model.id().into());
277                async move {
278                    let request = serde_json::to_string(&request)?;
279                    let stream = client
280                        .request_stream(proto::StreamCompleteWithLanguageModel {
281                            provider: proto::LanguageModelProvider::OpenAi as i32,
282                            request,
283                        })
284                        .await?;
285                    Ok(open_ai::extract_text_from_events(
286                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
287                    )
288                    .boxed())
289                }
290                .boxed()
291            }
292            CloudModel::Google(model) => {
293                let client = self.client.clone();
294                let request = request.into_google(model.id().into());
295                async move {
296                    let request = serde_json::to_string(&request)?;
297                    let stream = client
298                        .request_stream(proto::StreamCompleteWithLanguageModel {
299                            provider: proto::LanguageModelProvider::Google as i32,
300                            request,
301                        })
302                        .await?;
303                    Ok(google_ai::extract_text_from_events(
304                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
305                    )
306                    .boxed())
307                }
308                .boxed()
309            }
310        }
311    }
312
313    fn use_tool(
314        &self,
315        request: LanguageModelRequest,
316        tool_name: String,
317        tool_description: String,
318        input_schema: serde_json::Value,
319        _cx: &AsyncAppContext,
320    ) -> BoxFuture<'static, Result<serde_json::Value>> {
321        match &self.model {
322            CloudModel::Anthropic(model) => {
323                let client = self.client.clone();
324                let mut request = request.into_anthropic(model.id().into());
325                request.tool_choice = Some(anthropic::ToolChoice::Tool {
326                    name: tool_name.clone(),
327                });
328                request.tools = vec![anthropic::Tool {
329                    name: tool_name.clone(),
330                    description: tool_description,
331                    input_schema,
332                }];
333
334                async move {
335                    let request = serde_json::to_string(&request)?;
336                    let response = client
337                        .request(proto::CompleteWithLanguageModel {
338                            provider: proto::LanguageModelProvider::Anthropic as i32,
339                            request,
340                        })
341                        .await?;
342                    let response: anthropic::Response = serde_json::from_str(&response.completion)?;
343                    response
344                        .content
345                        .into_iter()
346                        .find_map(|content| {
347                            if let anthropic::Content::ToolUse { name, input, .. } = content {
348                                if name == tool_name {
349                                    Some(input)
350                                } else {
351                                    None
352                                }
353                            } else {
354                                None
355                            }
356                        })
357                        .context("tool not used")
358                }
359                .boxed()
360            }
361            CloudModel::OpenAi(_) => {
362                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
363            }
364            CloudModel::Google(_) => {
365                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
366            }
367        }
368    }
369}
370
371struct AuthenticationPrompt {
372    state: gpui::Model<State>,
373}
374
375impl Render for AuthenticationPrompt {
376    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
377        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
378
379        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
380            v_flex()
381                .gap_2()
382                .child(
383                    Button::new("sign_in", "Sign in")
384                        .icon_color(Color::Muted)
385                        .icon(IconName::Github)
386                        .icon_position(IconPosition::Start)
387                        .style(ButtonStyle::Filled)
388                        .full_width()
389                        .on_click(cx.listener(move |this, _, cx| {
390                            this.state.update(cx, |provider, cx| {
391                                provider.authenticate(cx).detach_and_log_err(cx);
392                                cx.notify();
393                            });
394                        })),
395                )
396                .child(
397                    div().flex().w_full().items_center().child(
398                        Label::new("Sign in to enable collaboration.")
399                            .color(Color::Muted)
400                            .size(LabelSize::Small),
401                    ),
402                ),
403        )
404    }
405}