1mod anthropic;
2mod cloud;
3#[cfg(test)]
4mod fake;
5mod ollama;
6mod open_ai;
7
8pub use anthropic::*;
9pub use cloud::*;
10#[cfg(test)]
11pub use fake::*;
12pub use ollama::*;
13pub use open_ai::*;
14use parking_lot::RwLock;
15use smol::lock::{Semaphore, SemaphoreGuardArc};
16
17use crate::{
18 assistant_settings::{AssistantProvider, AssistantSettings},
19 LanguageModel, LanguageModelRequest,
20};
21use anyhow::Result;
22use client::Client;
23use futures::{future::BoxFuture, stream::BoxStream};
24use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
25use settings::{Settings, SettingsStore};
26use std::time::Duration;
27use std::{any::Any, sync::Arc};
28
29/// Choose which model to use for openai provider.
30/// If the model is not available, try to use the first available model, or fallback to the original model.
31fn choose_openai_model(
32 model: &::open_ai::Model,
33 available_models: &[::open_ai::Model],
34) -> ::open_ai::Model {
35 available_models
36 .iter()
37 .find(|&m| m == model)
38 .or_else(|| available_models.first())
39 .unwrap_or_else(|| model)
40 .clone()
41}
42
43pub fn init(client: Arc<Client>, cx: &mut AppContext) {
44 let provider = create_provider_from_settings(client.clone(), 0, cx);
45 cx.set_global(CompletionProvider::new(provider, Some(client)));
46
47 let mut settings_version = 0;
48 cx.observe_global::<SettingsStore>(move |cx| {
49 settings_version += 1;
50 cx.update_global::<CompletionProvider, _>(|provider, cx| {
51 provider.update_settings(settings_version, cx);
52 })
53 })
54 .detach();
55}
56
57pub struct CompletionResponse {
58 pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
59 _lock: SemaphoreGuardArc,
60}
61
62pub trait LanguageModelCompletionProvider: Send + Sync {
63 fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
64 fn settings_version(&self) -> usize;
65 fn is_authenticated(&self) -> bool;
66 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
67 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
68 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
69 fn model(&self) -> LanguageModel;
70 fn count_tokens(
71 &self,
72 request: LanguageModelRequest,
73 cx: &AppContext,
74 ) -> BoxFuture<'static, Result<usize>>;
75 fn complete(
76 &self,
77 request: LanguageModelRequest,
78 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
79
80 fn as_any_mut(&mut self) -> &mut dyn Any;
81}
82
83const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
84
85pub struct CompletionProvider {
86 provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
87 client: Option<Arc<Client>>,
88 request_limiter: Arc<Semaphore>,
89}
90
91impl CompletionProvider {
92 pub fn new(
93 provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
94 client: Option<Arc<Client>>,
95 ) -> Self {
96 Self {
97 provider,
98 client,
99 request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
100 }
101 }
102
103 pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
104 self.provider.read().available_models(cx)
105 }
106
107 pub fn settings_version(&self) -> usize {
108 self.provider.read().settings_version()
109 }
110
111 pub fn is_authenticated(&self) -> bool {
112 self.provider.read().is_authenticated()
113 }
114
115 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
116 self.provider.read().authenticate(cx)
117 }
118
119 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
120 self.provider.read().authentication_prompt(cx)
121 }
122
123 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
124 self.provider.read().reset_credentials(cx)
125 }
126
127 pub fn model(&self) -> LanguageModel {
128 self.provider.read().model()
129 }
130
131 pub fn count_tokens(
132 &self,
133 request: LanguageModelRequest,
134 cx: &AppContext,
135 ) -> BoxFuture<'static, Result<usize>> {
136 self.provider.read().count_tokens(request, cx)
137 }
138
139 pub fn complete(
140 &self,
141 request: LanguageModelRequest,
142 cx: &AppContext,
143 ) -> Task<CompletionResponse> {
144 let rate_limiter = self.request_limiter.clone();
145 let provider = self.provider.clone();
146 cx.background_executor().spawn(async move {
147 let lock = rate_limiter.acquire_arc().await;
148 let response = provider.read().complete(request);
149 CompletionResponse {
150 inner: response,
151 _lock: lock,
152 }
153 })
154 }
155}
156
157impl gpui::Global for CompletionProvider {}
158
159impl CompletionProvider {
160 pub fn global(cx: &AppContext) -> &Self {
161 cx.global::<Self>()
162 }
163
164 pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
165 &mut self,
166 update: impl FnOnce(&mut T) -> R,
167 ) -> Option<R> {
168 let mut provider = self.provider.write();
169 if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
170 Some(update(provider))
171 } else {
172 None
173 }
174 }
175
176 pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
177 let updated = match &AssistantSettings::get_global(cx).provider {
178 AssistantProvider::ZedDotDev { model } => self
179 .update_current_as::<_, CloudCompletionProvider>(|provider| {
180 provider.update(model.clone(), version);
181 }),
182 AssistantProvider::OpenAi {
183 model,
184 api_url,
185 low_speed_timeout_in_seconds,
186 available_models,
187 } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
188 provider.update(
189 choose_openai_model(&model, &available_models),
190 api_url.clone(),
191 low_speed_timeout_in_seconds.map(Duration::from_secs),
192 version,
193 );
194 }),
195 AssistantProvider::Anthropic {
196 model,
197 api_url,
198 low_speed_timeout_in_seconds,
199 } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
200 provider.update(
201 model.clone(),
202 api_url.clone(),
203 low_speed_timeout_in_seconds.map(Duration::from_secs),
204 version,
205 );
206 }),
207 AssistantProvider::Ollama {
208 model,
209 api_url,
210 low_speed_timeout_in_seconds,
211 } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
212 provider.update(
213 model.clone(),
214 api_url.clone(),
215 low_speed_timeout_in_seconds.map(Duration::from_secs),
216 version,
217 cx,
218 );
219 }),
220 };
221
222 // Previously configured provider was changed to another one
223 if updated.is_none() {
224 if let Some(client) = self.client.clone() {
225 self.provider = create_provider_from_settings(client, version, cx);
226 } else {
227 log::warn!("completion provider cannot be created because client is not set");
228 }
229 }
230 }
231}
232
233fn create_provider_from_settings(
234 client: Arc<Client>,
235 settings_version: usize,
236 cx: &mut AppContext,
237) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
238 match &AssistantSettings::get_global(cx).provider {
239 AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
240 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
241 )),
242 AssistantProvider::OpenAi {
243 model,
244 api_url,
245 low_speed_timeout_in_seconds,
246 available_models,
247 } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
248 choose_openai_model(&model, &available_models),
249 api_url.clone(),
250 client.http_client(),
251 low_speed_timeout_in_seconds.map(Duration::from_secs),
252 settings_version,
253 ))),
254 AssistantProvider::Anthropic {
255 model,
256 api_url,
257 low_speed_timeout_in_seconds,
258 } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
259 model.clone(),
260 api_url.clone(),
261 client.http_client(),
262 low_speed_timeout_in_seconds.map(Duration::from_secs),
263 settings_version,
264 ))),
265 AssistantProvider::Ollama {
266 model,
267 api_url,
268 low_speed_timeout_in_seconds,
269 } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
270 model.clone(),
271 api_url.clone(),
272 client.http_client(),
273 low_speed_timeout_in_seconds.map(Duration::from_secs),
274 settings_version,
275 cx,
276 ))),
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use std::sync::Arc;
283
284 use gpui::AppContext;
285 use parking_lot::RwLock;
286 use settings::SettingsStore;
287 use smol::stream::StreamExt;
288
289 use crate::{
290 completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
291 FakeCompletionProvider, LanguageModelRequest,
292 };
293
294 #[gpui::test]
295 fn test_rate_limiting(cx: &mut AppContext) {
296 SettingsStore::test(cx);
297 let fake_provider = FakeCompletionProvider::setup_test(cx);
298
299 let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
300
301 // Enqueue some requests
302 for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
303 let response = provider.complete(
304 LanguageModelRequest {
305 temperature: i as f32 / 10.0,
306 ..Default::default()
307 },
308 cx,
309 );
310 cx.background_executor()
311 .spawn(async move {
312 let response = response.await;
313 let mut stream = response.inner.await.unwrap();
314 while let Some(message) = stream.next().await {
315 message.unwrap();
316 }
317 })
318 .detach();
319 }
320 cx.background_executor().run_until_parked();
321
322 assert_eq!(
323 fake_provider.completion_count(),
324 MAX_CONCURRENT_COMPLETION_REQUESTS
325 );
326
327 // Get the first completion request that is in flight and mark it as completed.
328 let completion = fake_provider
329 .running_completions()
330 .into_iter()
331 .next()
332 .unwrap();
333 fake_provider.finish_completion(&completion);
334
335 // Ensure that the number of in-flight completion requests is reduced.
336 assert_eq!(
337 fake_provider.completion_count(),
338 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
339 );
340
341 cx.background_executor().run_until_parked();
342
343 // Ensure that another completion request was allowed to acquire the lock.
344 assert_eq!(
345 fake_provider.completion_count(),
346 MAX_CONCURRENT_COMPLETION_REQUESTS
347 );
348
349 // Mark all completion requests as finished that are in flight.
350 for request in fake_provider.running_completions() {
351 fake_provider.finish_completion(&request);
352 }
353
354 assert_eq!(fake_provider.completion_count(), 0);
355
356 // Wait until the background tasks acquire the lock again.
357 cx.background_executor().run_until_parked();
358
359 assert_eq!(
360 fake_provider.completion_count(),
361 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
362 );
363
364 // Finish all remaining completion requests.
365 for request in fake_provider.running_completions() {
366 fake_provider.finish_completion(&request);
367 }
368
369 cx.background_executor().run_until_parked();
370
371 assert_eq!(fake_provider.completion_count(), 0);
372 }
373}