1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
4 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest,
6};
7use anyhow::Result;
8use client::Client;
9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
10use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
11use settings::{Settings, SettingsStore};
12use std::{collections::BTreeMap, sync::Arc};
13use strum::IntoEnumIterator;
14use ui::prelude::*;
15
16use crate::LanguageModelProvider;
17
18use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
19
20pub const PROVIDER_ID: &str = "zed.dev";
21pub const PROVIDER_NAME: &str = "zed.dev";
22
23#[derive(Default, Clone, Debug, PartialEq)]
24pub struct ZedDotDevSettings {
25 pub available_models: Vec<CloudModel>,
26}
27
28pub struct CloudLanguageModelProvider {
29 client: Arc<Client>,
30 state: gpui::Model<State>,
31 _maintain_client_status: Task<()>,
32}
33
34struct State {
35 client: Arc<Client>,
36 status: client::Status,
37 _subscription: Subscription,
38}
39
40impl State {
41 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
42 let client = self.client.clone();
43 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
44 }
45}
46
47impl CloudLanguageModelProvider {
48 pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
49 let mut status_rx = client.status();
50 let status = *status_rx.borrow();
51
52 let state = cx.new_model(|cx| State {
53 client: client.clone(),
54 status,
55 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
56 cx.notify();
57 }),
58 });
59
60 let state_ref = state.downgrade();
61 let maintain_client_status = cx.spawn(|mut cx| async move {
62 while let Some(status) = status_rx.next().await {
63 if let Some(this) = state_ref.upgrade() {
64 _ = this.update(&mut cx, |this, cx| {
65 this.status = status;
66 cx.notify();
67 });
68 } else {
69 break;
70 }
71 }
72 });
73
74 Self {
75 client,
76 state,
77 _maintain_client_status: maintain_client_status,
78 }
79 }
80}
81
82impl LanguageModelProviderState for CloudLanguageModelProvider {
83 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
84 Some(cx.observe(&self.state, |_, _, cx| {
85 cx.notify();
86 }))
87 }
88}
89
90impl LanguageModelProvider for CloudLanguageModelProvider {
91 fn id(&self) -> LanguageModelProviderId {
92 LanguageModelProviderId(PROVIDER_ID.into())
93 }
94
95 fn name(&self) -> LanguageModelProviderName {
96 LanguageModelProviderName(PROVIDER_NAME.into())
97 }
98
99 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
100 let mut models = BTreeMap::default();
101
102 // Add base models from CloudModel::iter()
103 for model in CloudModel::iter() {
104 if !matches!(model, CloudModel::Custom { .. }) {
105 models.insert(model.id().to_string(), model);
106 }
107 }
108
109 // Override with available models from settings
110 for model in &AllLanguageModelSettings::get_global(cx)
111 .zed_dot_dev
112 .available_models
113 {
114 models.insert(model.id().to_string(), model.clone());
115 }
116
117 models
118 .into_values()
119 .map(|model| {
120 Arc::new(CloudLanguageModel {
121 id: LanguageModelId::from(model.id().to_string()),
122 model,
123 client: self.client.clone(),
124 }) as Arc<dyn LanguageModel>
125 })
126 .collect()
127 }
128
129 fn is_authenticated(&self, cx: &AppContext) -> bool {
130 self.state.read(cx).status.is_connected()
131 }
132
133 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
134 self.state.read(cx).authenticate(cx)
135 }
136
137 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
138 cx.new_view(|_cx| AuthenticationPrompt {
139 state: self.state.clone(),
140 })
141 .into()
142 }
143
144 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
145 Task::ready(Ok(()))
146 }
147}
148
149pub struct CloudLanguageModel {
150 id: LanguageModelId,
151 model: CloudModel,
152 client: Arc<Client>,
153}
154
155impl LanguageModel for CloudLanguageModel {
156 fn id(&self) -> LanguageModelId {
157 self.id.clone()
158 }
159
160 fn name(&self) -> LanguageModelName {
161 LanguageModelName::from(self.model.display_name().to_string())
162 }
163
164 fn provider_id(&self) -> LanguageModelProviderId {
165 LanguageModelProviderId(PROVIDER_ID.into())
166 }
167
168 fn provider_name(&self) -> LanguageModelProviderName {
169 LanguageModelProviderName(PROVIDER_NAME.into())
170 }
171
172 fn telemetry_id(&self) -> String {
173 format!("zed.dev/{}", self.model.id())
174 }
175
176 fn max_token_count(&self) -> usize {
177 self.model.max_token_count()
178 }
179
180 fn count_tokens(
181 &self,
182 request: LanguageModelRequest,
183 cx: &AppContext,
184 ) -> BoxFuture<'static, Result<usize>> {
185 match &self.model {
186 CloudModel::Gpt3Point5Turbo => {
187 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
188 }
189 CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
190 CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
191 CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
192 CloudModel::Gpt4OmniMini => {
193 count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
194 }
195 CloudModel::Claude3_5Sonnet
196 | CloudModel::Claude3Opus
197 | CloudModel::Claude3Sonnet
198 | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
199 CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
200 count_anthropic_tokens(request, cx)
201 }
202 _ => {
203 let request = self.client.request(proto::CountTokensWithLanguageModel {
204 model: self.model.id().to_string(),
205 messages: request
206 .messages
207 .iter()
208 .map(|message| message.to_proto())
209 .collect(),
210 });
211 async move {
212 let response = request.await?;
213 Ok(response.token_count as usize)
214 }
215 .boxed()
216 }
217 }
218 }
219
220 fn stream_completion(
221 &self,
222 mut request: LanguageModelRequest,
223 _: &AsyncAppContext,
224 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
225 match &self.model {
226 CloudModel::Claude3Opus
227 | CloudModel::Claude3Sonnet
228 | CloudModel::Claude3Haiku
229 | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
230 CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
231 preprocess_anthropic_request(&mut request)
232 }
233 _ => {}
234 }
235
236 let request = proto::CompleteWithLanguageModel {
237 model: self.id.0.to_string(),
238 messages: request
239 .messages
240 .iter()
241 .map(|message| message.to_proto())
242 .collect(),
243 stop: request.stop,
244 temperature: request.temperature,
245 tools: Vec::new(),
246 tool_choice: None,
247 };
248
249 self.client
250 .request_stream(request)
251 .map_ok(|stream| {
252 stream
253 .filter_map(|response| async move {
254 match response {
255 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
256 Err(error) => Some(Err(error)),
257 }
258 })
259 .boxed()
260 })
261 .boxed()
262 }
263}
264
265struct AuthenticationPrompt {
266 state: gpui::Model<State>,
267}
268
269impl Render for AuthenticationPrompt {
270 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
271 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
272
273 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
274 v_flex()
275 .gap_2()
276 .child(
277 Button::new("sign_in", "Sign in")
278 .icon_color(Color::Muted)
279 .icon(IconName::Github)
280 .icon_position(IconPosition::Start)
281 .style(ButtonStyle::Filled)
282 .full_width()
283 .on_click(cx.listener(move |this, _, cx| {
284 this.state.update(cx, |provider, cx| {
285 provider.authenticate(cx).detach_and_log_err(cx);
286 cx.notify();
287 });
288 })),
289 )
290 .child(
291 div().flex().w_full().items_center().child(
292 Label::new("Sign in to enable collaboration.")
293 .color(Color::Muted)
294 .size(LabelSize::Small),
295 ),
296 ),
297 )
298 }
299}