completion.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(any(test, feature = "test-support"))]
  4mod fake;
  5mod ollama;
  6mod open_ai;
  7
  8pub use anthropic::*;
  9use anyhow::Result;
 10use client::Client;
 11pub use cloud::*;
 12#[cfg(any(test, feature = "test-support"))]
 13pub use fake::*;
 14use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
 15use gpui::{AnyView, AppContext, Task, WindowContext};
 16use language_model::{LanguageModel, LanguageModelRequest};
 17pub use ollama::*;
 18pub use open_ai::*;
 19use parking_lot::RwLock;
 20use smol::lock::{Semaphore, SemaphoreGuardArc};
 21use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
 22
 23pub struct CompletionResponse {
 24    inner: BoxStream<'static, Result<String>>,
 25    _lock: SemaphoreGuardArc,
 26}
 27
 28impl futures::Stream for CompletionResponse {
 29    type Item = Result<String>;
 30
 31    fn poll_next(
 32        mut self: Pin<&mut Self>,
 33        cx: &mut std::task::Context<'_>,
 34    ) -> Poll<Option<Self::Item>> {
 35        Pin::new(&mut self.inner).poll_next(cx)
 36    }
 37}
 38
 39pub trait LanguageModelCompletionProvider: Send + Sync {
 40    fn available_models(&self) -> Vec<LanguageModel>;
 41    fn settings_version(&self) -> usize;
 42    fn is_authenticated(&self) -> bool;
 43    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
 44    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 45    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
 46    fn model(&self) -> LanguageModel;
 47    fn count_tokens(
 48        &self,
 49        request: LanguageModelRequest,
 50        cx: &AppContext,
 51    ) -> BoxFuture<'static, Result<usize>>;
 52    fn stream_completion(
 53        &self,
 54        request: LanguageModelRequest,
 55    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 56
 57    fn as_any_mut(&mut self) -> &mut dyn Any;
 58}
 59
 60const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
 61
 62pub struct CompletionProvider {
 63    provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
 64    client: Option<Arc<Client>>,
 65    request_limiter: Arc<Semaphore>,
 66}
 67
 68impl CompletionProvider {
 69    pub fn new(
 70        provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
 71        client: Option<Arc<Client>>,
 72    ) -> Self {
 73        Self {
 74            provider,
 75            client,
 76            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
 77        }
 78    }
 79
 80    pub fn available_models(&self) -> Vec<LanguageModel> {
 81        self.provider.read().available_models()
 82    }
 83
 84    pub fn settings_version(&self) -> usize {
 85        self.provider.read().settings_version()
 86    }
 87
 88    pub fn is_authenticated(&self) -> bool {
 89        self.provider.read().is_authenticated()
 90    }
 91
 92    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 93        self.provider.read().authenticate(cx)
 94    }
 95
 96    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 97        self.provider.read().authentication_prompt(cx)
 98    }
 99
100    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
101        self.provider.read().reset_credentials(cx)
102    }
103
104    pub fn model(&self) -> LanguageModel {
105        self.provider.read().model()
106    }
107
108    pub fn count_tokens(
109        &self,
110        request: LanguageModelRequest,
111        cx: &AppContext,
112    ) -> BoxFuture<'static, Result<usize>> {
113        self.provider.read().count_tokens(request, cx)
114    }
115
116    pub fn stream_completion(
117        &self,
118        request: LanguageModelRequest,
119        cx: &AppContext,
120    ) -> Task<Result<CompletionResponse>> {
121        let rate_limiter = self.request_limiter.clone();
122        let provider = self.provider.clone();
123        cx.foreground_executor().spawn(async move {
124            let lock = rate_limiter.acquire_arc().await;
125            let response = provider.read().stream_completion(request);
126            let response = response.await?;
127            Ok(CompletionResponse {
128                inner: response,
129                _lock: lock,
130            })
131        })
132    }
133
134    pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
135        let response = self.stream_completion(request, cx);
136        cx.foreground_executor().spawn(async move {
137            let mut chunks = response.await?;
138            let mut completion = String::new();
139            while let Some(chunk) = chunks.next().await {
140                let chunk = chunk?;
141                completion.push_str(&chunk);
142            }
143            Ok(completion)
144        })
145    }
146
147    pub fn update_provider(
148        &mut self,
149        get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
150    ) {
151        if let Some(client) = &self.client {
152            self.provider = get_provider(Arc::clone(client));
153        } else {
154            log::warn!("completion provider cannot be updated because its client was not set");
155        }
156    }
157}
158
159impl gpui::Global for CompletionProvider {}
160
161impl CompletionProvider {
162    pub fn global(cx: &AppContext) -> &Self {
163        cx.global::<Self>()
164    }
165
166    pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
167        &mut self,
168        update: impl FnOnce(&mut T) -> R,
169    ) -> Option<R> {
170        let mut provider = self.provider.write();
171        if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
172            Some(update(provider))
173        } else {
174            None
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use std::sync::Arc;
182
183    use gpui::AppContext;
184    use parking_lot::RwLock;
185    use settings::SettingsStore;
186    use smol::stream::StreamExt;
187
188    use crate::{
189        CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
190        MAX_CONCURRENT_COMPLETION_REQUESTS,
191    };
192
193    #[gpui::test]
194    fn test_rate_limiting(cx: &mut AppContext) {
195        SettingsStore::test(cx);
196        let fake_provider = FakeCompletionProvider::setup_test(cx);
197
198        let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
199
200        // Enqueue some requests
201        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
202            let response = provider.stream_completion(
203                LanguageModelRequest {
204                    temperature: i as f32 / 10.0,
205                    ..Default::default()
206                },
207                cx,
208            );
209            cx.background_executor()
210                .spawn(async move {
211                    let mut stream = response.await.unwrap();
212                    while let Some(message) = stream.next().await {
213                        message.unwrap();
214                    }
215                })
216                .detach();
217        }
218        cx.background_executor().run_until_parked();
219
220        assert_eq!(
221            fake_provider.completion_count(),
222            MAX_CONCURRENT_COMPLETION_REQUESTS
223        );
224
225        // Get the first completion request that is in flight and mark it as completed.
226        let completion = fake_provider
227            .pending_completions()
228            .into_iter()
229            .next()
230            .unwrap();
231        fake_provider.finish_completion(&completion);
232
233        // Ensure that the number of in-flight completion requests is reduced.
234        assert_eq!(
235            fake_provider.completion_count(),
236            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
237        );
238
239        cx.background_executor().run_until_parked();
240
241        // Ensure that another completion request was allowed to acquire the lock.
242        assert_eq!(
243            fake_provider.completion_count(),
244            MAX_CONCURRENT_COMPLETION_REQUESTS
245        );
246
247        // Mark all completion requests as finished that are in flight.
248        for request in fake_provider.pending_completions() {
249            fake_provider.finish_completion(&request);
250        }
251
252        assert_eq!(fake_provider.completion_count(), 0);
253
254        // Wait until the background tasks acquire the lock again.
255        cx.background_executor().run_until_parked();
256
257        assert_eq!(
258            fake_provider.completion_count(),
259            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
260        );
261
262        // Finish all remaining completion requests.
263        for request in fake_provider.pending_completions() {
264            fake_provider.finish_completion(&request);
265        }
266
267        cx.background_executor().run_until_parked();
268
269        assert_eq!(fake_provider.completion_count(), 0);
270    }
271}