completion.rs

  1use anyhow::{anyhow, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
  3use gpui::{AppContext, Global, Model, ModelContext, Task};
  4use language_model::{
  5    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
  6    LanguageModelRequest, LanguageModelTool,
  7};
  8use smol::{
  9    future::FutureExt,
 10    lock::{Semaphore, SemaphoreGuardArc},
 11};
 12use std::{future, pin::Pin, sync::Arc, task::Poll};
 13use ui::Context;
 14
 15pub fn init(cx: &mut AppContext) {
 16    let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
 17    cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
 18}
 19
 20struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
 21
 22impl Global for GlobalLanguageModelCompletionProvider {}
 23
 24pub struct LanguageModelCompletionProvider {
 25    active_provider: Option<Arc<dyn LanguageModelProvider>>,
 26    active_model: Option<Arc<dyn LanguageModel>>,
 27    request_limiter: Arc<Semaphore>,
 28}
 29
 30const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
 31
 32pub struct LanguageModelCompletionResponse {
 33    inner: BoxStream<'static, Result<String>>,
 34    _lock: SemaphoreGuardArc,
 35}
 36
 37impl futures::Stream for LanguageModelCompletionResponse {
 38    type Item = Result<String>;
 39
 40    fn poll_next(
 41        mut self: Pin<&mut Self>,
 42        cx: &mut std::task::Context<'_>,
 43    ) -> Poll<Option<Self::Item>> {
 44        Pin::new(&mut self.inner).poll_next(cx)
 45    }
 46}
 47
 48impl LanguageModelCompletionProvider {
 49    pub fn global(cx: &AppContext) -> Model<Self> {
 50        cx.global::<GlobalLanguageModelCompletionProvider>()
 51            .0
 52            .clone()
 53    }
 54
 55    pub fn read_global(cx: &AppContext) -> &Self {
 56        cx.global::<GlobalLanguageModelCompletionProvider>()
 57            .0
 58            .read(cx)
 59    }
 60
 61    #[cfg(any(test, feature = "test-support"))]
 62    pub fn test(cx: &mut AppContext) {
 63        let provider = cx.new_model(|cx| {
 64            let mut this = Self::new(cx);
 65            let available_model = LanguageModelRegistry::read_global(cx)
 66                .available_models(cx)
 67                .first()
 68                .unwrap()
 69                .clone();
 70            this.set_active_model(available_model, cx);
 71            this
 72        });
 73        cx.set_global(GlobalLanguageModelCompletionProvider(provider));
 74    }
 75
 76    pub fn new(cx: &mut ModelContext<Self>) -> Self {
 77        cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
 78            cx.notify();
 79        })
 80        .detach();
 81
 82        Self {
 83            active_provider: None,
 84            active_model: None,
 85            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
 86        }
 87    }
 88
 89    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
 90        self.active_provider.clone()
 91    }
 92
 93    pub fn set_active_provider(
 94        &mut self,
 95        provider_id: LanguageModelProviderId,
 96        cx: &mut ModelContext<Self>,
 97    ) {
 98        self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
 99        self.active_model = None;
100        cx.notify();
101    }
102
103    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
104        self.active_model.clone()
105    }
106
107    pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
108        if self.active_model.as_ref().map_or(false, |m| {
109            m.id() == model.id() && m.provider_id() == model.provider_id()
110        }) {
111            return;
112        }
113
114        self.active_provider =
115            LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
116        self.active_model = Some(model.clone());
117
118        if let Some(provider) = self.active_provider.as_ref() {
119            provider.load_model(model, cx);
120        }
121
122        cx.notify();
123    }
124
125    pub fn is_authenticated(&self, cx: &AppContext) -> bool {
126        self.active_provider
127            .as_ref()
128            .map_or(false, |provider| provider.is_authenticated(cx))
129    }
130
131    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
132        self.active_provider
133            .as_ref()
134            .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
135    }
136
137    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
138        self.active_provider
139            .as_ref()
140            .map_or(Task::ready(Ok(())), |provider| {
141                provider.reset_credentials(cx)
142            })
143    }
144
145    pub fn count_tokens(
146        &self,
147        request: LanguageModelRequest,
148        cx: &AppContext,
149    ) -> BoxFuture<'static, Result<usize>> {
150        if let Some(model) = self.active_model() {
151            model.count_tokens(request, cx)
152        } else {
153            future::ready(Err(anyhow!("no active model"))).boxed()
154        }
155    }
156
157    pub fn stream_completion(
158        &self,
159        request: LanguageModelRequest,
160        cx: &AppContext,
161    ) -> Task<Result<LanguageModelCompletionResponse>> {
162        if let Some(language_model) = self.active_model() {
163            let rate_limiter = self.request_limiter.clone();
164            cx.spawn(|cx| async move {
165                let lock = rate_limiter.acquire_arc().await;
166                let response = language_model.stream_completion(request, &cx).await?;
167                Ok(LanguageModelCompletionResponse {
168                    inner: response,
169                    _lock: lock,
170                })
171            })
172        } else {
173            Task::ready(Err(anyhow!("No active model set")))
174        }
175    }
176
177    pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
178        let response = self.stream_completion(request, cx);
179        cx.foreground_executor().spawn(async move {
180            let mut chunks = response.await?;
181            let mut completion = String::new();
182            while let Some(chunk) = chunks.next().await {
183                let chunk = chunk?;
184                completion.push_str(&chunk);
185            }
186            Ok(completion)
187        })
188    }
189
190    pub fn use_tool<T: LanguageModelTool>(
191        &self,
192        request: LanguageModelRequest,
193        cx: &AppContext,
194    ) -> Task<Result<T>> {
195        if let Some(language_model) = self.active_model() {
196            cx.spawn(|cx| async move {
197                let schema = schemars::schema_for!(T);
198                let schema_json = serde_json::to_value(&schema).unwrap();
199                let request =
200                    language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
201                let response = request.await?;
202                Ok(serde_json::from_value(response)?)
203            })
204        } else {
205            Task::ready(Err(anyhow!("No active model set")))
206        }
207    }
208
209    pub fn active_model_telemetry_id(&self) -> Option<String> {
210        self.active_model.as_ref().map(|m| m.telemetry_id())
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use futures::StreamExt;
217    use gpui::AppContext;
218    use settings::SettingsStore;
219    use ui::Context;
220
221    use crate::{
222        LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
223    };
224
225    use language_model::LanguageModelRegistry;
226
227    #[gpui::test]
228    fn test_rate_limiting(cx: &mut AppContext) {
229        SettingsStore::test(cx);
230        let fake_provider = LanguageModelRegistry::test(cx);
231
232        let model = LanguageModelRegistry::read_global(cx)
233            .available_models(cx)
234            .first()
235            .cloned()
236            .unwrap();
237
238        let provider = cx.new_model(|cx| {
239            let mut provider = LanguageModelCompletionProvider::new(cx);
240            provider.set_active_model(model.clone(), cx);
241            provider
242        });
243
244        let fake_model = fake_provider.test_model();
245
246        // Enqueue some requests
247        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
248            let response = provider.read(cx).stream_completion(
249                LanguageModelRequest {
250                    temperature: i as f32 / 10.0,
251                    ..Default::default()
252                },
253                cx,
254            );
255            cx.background_executor()
256                .spawn(async move {
257                    let mut stream = response.await.unwrap();
258                    while let Some(message) = stream.next().await {
259                        message.unwrap();
260                    }
261                })
262                .detach();
263        }
264        cx.background_executor().run_until_parked();
265        assert_eq!(
266            fake_model.completion_count(),
267            MAX_CONCURRENT_COMPLETION_REQUESTS
268        );
269
270        // Get the first completion request that is in flight and mark it as completed.
271        let completion = fake_model.pending_completions().into_iter().next().unwrap();
272        fake_model.finish_completion(&completion);
273
274        // Ensure that the number of in-flight completion requests is reduced.
275        assert_eq!(
276            fake_model.completion_count(),
277            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
278        );
279
280        cx.background_executor().run_until_parked();
281
282        // Ensure that another completion request was allowed to acquire the lock.
283        assert_eq!(
284            fake_model.completion_count(),
285            MAX_CONCURRENT_COMPLETION_REQUESTS
286        );
287
288        // Mark all completion requests as finished that are in flight.
289        for request in fake_model.pending_completions() {
290            fake_model.finish_completion(&request);
291        }
292
293        assert_eq!(fake_model.completion_count(), 0);
294
295        // Wait until the background tasks acquire the lock again.
296        cx.background_executor().run_until_parked();
297
298        assert_eq!(
299            fake_model.completion_count(),
300            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
301        );
302
303        // Finish all remaining completion requests.
304        for request in fake_model.pending_completions() {
305            fake_model.finish_completion(&request);
306        }
307
308        cx.background_executor().run_until_parked();
309
310        assert_eq!(fake_model.completion_count(), 0);
311    }
312}