1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
4 LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
5};
6use anyhow::Result;
7use client::Client;
8use collections::HashMap;
9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
10use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
11use settings::{Settings, SettingsStore};
12use std::sync::Arc;
13use strum::IntoEnumIterator;
14use ui::prelude::*;
15
16use crate::LanguageModelProvider;
17
18use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
19
20pub const PROVIDER_NAME: &str = "zed.dev";
21
22#[derive(Default, Clone, Debug, PartialEq)]
23pub struct ZedDotDevSettings {
24 pub available_models: Vec<CloudModel>,
25}
26
27pub struct CloudLanguageModelProvider {
28 client: Arc<Client>,
29 state: gpui::Model<State>,
30 _maintain_client_status: Task<()>,
31}
32
33struct State {
34 client: Arc<Client>,
35 status: client::Status,
36 settings: ZedDotDevSettings,
37 _subscription: Subscription,
38}
39
40impl State {
41 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
42 let client = self.client.clone();
43 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
44 }
45}
46
47impl CloudLanguageModelProvider {
48 pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
49 let mut status_rx = client.status();
50 let status = *status_rx.borrow();
51
52 let state = cx.new_model(|cx| State {
53 client: client.clone(),
54 status,
55 settings: ZedDotDevSettings::default(),
56 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
57 this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
58 cx.notify();
59 }),
60 });
61
62 let state_ref = state.downgrade();
63 let maintain_client_status = cx.spawn(|mut cx| async move {
64 while let Some(status) = status_rx.next().await {
65 if let Some(this) = state_ref.upgrade() {
66 _ = this.update(&mut cx, |this, cx| {
67 this.status = status;
68 cx.notify();
69 });
70 } else {
71 break;
72 }
73 }
74 });
75
76 Self {
77 client,
78 state,
79 _maintain_client_status: maintain_client_status,
80 }
81 }
82}
83
84impl LanguageModelProviderState for CloudLanguageModelProvider {
85 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
86 Some(cx.observe(&self.state, |_, _, cx| {
87 cx.notify();
88 }))
89 }
90}
91
92impl LanguageModelProvider for CloudLanguageModelProvider {
93 fn name(&self) -> LanguageModelProviderName {
94 LanguageModelProviderName(PROVIDER_NAME.into())
95 }
96
97 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
98 let mut models = HashMap::default();
99
100 // Add base models from CloudModel::iter()
101 for model in CloudModel::iter() {
102 if !matches!(model, CloudModel::Custom { .. }) {
103 models.insert(model.id().to_string(), model);
104 }
105 }
106
107 // Override with available models from settings
108 for model in &self.state.read(cx).settings.available_models {
109 models.insert(model.id().to_string(), model.clone());
110 }
111
112 models
113 .into_values()
114 .map(|model| {
115 Arc::new(CloudLanguageModel {
116 id: LanguageModelId::from(model.id().to_string()),
117 model,
118 client: self.client.clone(),
119 }) as Arc<dyn LanguageModel>
120 })
121 .collect()
122 }
123
124 fn is_authenticated(&self, cx: &AppContext) -> bool {
125 self.state.read(cx).status.is_connected()
126 }
127
128 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
129 self.state.read(cx).authenticate(cx)
130 }
131
132 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
133 cx.new_view(|_cx| AuthenticationPrompt {
134 state: self.state.clone(),
135 })
136 .into()
137 }
138
139 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
140 Task::ready(Ok(()))
141 }
142}
143
144pub struct CloudLanguageModel {
145 id: LanguageModelId,
146 model: CloudModel,
147 client: Arc<Client>,
148}
149
150impl LanguageModel for CloudLanguageModel {
151 fn id(&self) -> LanguageModelId {
152 self.id.clone()
153 }
154
155 fn name(&self) -> LanguageModelName {
156 LanguageModelName::from(self.model.display_name().to_string())
157 }
158
159 fn provider_name(&self) -> LanguageModelProviderName {
160 LanguageModelProviderName(PROVIDER_NAME.into())
161 }
162
163 fn telemetry_id(&self) -> String {
164 format!("zed.dev/{}", self.model.id())
165 }
166
167 fn max_token_count(&self) -> usize {
168 self.model.max_token_count()
169 }
170
171 fn count_tokens(
172 &self,
173 request: LanguageModelRequest,
174 cx: &AppContext,
175 ) -> BoxFuture<'static, Result<usize>> {
176 match &self.model {
177 CloudModel::Gpt3Point5Turbo => {
178 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
179 }
180 CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
181 CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
182 CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
183 CloudModel::Gpt4OmniMini => {
184 count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
185 }
186 CloudModel::Claude3_5Sonnet
187 | CloudModel::Claude3Opus
188 | CloudModel::Claude3Sonnet
189 | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
190 _ => {
191 let request = self.client.request(proto::CountTokensWithLanguageModel {
192 model: self.model.id().to_string(),
193 messages: request
194 .messages
195 .iter()
196 .map(|message| message.to_proto())
197 .collect(),
198 });
199 async move {
200 let response = request.await?;
201 Ok(response.token_count as usize)
202 }
203 .boxed()
204 }
205 }
206 }
207
208 fn stream_completion(
209 &self,
210 mut request: LanguageModelRequest,
211 _: &AsyncAppContext,
212 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
213 match &self.model {
214 CloudModel::Claude3Opus
215 | CloudModel::Claude3Sonnet
216 | CloudModel::Claude3Haiku
217 | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
218 CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
219 preprocess_anthropic_request(&mut request)
220 }
221 _ => {}
222 }
223
224 let request = proto::CompleteWithLanguageModel {
225 model: self.id.0.to_string(),
226 messages: request
227 .messages
228 .iter()
229 .map(|message| message.to_proto())
230 .collect(),
231 stop: request.stop,
232 temperature: request.temperature,
233 tools: Vec::new(),
234 tool_choice: None,
235 };
236
237 self.client
238 .request_stream(request)
239 .map_ok(|stream| {
240 stream
241 .filter_map(|response| async move {
242 match response {
243 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
244 Err(error) => Some(Err(error)),
245 }
246 })
247 .boxed()
248 })
249 .boxed()
250 }
251}
252
253struct AuthenticationPrompt {
254 state: gpui::Model<State>,
255}
256
257impl Render for AuthenticationPrompt {
258 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
259 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
260
261 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
262 v_flex()
263 .gap_2()
264 .child(
265 Button::new("sign_in", "Sign in")
266 .icon_color(Color::Muted)
267 .icon(IconName::Github)
268 .icon_position(IconPosition::Start)
269 .style(ButtonStyle::Filled)
270 .full_width()
271 .on_click(cx.listener(move |this, _, cx| {
272 this.state.update(cx, |provider, cx| {
273 provider.authenticate(cx).detach_and_log_err(cx);
274 cx.notify();
275 });
276 })),
277 )
278 .child(
279 div().flex().w_full().items_center().child(
280 Label::new("Sign in to enable collaboration.")
281 .color(Color::Muted)
282 .size(LabelSize::Small),
283 ),
284 ),
285 )
286 }
287}