completion_provider.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(test)]
  4mod fake;
  5mod ollama;
  6mod open_ai;
  7
  8pub use anthropic::*;
  9pub use cloud::*;
 10#[cfg(test)]
 11pub use fake::*;
 12pub use ollama::*;
 13pub use open_ai::*;
 14use parking_lot::RwLock;
 15use smol::lock::{Semaphore, SemaphoreGuardArc};
 16
 17use crate::{
 18    assistant_settings::{AssistantProvider, AssistantSettings},
 19    LanguageModel, LanguageModelRequest,
 20};
 21use anyhow::Result;
 22use client::Client;
 23use futures::{future::BoxFuture, stream::BoxStream};
 24use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 25use settings::{Settings, SettingsStore};
 26use std::time::Duration;
 27use std::{any::Any, sync::Arc};
 28
 29/// Choose which model to use for openai provider.
 30/// If the model is not available, try to use the first available model, or fallback to the original model.
 31fn choose_openai_model(
 32    model: &::open_ai::Model,
 33    available_models: &[::open_ai::Model],
 34) -> ::open_ai::Model {
 35    available_models
 36        .iter()
 37        .find(|&m| m == model)
 38        .or_else(|| available_models.first())
 39        .unwrap_or_else(|| model)
 40        .clone()
 41}
 42
 43pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 44    let provider = create_provider_from_settings(client.clone(), 0, cx);
 45    cx.set_global(CompletionProvider::new(provider, Some(client)));
 46
 47    let mut settings_version = 0;
 48    cx.observe_global::<SettingsStore>(move |cx| {
 49        settings_version += 1;
 50        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 51            provider.update_settings(settings_version, cx);
 52        })
 53    })
 54    .detach();
 55}
 56
 57pub struct CompletionResponse {
 58    pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
 59    _lock: SemaphoreGuardArc,
 60}
 61
 62pub trait LanguageModelCompletionProvider: Send + Sync {
 63    fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
 64    fn settings_version(&self) -> usize;
 65    fn is_authenticated(&self) -> bool;
 66    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
 67    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 68    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
 69    fn model(&self) -> LanguageModel;
 70    fn count_tokens(
 71        &self,
 72        request: LanguageModelRequest,
 73        cx: &AppContext,
 74    ) -> BoxFuture<'static, Result<usize>>;
 75    fn complete(
 76        &self,
 77        request: LanguageModelRequest,
 78    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 79
 80    fn as_any_mut(&mut self) -> &mut dyn Any;
 81}
 82
 83const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
 84
 85pub struct CompletionProvider {
 86    provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
 87    client: Option<Arc<Client>>,
 88    request_limiter: Arc<Semaphore>,
 89}
 90
 91impl CompletionProvider {
 92    pub fn new(
 93        provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
 94        client: Option<Arc<Client>>,
 95    ) -> Self {
 96        Self {
 97            provider,
 98            client,
 99            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
100        }
101    }
102
103    pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
104        self.provider.read().available_models(cx)
105    }
106
107    pub fn settings_version(&self) -> usize {
108        self.provider.read().settings_version()
109    }
110
111    pub fn is_authenticated(&self) -> bool {
112        self.provider.read().is_authenticated()
113    }
114
115    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
116        self.provider.read().authenticate(cx)
117    }
118
119    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
120        self.provider.read().authentication_prompt(cx)
121    }
122
123    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
124        self.provider.read().reset_credentials(cx)
125    }
126
127    pub fn model(&self) -> LanguageModel {
128        self.provider.read().model()
129    }
130
131    pub fn count_tokens(
132        &self,
133        request: LanguageModelRequest,
134        cx: &AppContext,
135    ) -> BoxFuture<'static, Result<usize>> {
136        self.provider.read().count_tokens(request, cx)
137    }
138
139    pub fn complete(
140        &self,
141        request: LanguageModelRequest,
142        cx: &AppContext,
143    ) -> Task<CompletionResponse> {
144        let rate_limiter = self.request_limiter.clone();
145        let provider = self.provider.clone();
146        cx.background_executor().spawn(async move {
147            let lock = rate_limiter.acquire_arc().await;
148            let response = provider.read().complete(request);
149            CompletionResponse {
150                inner: response,
151                _lock: lock,
152            }
153        })
154    }
155}
156
157impl gpui::Global for CompletionProvider {}
158
159impl CompletionProvider {
160    pub fn global(cx: &AppContext) -> &Self {
161        cx.global::<Self>()
162    }
163
164    pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
165        &mut self,
166        update: impl FnOnce(&mut T) -> R,
167    ) -> Option<R> {
168        let mut provider = self.provider.write();
169        if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
170            Some(update(provider))
171        } else {
172            None
173        }
174    }
175
176    pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
177        let updated = match &AssistantSettings::get_global(cx).provider {
178            AssistantProvider::ZedDotDev { model } => self
179                .update_current_as::<_, CloudCompletionProvider>(|provider| {
180                    provider.update(model.clone(), version);
181                }),
182            AssistantProvider::OpenAi {
183                model,
184                api_url,
185                low_speed_timeout_in_seconds,
186                available_models,
187            } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
188                provider.update(
189                    choose_openai_model(&model, &available_models),
190                    api_url.clone(),
191                    low_speed_timeout_in_seconds.map(Duration::from_secs),
192                    version,
193                );
194            }),
195            AssistantProvider::Anthropic {
196                model,
197                api_url,
198                low_speed_timeout_in_seconds,
199            } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
200                provider.update(
201                    model.clone(),
202                    api_url.clone(),
203                    low_speed_timeout_in_seconds.map(Duration::from_secs),
204                    version,
205                );
206            }),
207            AssistantProvider::Ollama {
208                model,
209                api_url,
210                low_speed_timeout_in_seconds,
211            } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
212                provider.update(
213                    model.clone(),
214                    api_url.clone(),
215                    low_speed_timeout_in_seconds.map(Duration::from_secs),
216                    version,
217                    cx,
218                );
219            }),
220        };
221
222        // Previously configured provider was changed to another one
223        if updated.is_none() {
224            if let Some(client) = self.client.clone() {
225                self.provider = create_provider_from_settings(client, version, cx);
226            } else {
227                log::warn!("completion provider cannot be created because client is not set");
228            }
229        }
230    }
231}
232
233fn create_provider_from_settings(
234    client: Arc<Client>,
235    settings_version: usize,
236    cx: &mut AppContext,
237) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
238    match &AssistantSettings::get_global(cx).provider {
239        AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
240            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
241        )),
242        AssistantProvider::OpenAi {
243            model,
244            api_url,
245            low_speed_timeout_in_seconds,
246            available_models,
247        } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
248            choose_openai_model(&model, &available_models),
249            api_url.clone(),
250            client.http_client(),
251            low_speed_timeout_in_seconds.map(Duration::from_secs),
252            settings_version,
253        ))),
254        AssistantProvider::Anthropic {
255            model,
256            api_url,
257            low_speed_timeout_in_seconds,
258        } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
259            model.clone(),
260            api_url.clone(),
261            client.http_client(),
262            low_speed_timeout_in_seconds.map(Duration::from_secs),
263            settings_version,
264        ))),
265        AssistantProvider::Ollama {
266            model,
267            api_url,
268            low_speed_timeout_in_seconds,
269        } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
270            model.clone(),
271            api_url.clone(),
272            client.http_client(),
273            low_speed_timeout_in_seconds.map(Duration::from_secs),
274            settings_version,
275            cx,
276        ))),
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use std::sync::Arc;
283
284    use gpui::AppContext;
285    use parking_lot::RwLock;
286    use settings::SettingsStore;
287    use smol::stream::StreamExt;
288
289    use crate::{
290        completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
291        FakeCompletionProvider, LanguageModelRequest,
292    };
293
294    #[gpui::test]
295    fn test_rate_limiting(cx: &mut AppContext) {
296        SettingsStore::test(cx);
297        let fake_provider = FakeCompletionProvider::setup_test(cx);
298
299        let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
300
301        // Enqueue some requests
302        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
303            let response = provider.complete(
304                LanguageModelRequest {
305                    temperature: i as f32 / 10.0,
306                    ..Default::default()
307                },
308                cx,
309            );
310            cx.background_executor()
311                .spawn(async move {
312                    let response = response.await;
313                    let mut stream = response.inner.await.unwrap();
314                    while let Some(message) = stream.next().await {
315                        message.unwrap();
316                    }
317                })
318                .detach();
319        }
320        cx.background_executor().run_until_parked();
321
322        assert_eq!(
323            fake_provider.completion_count(),
324            MAX_CONCURRENT_COMPLETION_REQUESTS
325        );
326
327        // Get the first completion request that is in flight and mark it as completed.
328        let completion = fake_provider
329            .running_completions()
330            .into_iter()
331            .next()
332            .unwrap();
333        fake_provider.finish_completion(&completion);
334
335        // Ensure that the number of in-flight completion requests is reduced.
336        assert_eq!(
337            fake_provider.completion_count(),
338            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
339        );
340
341        cx.background_executor().run_until_parked();
342
343        // Ensure that another completion request was allowed to acquire the lock.
344        assert_eq!(
345            fake_provider.completion_count(),
346            MAX_CONCURRENT_COMPLETION_REQUESTS
347        );
348
349        // Mark all completion requests as finished that are in flight.
350        for request in fake_provider.running_completions() {
351            fake_provider.finish_completion(&request);
352        }
353
354        assert_eq!(fake_provider.completion_count(), 0);
355
356        // Wait until the background tasks acquire the lock again.
357        cx.background_executor().run_until_parked();
358
359        assert_eq!(
360            fake_provider.completion_count(),
361            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
362        );
363
364        // Finish all remaining completion requests.
365        for request in fake_provider.running_completions() {
366            fake_provider.finish_completion(&request);
367        }
368
369        cx.background_executor().run_until_parked();
370
371        assert_eq!(fake_provider.completion_count(), 0);
372    }
373}