1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
4 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest,
6};
7use anyhow::Result;
8use client::Client;
9use collections::BTreeMap;
10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
11use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
12use settings::{Settings, SettingsStore};
13use std::sync::Arc;
14use strum::IntoEnumIterator;
15use ui::prelude::*;
16
17use crate::LanguageModelProvider;
18
19use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
20
21pub const PROVIDER_ID: &str = "zed.dev";
22pub const PROVIDER_NAME: &str = "zed.dev";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct ZedDotDevSettings {
26 pub available_models: Vec<CloudModel>,
27}
28
29pub struct CloudLanguageModelProvider {
30 client: Arc<Client>,
31 state: gpui::Model<State>,
32 _maintain_client_status: Task<()>,
33}
34
35struct State {
36 client: Arc<Client>,
37 status: client::Status,
38 _subscription: Subscription,
39}
40
41impl State {
42 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
43 let client = self.client.clone();
44 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
45 }
46}
47
48impl CloudLanguageModelProvider {
49 pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
50 let mut status_rx = client.status();
51 let status = *status_rx.borrow();
52
53 let state = cx.new_model(|cx| State {
54 client: client.clone(),
55 status,
56 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
57 cx.notify();
58 }),
59 });
60
61 let state_ref = state.downgrade();
62 let maintain_client_status = cx.spawn(|mut cx| async move {
63 while let Some(status) = status_rx.next().await {
64 if let Some(this) = state_ref.upgrade() {
65 _ = this.update(&mut cx, |this, cx| {
66 this.status = status;
67 cx.notify();
68 });
69 } else {
70 break;
71 }
72 }
73 });
74
75 Self {
76 client,
77 state,
78 _maintain_client_status: maintain_client_status,
79 }
80 }
81}
82
83impl LanguageModelProviderState for CloudLanguageModelProvider {
84 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
85 Some(cx.observe(&self.state, |_, _, cx| {
86 cx.notify();
87 }))
88 }
89}
90
91impl LanguageModelProvider for CloudLanguageModelProvider {
92 fn id(&self) -> LanguageModelProviderId {
93 LanguageModelProviderId(PROVIDER_ID.into())
94 }
95
96 fn name(&self) -> LanguageModelProviderName {
97 LanguageModelProviderName(PROVIDER_NAME.into())
98 }
99
100 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
101 let mut models = BTreeMap::default();
102
103 // Add base models from CloudModel::iter()
104 for model in CloudModel::iter() {
105 if !matches!(model, CloudModel::Custom { .. }) {
106 models.insert(model.id().to_string(), model);
107 }
108 }
109
110 // Override with available models from settings
111 for model in &AllLanguageModelSettings::get_global(cx)
112 .zed_dot_dev
113 .available_models
114 {
115 models.insert(model.id().to_string(), model.clone());
116 }
117
118 models
119 .into_values()
120 .map(|model| {
121 Arc::new(CloudLanguageModel {
122 id: LanguageModelId::from(model.id().to_string()),
123 model,
124 client: self.client.clone(),
125 }) as Arc<dyn LanguageModel>
126 })
127 .collect()
128 }
129
130 fn is_authenticated(&self, cx: &AppContext) -> bool {
131 self.state.read(cx).status.is_connected()
132 }
133
134 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
135 self.state.read(cx).authenticate(cx)
136 }
137
138 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
139 cx.new_view(|_cx| AuthenticationPrompt {
140 state: self.state.clone(),
141 })
142 .into()
143 }
144
145 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
146 Task::ready(Ok(()))
147 }
148}
149
150pub struct CloudLanguageModel {
151 id: LanguageModelId,
152 model: CloudModel,
153 client: Arc<Client>,
154}
155
156impl LanguageModel for CloudLanguageModel {
157 fn id(&self) -> LanguageModelId {
158 self.id.clone()
159 }
160
161 fn name(&self) -> LanguageModelName {
162 LanguageModelName::from(self.model.display_name().to_string())
163 }
164
165 fn provider_id(&self) -> LanguageModelProviderId {
166 LanguageModelProviderId(PROVIDER_ID.into())
167 }
168
169 fn provider_name(&self) -> LanguageModelProviderName {
170 LanguageModelProviderName(PROVIDER_NAME.into())
171 }
172
173 fn telemetry_id(&self) -> String {
174 format!("zed.dev/{}", self.model.id())
175 }
176
177 fn max_token_count(&self) -> usize {
178 self.model.max_token_count()
179 }
180
181 fn count_tokens(
182 &self,
183 request: LanguageModelRequest,
184 cx: &AppContext,
185 ) -> BoxFuture<'static, Result<usize>> {
186 match &self.model {
187 CloudModel::Gpt3Point5Turbo => {
188 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
189 }
190 CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
191 CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
192 CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
193 CloudModel::Gpt4OmniMini => {
194 count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
195 }
196 CloudModel::Claude3_5Sonnet
197 | CloudModel::Claude3Opus
198 | CloudModel::Claude3Sonnet
199 | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
200 CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
201 count_anthropic_tokens(request, cx)
202 }
203 _ => {
204 let request = self.client.request(proto::CountTokensWithLanguageModel {
205 model: self.model.id().to_string(),
206 messages: request
207 .messages
208 .iter()
209 .map(|message| message.to_proto())
210 .collect(),
211 });
212 async move {
213 let response = request.await?;
214 Ok(response.token_count as usize)
215 }
216 .boxed()
217 }
218 }
219 }
220
221 fn stream_completion(
222 &self,
223 mut request: LanguageModelRequest,
224 _: &AsyncAppContext,
225 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
226 match &self.model {
227 CloudModel::Claude3Opus
228 | CloudModel::Claude3Sonnet
229 | CloudModel::Claude3Haiku
230 | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
231 CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
232 preprocess_anthropic_request(&mut request)
233 }
234 _ => {}
235 }
236
237 let request = proto::CompleteWithLanguageModel {
238 model: self.id.0.to_string(),
239 messages: request
240 .messages
241 .iter()
242 .map(|message| message.to_proto())
243 .collect(),
244 stop: request.stop,
245 temperature: request.temperature,
246 tools: Vec::new(),
247 tool_choice: None,
248 };
249
250 self.client
251 .request_stream(request)
252 .map_ok(|stream| {
253 stream
254 .filter_map(|response| async move {
255 match response {
256 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
257 Err(error) => Some(Err(error)),
258 }
259 })
260 .boxed()
261 })
262 .boxed()
263 }
264}
265
266struct AuthenticationPrompt {
267 state: gpui::Model<State>,
268}
269
270impl Render for AuthenticationPrompt {
271 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
272 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
273
274 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
275 v_flex()
276 .gap_2()
277 .child(
278 Button::new("sign_in", "Sign in")
279 .icon_color(Color::Muted)
280 .icon(IconName::Github)
281 .icon_position(IconPosition::Start)
282 .style(ButtonStyle::Filled)
283 .full_width()
284 .on_click(cx.listener(move |this, _, cx| {
285 this.state.update(cx, |provider, cx| {
286 provider.authenticate(cx).detach_and_log_err(cx);
287 cx.notify();
288 });
289 })),
290 )
291 .child(
292 div().flex().w_full().items_center().child(
293 Label::new("Sign in to enable collaboration.")
294 .color(Color::Muted)
295 .size(LabelSize::Small),
296 ),
297 ),
298 )
299 }
300}