1mod anthropic;
2mod cloud;
3#[cfg(any(test, feature = "test-support"))]
4mod fake;
5mod ollama;
6mod open_ai;
7
8pub use anthropic::*;
9pub use cloud::*;
10#[cfg(any(test, feature = "test-support"))]
11pub use fake::*;
12pub use ollama::*;
13pub use open_ai::*;
14use parking_lot::RwLock;
15use smol::lock::{Semaphore, SemaphoreGuardArc};
16
17use crate::{
18 assistant_settings::{AssistantProvider, AssistantSettings},
19 LanguageModel, LanguageModelRequest,
20};
21use anyhow::Result;
22use client::Client;
23use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
24use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
25use settings::{Settings, SettingsStore};
26use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
27
28/// Choose which model to use for openai provider.
29/// If the model is not available, try to use the first available model, or fallback to the original model.
30fn choose_openai_model(
31 model: &::open_ai::Model,
32 available_models: &[::open_ai::Model],
33) -> ::open_ai::Model {
34 available_models
35 .iter()
36 .find(|&m| m == model)
37 .or_else(|| available_models.first())
38 .unwrap_or_else(|| model)
39 .clone()
40}
41
42pub fn init(client: Arc<Client>, cx: &mut AppContext) {
43 let provider = create_provider_from_settings(client.clone(), 0, cx);
44 cx.set_global(CompletionProvider::new(provider, Some(client)));
45
46 let mut settings_version = 0;
47 cx.observe_global::<SettingsStore>(move |cx| {
48 settings_version += 1;
49 cx.update_global::<CompletionProvider, _>(|provider, cx| {
50 provider.update_settings(settings_version, cx);
51 })
52 })
53 .detach();
54}
55
56pub struct CompletionResponse {
57 inner: BoxStream<'static, Result<String>>,
58 _lock: SemaphoreGuardArc,
59}
60
61impl futures::Stream for CompletionResponse {
62 type Item = Result<String>;
63
64 fn poll_next(
65 mut self: Pin<&mut Self>,
66 cx: &mut std::task::Context<'_>,
67 ) -> Poll<Option<Self::Item>> {
68 Pin::new(&mut self.inner).poll_next(cx)
69 }
70}
71
72pub trait LanguageModelCompletionProvider: Send + Sync {
73 fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
74 fn settings_version(&self) -> usize;
75 fn is_authenticated(&self) -> bool;
76 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
77 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
78 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
79 fn model(&self) -> LanguageModel;
80 fn count_tokens(
81 &self,
82 request: LanguageModelRequest,
83 cx: &AppContext,
84 ) -> BoxFuture<'static, Result<usize>>;
85 fn stream_completion(
86 &self,
87 request: LanguageModelRequest,
88 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
89
90 fn as_any_mut(&mut self) -> &mut dyn Any;
91}
92
93const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
94
95pub struct CompletionProvider {
96 provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
97 client: Option<Arc<Client>>,
98 request_limiter: Arc<Semaphore>,
99}
100
101impl CompletionProvider {
102 pub fn new(
103 provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
104 client: Option<Arc<Client>>,
105 ) -> Self {
106 Self {
107 provider,
108 client,
109 request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
110 }
111 }
112
113 pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
114 self.provider.read().available_models(cx)
115 }
116
117 pub fn settings_version(&self) -> usize {
118 self.provider.read().settings_version()
119 }
120
121 pub fn is_authenticated(&self) -> bool {
122 self.provider.read().is_authenticated()
123 }
124
125 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
126 self.provider.read().authenticate(cx)
127 }
128
129 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
130 self.provider.read().authentication_prompt(cx)
131 }
132
133 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
134 self.provider.read().reset_credentials(cx)
135 }
136
137 pub fn model(&self) -> LanguageModel {
138 self.provider.read().model()
139 }
140
141 pub fn count_tokens(
142 &self,
143 request: LanguageModelRequest,
144 cx: &AppContext,
145 ) -> BoxFuture<'static, Result<usize>> {
146 self.provider.read().count_tokens(request, cx)
147 }
148
149 pub fn stream_completion(
150 &self,
151 request: LanguageModelRequest,
152 cx: &AppContext,
153 ) -> Task<Result<CompletionResponse>> {
154 let rate_limiter = self.request_limiter.clone();
155 let provider = self.provider.clone();
156 cx.foreground_executor().spawn(async move {
157 let lock = rate_limiter.acquire_arc().await;
158 let response = provider.read().stream_completion(request);
159 let response = response.await?;
160 Ok(CompletionResponse {
161 inner: response,
162 _lock: lock,
163 })
164 })
165 }
166
167 pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
168 let response = self.stream_completion(request, cx);
169 cx.foreground_executor().spawn(async move {
170 let mut chunks = response.await?;
171 let mut completion = String::new();
172 while let Some(chunk) = chunks.next().await {
173 let chunk = chunk?;
174 completion.push_str(&chunk);
175 }
176 Ok(completion)
177 })
178 }
179}
180
181impl gpui::Global for CompletionProvider {}
182
183impl CompletionProvider {
184 pub fn global(cx: &AppContext) -> &Self {
185 cx.global::<Self>()
186 }
187
188 pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
189 &mut self,
190 update: impl FnOnce(&mut T) -> R,
191 ) -> Option<R> {
192 let mut provider = self.provider.write();
193 if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
194 Some(update(provider))
195 } else {
196 None
197 }
198 }
199
200 pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
201 let updated = match &AssistantSettings::get_global(cx).provider {
202 AssistantProvider::ZedDotDev { model } => self
203 .update_current_as::<_, CloudCompletionProvider>(|provider| {
204 provider.update(model.clone(), version);
205 }),
206 AssistantProvider::OpenAi {
207 model,
208 api_url,
209 low_speed_timeout_in_seconds,
210 available_models,
211 } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
212 provider.update(
213 choose_openai_model(&model, &available_models),
214 api_url.clone(),
215 low_speed_timeout_in_seconds.map(Duration::from_secs),
216 version,
217 );
218 }),
219 AssistantProvider::Anthropic {
220 model,
221 api_url,
222 low_speed_timeout_in_seconds,
223 } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
224 provider.update(
225 model.clone(),
226 api_url.clone(),
227 low_speed_timeout_in_seconds.map(Duration::from_secs),
228 version,
229 );
230 }),
231 AssistantProvider::Ollama {
232 model,
233 api_url,
234 low_speed_timeout_in_seconds,
235 } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
236 provider.update(
237 model.clone(),
238 api_url.clone(),
239 low_speed_timeout_in_seconds.map(Duration::from_secs),
240 version,
241 cx,
242 );
243 }),
244 };
245
246 // Previously configured provider was changed to another one
247 if updated.is_none() {
248 if let Some(client) = self.client.clone() {
249 self.provider = create_provider_from_settings(client, version, cx);
250 } else {
251 log::warn!("completion provider cannot be created because client is not set");
252 }
253 }
254 }
255}
256
257fn create_provider_from_settings(
258 client: Arc<Client>,
259 settings_version: usize,
260 cx: &mut AppContext,
261) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
262 match &AssistantSettings::get_global(cx).provider {
263 AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
264 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
265 )),
266 AssistantProvider::OpenAi {
267 model,
268 api_url,
269 low_speed_timeout_in_seconds,
270 available_models,
271 } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
272 choose_openai_model(&model, &available_models),
273 api_url.clone(),
274 client.http_client(),
275 low_speed_timeout_in_seconds.map(Duration::from_secs),
276 settings_version,
277 ))),
278 AssistantProvider::Anthropic {
279 model,
280 api_url,
281 low_speed_timeout_in_seconds,
282 } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
283 model.clone(),
284 api_url.clone(),
285 client.http_client(),
286 low_speed_timeout_in_seconds.map(Duration::from_secs),
287 settings_version,
288 ))),
289 AssistantProvider::Ollama {
290 model,
291 api_url,
292 low_speed_timeout_in_seconds,
293 } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
294 model.clone(),
295 api_url.clone(),
296 client.http_client(),
297 low_speed_timeout_in_seconds.map(Duration::from_secs),
298 settings_version,
299 cx,
300 ))),
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use std::sync::Arc;
307
308 use gpui::AppContext;
309 use parking_lot::RwLock;
310 use settings::SettingsStore;
311 use smol::stream::StreamExt;
312
313 use crate::{
314 completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
315 FakeCompletionProvider, LanguageModelRequest,
316 };
317
318 #[gpui::test]
319 fn test_rate_limiting(cx: &mut AppContext) {
320 SettingsStore::test(cx);
321 let fake_provider = FakeCompletionProvider::setup_test(cx);
322
323 let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
324
325 // Enqueue some requests
326 for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
327 let response = provider.stream_completion(
328 LanguageModelRequest {
329 temperature: i as f32 / 10.0,
330 ..Default::default()
331 },
332 cx,
333 );
334 cx.background_executor()
335 .spawn(async move {
336 let mut stream = response.await.unwrap();
337 while let Some(message) = stream.next().await {
338 message.unwrap();
339 }
340 })
341 .detach();
342 }
343 cx.background_executor().run_until_parked();
344
345 assert_eq!(
346 fake_provider.completion_count(),
347 MAX_CONCURRENT_COMPLETION_REQUESTS
348 );
349
350 // Get the first completion request that is in flight and mark it as completed.
351 let completion = fake_provider
352 .pending_completions()
353 .into_iter()
354 .next()
355 .unwrap();
356 fake_provider.finish_completion(&completion);
357
358 // Ensure that the number of in-flight completion requests is reduced.
359 assert_eq!(
360 fake_provider.completion_count(),
361 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
362 );
363
364 cx.background_executor().run_until_parked();
365
366 // Ensure that another completion request was allowed to acquire the lock.
367 assert_eq!(
368 fake_provider.completion_count(),
369 MAX_CONCURRENT_COMPLETION_REQUESTS
370 );
371
372 // Mark all completion requests as finished that are in flight.
373 for request in fake_provider.pending_completions() {
374 fake_provider.finish_completion(&request);
375 }
376
377 assert_eq!(fake_provider.completion_count(), 0);
378
379 // Wait until the background tasks acquire the lock again.
380 cx.background_executor().run_until_parked();
381
382 assert_eq!(
383 fake_provider.completion_count(),
384 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
385 );
386
387 // Finish all remaining completion requests.
388 for request in fake_provider.pending_completions() {
389 fake_provider.finish_completion(&request);
390 }
391
392 cx.background_executor().run_until_parked();
393
394 assert_eq!(fake_provider.completion_count(), 0);
395 }
396}