completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AppContext, Global, Model, ModelContext, Task};
  4use language_model::{
  5    LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
  6    LanguageModelRequest,
  7};
  8use smol::lock::{Semaphore, SemaphoreGuardArc};
  9use std::{pin::Pin, sync::Arc, task::Poll};
 10use ui::Context;
 11
 12pub fn init(cx: &mut AppContext) {
 13    let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
 14    cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
 15}
 16
 17struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
 18
 19impl Global for GlobalLanguageModelCompletionProvider {}
 20
 21pub struct LanguageModelCompletionProvider {
 22    active_provider: Option<Arc<dyn LanguageModelProvider>>,
 23    active_model: Option<Arc<dyn LanguageModel>>,
 24    request_limiter: Arc<Semaphore>,
 25}
 26
 27const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
 28
 29pub struct LanguageModelCompletionResponse {
 30    pub inner: BoxStream<'static, Result<String>>,
 31    _lock: SemaphoreGuardArc,
 32}
 33
 34impl futures::Stream for LanguageModelCompletionResponse {
 35    type Item = Result<String>;
 36
 37    fn poll_next(
 38        mut self: Pin<&mut Self>,
 39        cx: &mut std::task::Context<'_>,
 40    ) -> Poll<Option<Self::Item>> {
 41        Pin::new(&mut self.inner).poll_next(cx)
 42    }
 43}
 44
 45impl LanguageModelCompletionProvider {
 46    pub fn global(cx: &AppContext) -> Model<Self> {
 47        cx.global::<GlobalLanguageModelCompletionProvider>()
 48            .0
 49            .clone()
 50    }
 51
 52    pub fn read_global(cx: &AppContext) -> &Self {
 53        cx.global::<GlobalLanguageModelCompletionProvider>()
 54            .0
 55            .read(cx)
 56    }
 57
 58    #[cfg(any(test, feature = "test-support"))]
 59    pub fn test(cx: &mut AppContext) {
 60        let provider = cx.new_model(|cx| {
 61            let mut this = Self::new(cx);
 62            let available_model = LanguageModelRegistry::read_global(cx)
 63                .available_models(cx)
 64                .first()
 65                .unwrap()
 66                .clone();
 67            this.set_active_model(available_model, cx);
 68            this
 69        });
 70        cx.set_global(GlobalLanguageModelCompletionProvider(provider));
 71    }
 72
 73    pub fn new(cx: &mut ModelContext<Self>) -> Self {
 74        cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
 75            cx.notify();
 76        })
 77        .detach();
 78
 79        Self {
 80            active_provider: None,
 81            active_model: None,
 82            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
 83        }
 84    }
 85
 86    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
 87        self.active_provider.clone()
 88    }
 89
 90    pub fn set_active_provider(
 91        &mut self,
 92        provider_name: LanguageModelProviderName,
 93        cx: &mut ModelContext<Self>,
 94    ) {
 95        self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
 96        self.active_model = None;
 97        cx.notify();
 98    }
 99
100    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
101        self.active_model.clone()
102    }
103
104    pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
105        if self.active_model.as_ref().map_or(false, |m| {
106            m.id() == model.id() && m.provider_name() == model.provider_name()
107        }) {
108            return;
109        }
110
111        self.active_provider =
112            LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
113        self.active_model = Some(model);
114        cx.notify();
115    }
116
117    pub fn is_authenticated(&self, cx: &AppContext) -> bool {
118        self.active_provider
119            .as_ref()
120            .map_or(false, |provider| provider.is_authenticated(cx))
121    }
122
123    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
124        self.active_provider
125            .as_ref()
126            .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
127    }
128
129    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
130        self.active_provider
131            .as_ref()
132            .map_or(Task::ready(Ok(())), |provider| {
133                provider.reset_credentials(cx)
134            })
135    }
136
137    pub fn count_tokens(
138        &self,
139        request: LanguageModelRequest,
140        cx: &AppContext,
141    ) -> BoxFuture<'static, Result<usize>> {
142        if let Some(model) = self.active_model() {
143            model.count_tokens(request, cx)
144        } else {
145            std::future::ready(Err(anyhow!("No active model set"))).boxed()
146        }
147    }
148
149    pub fn stream_completion(
150        &self,
151        request: LanguageModelRequest,
152        cx: &AppContext,
153    ) -> Task<Result<LanguageModelCompletionResponse>> {
154        if let Some(language_model) = self.active_model() {
155            let rate_limiter = self.request_limiter.clone();
156            cx.spawn(|cx| async move {
157                let lock = rate_limiter.acquire_arc().await;
158                let response = language_model.stream_completion(request, &cx).await?;
159                Ok(LanguageModelCompletionResponse {
160                    inner: response,
161                    _lock: lock,
162                })
163            })
164        } else {
165            Task::ready(Err(anyhow!("No active model set")))
166        }
167    }
168
169    pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
170        let response = self.stream_completion(request, cx);
171        cx.foreground_executor().spawn(async move {
172            let mut chunks = response.await?;
173            let mut completion = String::new();
174            while let Some(chunk) = chunks.next().await {
175                let chunk = chunk?;
176                completion.push_str(&chunk);
177            }
178            Ok(completion)
179        })
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use futures::StreamExt;
186    use gpui::AppContext;
187    use settings::SettingsStore;
188    use ui::Context;
189
190    use crate::{
191        LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
192    };
193
194    use language_model::LanguageModelRegistry;
195
196    #[gpui::test]
197    fn test_rate_limiting(cx: &mut AppContext) {
198        SettingsStore::test(cx);
199        let fake_provider = LanguageModelRegistry::test(cx);
200
201        let model = LanguageModelRegistry::read_global(cx)
202            .available_models(cx)
203            .first()
204            .cloned()
205            .unwrap();
206
207        let provider = cx.new_model(|cx| {
208            let mut provider = LanguageModelCompletionProvider::new(cx);
209            provider.set_active_model(model.clone(), cx);
210            provider
211        });
212
213        let fake_model = fake_provider.test_model();
214
215        // Enqueue some requests
216        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
217            let response = provider.read(cx).stream_completion(
218                LanguageModelRequest {
219                    temperature: i as f32 / 10.0,
220                    ..Default::default()
221                },
222                cx,
223            );
224            cx.background_executor()
225                .spawn(async move {
226                    let mut stream = response.await.unwrap();
227                    while let Some(message) = stream.next().await {
228                        message.unwrap();
229                    }
230                })
231                .detach();
232        }
233        cx.background_executor().run_until_parked();
234        assert_eq!(
235            fake_model.completion_count(),
236            MAX_CONCURRENT_COMPLETION_REQUESTS
237        );
238
239        // Get the first completion request that is in flight and mark it as completed.
240        let completion = fake_model.pending_completions().into_iter().next().unwrap();
241        fake_model.finish_completion(&completion);
242
243        // Ensure that the number of in-flight completion requests is reduced.
244        assert_eq!(
245            fake_model.completion_count(),
246            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
247        );
248
249        cx.background_executor().run_until_parked();
250
251        // Ensure that another completion request was allowed to acquire the lock.
252        assert_eq!(
253            fake_model.completion_count(),
254            MAX_CONCURRENT_COMPLETION_REQUESTS
255        );
256
257        // Mark all completion requests as finished that are in flight.
258        for request in fake_model.pending_completions() {
259            fake_model.finish_completion(&request);
260        }
261
262        assert_eq!(fake_model.completion_count(), 0);
263
264        // Wait until the background tasks acquire the lock again.
265        cx.background_executor().run_until_parked();
266
267        assert_eq!(
268            fake_model.completion_count(),
269            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
270        );
271
272        // Finish all remaining completion requests.
273        for request in fake_model.pending_completions() {
274            fake_model.finish_completion(&request);
275        }
276
277        cx.background_executor().run_until_parked();
278
279        assert_eq!(fake_model.completion_count(), 0);
280    }
281}