1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
4 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest, RateLimiter,
6};
7use anyhow::{anyhow, Context as _, Result};
8use client::Client;
9use collections::BTreeMap;
10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
11use gpui::{AnyView, AppContext, AsyncAppContext, FocusHandle, ModelContext, Subscription, Task};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use settings::{Settings, SettingsStore};
15use std::{future, sync::Arc};
16use strum::IntoEnumIterator;
17use ui::prelude::*;
18
19use crate::LanguageModelProvider;
20
21use super::anthropic::count_anthropic_tokens;
22
23pub const PROVIDER_ID: &str = "zed.dev";
24pub const PROVIDER_NAME: &str = "Zed AI";
25
26#[derive(Default, Clone, Debug, PartialEq)]
27pub struct ZedDotDevSettings {
28 pub available_models: Vec<AvailableModel>,
29}
30
31#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
32#[serde(rename_all = "lowercase")]
33pub enum AvailableProvider {
34 Anthropic,
35 OpenAi,
36 Google,
37}
38
39#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
40pub struct AvailableModel {
41 provider: AvailableProvider,
42 name: String,
43 max_tokens: usize,
44 tool_override: Option<String>,
45}
46
47pub struct CloudLanguageModelProvider {
48 client: Arc<Client>,
49 state: gpui::Model<State>,
50 _maintain_client_status: Task<()>,
51}
52
53pub struct State {
54 client: Arc<Client>,
55 status: client::Status,
56 _subscription: Subscription,
57}
58
59impl State {
60 fn is_connected(&self) -> bool {
61 self.status.is_connected()
62 }
63
64 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
65 let client = self.client.clone();
66 cx.spawn(move |this, mut cx| async move {
67 client.authenticate_and_connect(true, &cx).await?;
68 this.update(&mut cx, |_, cx| cx.notify())
69 })
70 }
71}
72
73impl CloudLanguageModelProvider {
74 pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
75 let mut status_rx = client.status();
76 let status = *status_rx.borrow();
77
78 let state = cx.new_model(|cx| State {
79 client: client.clone(),
80 status,
81 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
82 cx.notify();
83 }),
84 });
85
86 let state_ref = state.downgrade();
87 let maintain_client_status = cx.spawn(|mut cx| async move {
88 while let Some(status) = status_rx.next().await {
89 if let Some(this) = state_ref.upgrade() {
90 _ = this.update(&mut cx, |this, cx| {
91 if this.status != status {
92 this.status = status;
93 cx.notify();
94 }
95 });
96 } else {
97 break;
98 }
99 }
100 });
101
102 Self {
103 client,
104 state,
105 _maintain_client_status: maintain_client_status,
106 }
107 }
108}
109
110impl LanguageModelProviderState for CloudLanguageModelProvider {
111 type ObservableEntity = State;
112
113 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
114 Some(self.state.clone())
115 }
116}
117
118impl LanguageModelProvider for CloudLanguageModelProvider {
119 fn id(&self) -> LanguageModelProviderId {
120 LanguageModelProviderId(PROVIDER_ID.into())
121 }
122
123 fn name(&self) -> LanguageModelProviderName {
124 LanguageModelProviderName(PROVIDER_NAME.into())
125 }
126
127 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
128 let mut models = BTreeMap::default();
129
130 for model in anthropic::Model::iter() {
131 if !matches!(model, anthropic::Model::Custom { .. }) {
132 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
133 }
134 }
135 for model in open_ai::Model::iter() {
136 if !matches!(model, open_ai::Model::Custom { .. }) {
137 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
138 }
139 }
140 for model in google_ai::Model::iter() {
141 if !matches!(model, google_ai::Model::Custom { .. }) {
142 models.insert(model.id().to_string(), CloudModel::Google(model));
143 }
144 }
145
146 // Override with available models from settings
147 for model in &AllLanguageModelSettings::get_global(cx)
148 .zed_dot_dev
149 .available_models
150 {
151 let model = match model.provider {
152 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
153 name: model.name.clone(),
154 max_tokens: model.max_tokens,
155 tool_override: model.tool_override.clone(),
156 }),
157 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
158 name: model.name.clone(),
159 max_tokens: model.max_tokens,
160 }),
161 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
162 name: model.name.clone(),
163 max_tokens: model.max_tokens,
164 }),
165 };
166 models.insert(model.id().to_string(), model.clone());
167 }
168
169 models
170 .into_values()
171 .map(|model| {
172 Arc::new(CloudLanguageModel {
173 id: LanguageModelId::from(model.id().to_string()),
174 model,
175 client: self.client.clone(),
176 request_limiter: RateLimiter::new(4),
177 }) as Arc<dyn LanguageModel>
178 })
179 .collect()
180 }
181
182 fn is_authenticated(&self, cx: &AppContext) -> bool {
183 self.state.read(cx).status.is_connected()
184 }
185
186 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
187 Task::ready(Ok(()))
188 }
189
190 fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
191 let view = cx
192 .new_view(|_cx| ConfigurationView {
193 state: self.state.clone(),
194 })
195 .into();
196 (view, None)
197 }
198
199 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
200 Task::ready(Ok(()))
201 }
202}
203
204pub struct CloudLanguageModel {
205 id: LanguageModelId,
206 model: CloudModel,
207 client: Arc<Client>,
208 request_limiter: RateLimiter,
209}
210
211impl LanguageModel for CloudLanguageModel {
212 fn id(&self) -> LanguageModelId {
213 self.id.clone()
214 }
215
216 fn name(&self) -> LanguageModelName {
217 LanguageModelName::from(self.model.display_name().to_string())
218 }
219
220 fn provider_id(&self) -> LanguageModelProviderId {
221 LanguageModelProviderId(PROVIDER_ID.into())
222 }
223
224 fn provider_name(&self) -> LanguageModelProviderName {
225 LanguageModelProviderName(PROVIDER_NAME.into())
226 }
227
228 fn telemetry_id(&self) -> String {
229 format!("zed.dev/{}", self.model.id())
230 }
231
232 fn max_token_count(&self) -> usize {
233 self.model.max_token_count()
234 }
235
236 fn count_tokens(
237 &self,
238 request: LanguageModelRequest,
239 cx: &AppContext,
240 ) -> BoxFuture<'static, Result<usize>> {
241 match self.model.clone() {
242 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
243 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
244 CloudModel::Google(model) => {
245 let client = self.client.clone();
246 let request = request.into_google(model.id().into());
247 let request = google_ai::CountTokensRequest {
248 contents: request.contents,
249 };
250 async move {
251 let request = serde_json::to_string(&request)?;
252 let response = client
253 .request(proto::CountLanguageModelTokens {
254 provider: proto::LanguageModelProvider::Google as i32,
255 request,
256 })
257 .await?;
258 Ok(response.token_count as usize)
259 }
260 .boxed()
261 }
262 }
263 }
264
265 fn stream_completion(
266 &self,
267 request: LanguageModelRequest,
268 _: &AsyncAppContext,
269 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
270 match &self.model {
271 CloudModel::Anthropic(model) => {
272 let client = self.client.clone();
273 let request = request.into_anthropic(model.id().into());
274 let future = self.request_limiter.stream(async move {
275 let request = serde_json::to_string(&request)?;
276 let stream = client
277 .request_stream(proto::StreamCompleteWithLanguageModel {
278 provider: proto::LanguageModelProvider::Anthropic as i32,
279 request,
280 })
281 .await?;
282 Ok(anthropic::extract_text_from_events(
283 stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
284 ))
285 });
286 async move { Ok(future.await?.boxed()) }.boxed()
287 }
288 CloudModel::OpenAi(model) => {
289 let client = self.client.clone();
290 let request = request.into_open_ai(model.id().into());
291 let future = self.request_limiter.stream(async move {
292 let request = serde_json::to_string(&request)?;
293 let stream = client
294 .request_stream(proto::StreamCompleteWithLanguageModel {
295 provider: proto::LanguageModelProvider::OpenAi as i32,
296 request,
297 })
298 .await?;
299 Ok(open_ai::extract_text_from_events(
300 stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
301 ))
302 });
303 async move { Ok(future.await?.boxed()) }.boxed()
304 }
305 CloudModel::Google(model) => {
306 let client = self.client.clone();
307 let request = request.into_google(model.id().into());
308 let future = self.request_limiter.stream(async move {
309 let request = serde_json::to_string(&request)?;
310 let stream = client
311 .request_stream(proto::StreamCompleteWithLanguageModel {
312 provider: proto::LanguageModelProvider::Google as i32,
313 request,
314 })
315 .await?;
316 Ok(google_ai::extract_text_from_events(
317 stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
318 ))
319 });
320 async move { Ok(future.await?.boxed()) }.boxed()
321 }
322 }
323 }
324
325 fn use_any_tool(
326 &self,
327 request: LanguageModelRequest,
328 tool_name: String,
329 tool_description: String,
330 input_schema: serde_json::Value,
331 _cx: &AsyncAppContext,
332 ) -> BoxFuture<'static, Result<serde_json::Value>> {
333 match &self.model {
334 CloudModel::Anthropic(model) => {
335 let client = self.client.clone();
336 let mut request = request.into_anthropic(model.tool_model_id().into());
337 request.tool_choice = Some(anthropic::ToolChoice::Tool {
338 name: tool_name.clone(),
339 });
340 request.tools = vec![anthropic::Tool {
341 name: tool_name.clone(),
342 description: tool_description,
343 input_schema,
344 }];
345
346 self.request_limiter
347 .run(async move {
348 let request = serde_json::to_string(&request)?;
349 let response = client
350 .request(proto::CompleteWithLanguageModel {
351 provider: proto::LanguageModelProvider::Anthropic as i32,
352 request,
353 })
354 .await?;
355 let response: anthropic::Response =
356 serde_json::from_str(&response.completion)?;
357 response
358 .content
359 .into_iter()
360 .find_map(|content| {
361 if let anthropic::Content::ToolUse { name, input, .. } = content {
362 if name == tool_name {
363 Some(input)
364 } else {
365 None
366 }
367 } else {
368 None
369 }
370 })
371 .context("tool not used")
372 })
373 .boxed()
374 }
375 CloudModel::OpenAi(_) => {
376 future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
377 }
378 CloudModel::Google(_) => {
379 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
380 }
381 }
382 }
383}
384
385struct ConfigurationView {
386 state: gpui::Model<State>,
387}
388
389impl ConfigurationView {
390 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
391 self.state.update(cx, |state, cx| {
392 state.authenticate(cx).detach_and_log_err(cx);
393 });
394 cx.notify();
395 }
396}
397
398impl Render for ConfigurationView {
399 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
400 const ZED_AI_URL: &str = "https://zed.dev/ai";
401 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/settings";
402
403 let is_connected = self.state.read(cx).is_connected();
404
405 let is_pro = false;
406
407 if is_connected {
408 v_flex()
409 .gap_3()
410 .max_w_4_5()
411 .child(Label::new(
412 if is_pro {
413 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google through Zed Pro."
414 } else {
415 "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
416 }))
417 .child(
418 if is_pro {
419 h_flex().child(
420 Button::new("manage_settings", "Manage Subscription")
421 .style(ButtonStyle::Filled)
422 .on_click(cx.listener(|_, _, cx| {
423 cx.open_url(ACCOUNT_SETTINGS_URL)
424 })))
425 } else {
426 h_flex()
427 .gap_2()
428 .child(
429 Button::new("learn_more", "Learn more")
430 .style(ButtonStyle::Subtle)
431 .on_click(cx.listener(|_, _, cx| {
432 cx.open_url(ZED_AI_URL)
433 })))
434 .child(
435 Button::new("upgrade", "Upgrade")
436 .style(ButtonStyle::Subtle)
437 .color(Color::Accent)
438 .on_click(cx.listener(|_, _, cx| {
439 cx.open_url(ACCOUNT_SETTINGS_URL)
440 })))
441 },
442 )
443 } else {
444 v_flex()
445 .gap_6()
446 .child(Label::new("Use the zed.dev to access language models."))
447 .child(
448 v_flex()
449 .gap_2()
450 .child(
451 Button::new("sign_in", "Sign in")
452 .icon_color(Color::Muted)
453 .icon(IconName::Github)
454 .icon_position(IconPosition::Start)
455 .style(ButtonStyle::Filled)
456 .full_width()
457 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
458 )
459 .child(
460 div().flex().w_full().items_center().child(
461 Label::new("Sign in to enable collaboration.")
462 .color(Color::Muted)
463 .size(LabelSize::Small),
464 ),
465 ),
466 )
467 }
468 }
469}