completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
  3use gpui::{AppContext, Global, Model, ModelContext, Task};
  4use language_model::{
  5    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
  6    LanguageModelRequest,
  7};
  8use smol::lock::{Semaphore, SemaphoreGuardArc};
  9use std::{pin::Pin, sync::Arc, task::Poll};
 10use ui::Context;
 11
 12pub fn init(cx: &mut AppContext) {
 13    let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
 14    cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
 15}
 16
 17struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
 18
 19impl Global for GlobalLanguageModelCompletionProvider {}
 20
 21pub struct LanguageModelCompletionProvider {
 22    active_provider: Option<Arc<dyn LanguageModelProvider>>,
 23    active_model: Option<Arc<dyn LanguageModel>>,
 24    request_limiter: Arc<Semaphore>,
 25}
 26
 27const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
 28
 29pub struct LanguageModelCompletionResponse {
 30    inner: BoxStream<'static, Result<String>>,
 31    _lock: SemaphoreGuardArc,
 32}
 33
 34impl futures::Stream for LanguageModelCompletionResponse {
 35    type Item = Result<String>;
 36
 37    fn poll_next(
 38        mut self: Pin<&mut Self>,
 39        cx: &mut std::task::Context<'_>,
 40    ) -> Poll<Option<Self::Item>> {
 41        Pin::new(&mut self.inner).poll_next(cx)
 42    }
 43}
 44
 45impl LanguageModelCompletionProvider {
 46    pub fn global(cx: &AppContext) -> Model<Self> {
 47        cx.global::<GlobalLanguageModelCompletionProvider>()
 48            .0
 49            .clone()
 50    }
 51
 52    pub fn read_global(cx: &AppContext) -> &Self {
 53        cx.global::<GlobalLanguageModelCompletionProvider>()
 54            .0
 55            .read(cx)
 56    }
 57
 58    #[cfg(any(test, feature = "test-support"))]
 59    pub fn test(cx: &mut AppContext) {
 60        let provider = cx.new_model(|cx| {
 61            let mut this = Self::new(cx);
 62            let available_model = LanguageModelRegistry::read_global(cx)
 63                .available_models(cx)
 64                .first()
 65                .unwrap()
 66                .clone();
 67            this.set_active_model(available_model, cx);
 68            this
 69        });
 70        cx.set_global(GlobalLanguageModelCompletionProvider(provider));
 71    }
 72
 73    pub fn new(cx: &mut ModelContext<Self>) -> Self {
 74        cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
 75            cx.notify();
 76        })
 77        .detach();
 78
 79        Self {
 80            active_provider: None,
 81            active_model: None,
 82            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
 83        }
 84    }
 85
 86    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
 87        self.active_provider.clone()
 88    }
 89
 90    pub fn set_active_provider(
 91        &mut self,
 92        provider_id: LanguageModelProviderId,
 93        cx: &mut ModelContext<Self>,
 94    ) {
 95        self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
 96        self.active_model = None;
 97        cx.notify();
 98    }
 99
100    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
101        self.active_model.clone()
102    }
103
104    pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
105        if self.active_model.as_ref().map_or(false, |m| {
106            m.id() == model.id() && m.provider_id() == model.provider_id()
107        }) {
108            return;
109        }
110
111        self.active_provider =
112            LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
113        self.active_model = Some(model.clone());
114
115        if let Some(provider) = self.active_provider.as_ref() {
116            provider.load_model(model, cx);
117        }
118
119        cx.notify();
120    }
121
122    pub fn is_authenticated(&self, cx: &AppContext) -> bool {
123        self.active_provider
124            .as_ref()
125            .map_or(false, |provider| provider.is_authenticated(cx))
126    }
127
128    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
129        self.active_provider
130            .as_ref()
131            .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
132    }
133
134    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
135        self.active_provider
136            .as_ref()
137            .map_or(Task::ready(Ok(())), |provider| {
138                provider.reset_credentials(cx)
139            })
140    }
141
142    pub fn count_tokens(
143        &self,
144        request: LanguageModelRequest,
145        cx: &AppContext,
146    ) -> Option<BoxFuture<'static, Result<usize>>> {
147        if let Some(model) = self.active_model() {
148            Some(model.count_tokens(request, cx))
149        } else {
150            None
151        }
152    }
153
154    pub fn stream_completion(
155        &self,
156        request: LanguageModelRequest,
157        cx: &AppContext,
158    ) -> Task<Result<LanguageModelCompletionResponse>> {
159        if let Some(language_model) = self.active_model() {
160            let rate_limiter = self.request_limiter.clone();
161            cx.spawn(|cx| async move {
162                let lock = rate_limiter.acquire_arc().await;
163                let response = language_model.stream_completion(request, &cx).await?;
164                Ok(LanguageModelCompletionResponse {
165                    inner: response,
166                    _lock: lock,
167                })
168            })
169        } else {
170            Task::ready(Err(anyhow!("No active model set")))
171        }
172    }
173
174    pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
175        let response = self.stream_completion(request, cx);
176        cx.foreground_executor().spawn(async move {
177            let mut chunks = response.await?;
178            let mut completion = String::new();
179            while let Some(chunk) = chunks.next().await {
180                let chunk = chunk?;
181                completion.push_str(&chunk);
182            }
183            Ok(completion)
184        })
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use futures::StreamExt;
191    use gpui::AppContext;
192    use settings::SettingsStore;
193    use ui::Context;
194
195    use crate::{
196        LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
197    };
198
199    use language_model::LanguageModelRegistry;
200
201    #[gpui::test]
202    fn test_rate_limiting(cx: &mut AppContext) {
203        SettingsStore::test(cx);
204        let fake_provider = LanguageModelRegistry::test(cx);
205
206        let model = LanguageModelRegistry::read_global(cx)
207            .available_models(cx)
208            .first()
209            .cloned()
210            .unwrap();
211
212        let provider = cx.new_model(|cx| {
213            let mut provider = LanguageModelCompletionProvider::new(cx);
214            provider.set_active_model(model.clone(), cx);
215            provider
216        });
217
218        let fake_model = fake_provider.test_model();
219
220        // Enqueue some requests
221        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
222            let response = provider.read(cx).stream_completion(
223                LanguageModelRequest {
224                    temperature: i as f32 / 10.0,
225                    ..Default::default()
226                },
227                cx,
228            );
229            cx.background_executor()
230                .spawn(async move {
231                    let mut stream = response.await.unwrap();
232                    while let Some(message) = stream.next().await {
233                        message.unwrap();
234                    }
235                })
236                .detach();
237        }
238        cx.background_executor().run_until_parked();
239        assert_eq!(
240            fake_model.completion_count(),
241            MAX_CONCURRENT_COMPLETION_REQUESTS
242        );
243
244        // Get the first completion request that is in flight and mark it as completed.
245        let completion = fake_model.pending_completions().into_iter().next().unwrap();
246        fake_model.finish_completion(&completion);
247
248        // Ensure that the number of in-flight completion requests is reduced.
249        assert_eq!(
250            fake_model.completion_count(),
251            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
252        );
253
254        cx.background_executor().run_until_parked();
255
256        // Ensure that another completion request was allowed to acquire the lock.
257        assert_eq!(
258            fake_model.completion_count(),
259            MAX_CONCURRENT_COMPLETION_REQUESTS
260        );
261
262        // Mark all completion requests as finished that are in flight.
263        for request in fake_model.pending_completions() {
264            fake_model.finish_completion(&request);
265        }
266
267        assert_eq!(fake_model.completion_count(), 0);
268
269        // Wait until the background tasks acquire the lock again.
270        cx.background_executor().run_until_parked();
271
272        assert_eq!(
273            fake_model.completion_count(),
274            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
275        );
276
277        // Finish all remaining completion requests.
278        for request in fake_model.pending_completions() {
279            fake_model.finish_completion(&request);
280        }
281
282        cx.background_executor().run_until_parked();
283
284        assert_eq!(fake_model.completion_count(), 0);
285    }
286}