completion_provider.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(any(test, feature = "test-support"))]
  4mod fake;
  5mod ollama;
  6mod open_ai;
  7
  8pub use anthropic::*;
  9pub use cloud::*;
 10#[cfg(any(test, feature = "test-support"))]
 11pub use fake::*;
 12pub use ollama::*;
 13pub use open_ai::*;
 14use parking_lot::RwLock;
 15use smol::lock::{Semaphore, SemaphoreGuardArc};
 16
 17use crate::{
 18    assistant_settings::{AssistantProvider, AssistantSettings},
 19    LanguageModel, LanguageModelRequest,
 20};
 21use anyhow::Result;
 22use client::Client;
 23use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
 24use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 25use settings::{Settings, SettingsStore};
 26use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
 27
 28/// Choose which model to use for openai provider.
 29/// If the model is not available, try to use the first available model, or fallback to the original model.
 30fn choose_openai_model(
 31    model: &::open_ai::Model,
 32    available_models: &[::open_ai::Model],
 33) -> ::open_ai::Model {
 34    available_models
 35        .iter()
 36        .find(|&m| m == model)
 37        .or_else(|| available_models.first())
 38        .unwrap_or_else(|| model)
 39        .clone()
 40}
 41
 42pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 43    let provider = create_provider_from_settings(client.clone(), 0, cx);
 44    cx.set_global(CompletionProvider::new(provider, Some(client)));
 45
 46    let mut settings_version = 0;
 47    cx.observe_global::<SettingsStore>(move |cx| {
 48        settings_version += 1;
 49        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 50            provider.update_settings(settings_version, cx);
 51        })
 52    })
 53    .detach();
 54}
 55
 56pub struct CompletionResponse {
 57    inner: BoxStream<'static, Result<String>>,
 58    _lock: SemaphoreGuardArc,
 59}
 60
 61impl futures::Stream for CompletionResponse {
 62    type Item = Result<String>;
 63
 64    fn poll_next(
 65        mut self: Pin<&mut Self>,
 66        cx: &mut std::task::Context<'_>,
 67    ) -> Poll<Option<Self::Item>> {
 68        Pin::new(&mut self.inner).poll_next(cx)
 69    }
 70}
 71
 72pub trait LanguageModelCompletionProvider: Send + Sync {
 73    fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
 74    fn settings_version(&self) -> usize;
 75    fn is_authenticated(&self) -> bool;
 76    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
 77    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 78    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
 79    fn model(&self) -> LanguageModel;
 80    fn count_tokens(
 81        &self,
 82        request: LanguageModelRequest,
 83        cx: &AppContext,
 84    ) -> BoxFuture<'static, Result<usize>>;
 85    fn stream_completion(
 86        &self,
 87        request: LanguageModelRequest,
 88    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 89
 90    fn as_any_mut(&mut self) -> &mut dyn Any;
 91}
 92
 93const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
 94
 95pub struct CompletionProvider {
 96    provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
 97    client: Option<Arc<Client>>,
 98    request_limiter: Arc<Semaphore>,
 99}
100
101impl CompletionProvider {
102    pub fn new(
103        provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
104        client: Option<Arc<Client>>,
105    ) -> Self {
106        Self {
107            provider,
108            client,
109            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
110        }
111    }
112
113    pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
114        self.provider.read().available_models(cx)
115    }
116
117    pub fn settings_version(&self) -> usize {
118        self.provider.read().settings_version()
119    }
120
121    pub fn is_authenticated(&self) -> bool {
122        self.provider.read().is_authenticated()
123    }
124
125    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
126        self.provider.read().authenticate(cx)
127    }
128
129    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
130        self.provider.read().authentication_prompt(cx)
131    }
132
133    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
134        self.provider.read().reset_credentials(cx)
135    }
136
137    pub fn model(&self) -> LanguageModel {
138        self.provider.read().model()
139    }
140
141    pub fn count_tokens(
142        &self,
143        request: LanguageModelRequest,
144        cx: &AppContext,
145    ) -> BoxFuture<'static, Result<usize>> {
146        self.provider.read().count_tokens(request, cx)
147    }
148
149    pub fn stream_completion(
150        &self,
151        request: LanguageModelRequest,
152        cx: &AppContext,
153    ) -> Task<Result<CompletionResponse>> {
154        let rate_limiter = self.request_limiter.clone();
155        let provider = self.provider.clone();
156        cx.foreground_executor().spawn(async move {
157            let lock = rate_limiter.acquire_arc().await;
158            let response = provider.read().stream_completion(request);
159            let response = response.await?;
160            Ok(CompletionResponse {
161                inner: response,
162                _lock: lock,
163            })
164        })
165    }
166
167    pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
168        let response = self.stream_completion(request, cx);
169        cx.foreground_executor().spawn(async move {
170            let mut chunks = response.await?;
171            let mut completion = String::new();
172            while let Some(chunk) = chunks.next().await {
173                let chunk = chunk?;
174                completion.push_str(&chunk);
175            }
176            Ok(completion)
177        })
178    }
179}
180
181impl gpui::Global for CompletionProvider {}
182
183impl CompletionProvider {
184    pub fn global(cx: &AppContext) -> &Self {
185        cx.global::<Self>()
186    }
187
188    pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
189        &mut self,
190        update: impl FnOnce(&mut T) -> R,
191    ) -> Option<R> {
192        let mut provider = self.provider.write();
193        if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
194            Some(update(provider))
195        } else {
196            None
197        }
198    }
199
200    pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
201        let updated = match &AssistantSettings::get_global(cx).provider {
202            AssistantProvider::ZedDotDev { model } => self
203                .update_current_as::<_, CloudCompletionProvider>(|provider| {
204                    provider.update(model.clone(), version);
205                }),
206            AssistantProvider::OpenAi {
207                model,
208                api_url,
209                low_speed_timeout_in_seconds,
210                available_models,
211            } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
212                provider.update(
213                    choose_openai_model(&model, &available_models),
214                    api_url.clone(),
215                    low_speed_timeout_in_seconds.map(Duration::from_secs),
216                    version,
217                );
218            }),
219            AssistantProvider::Anthropic {
220                model,
221                api_url,
222                low_speed_timeout_in_seconds,
223            } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
224                provider.update(
225                    model.clone(),
226                    api_url.clone(),
227                    low_speed_timeout_in_seconds.map(Duration::from_secs),
228                    version,
229                );
230            }),
231            AssistantProvider::Ollama {
232                model,
233                api_url,
234                low_speed_timeout_in_seconds,
235            } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
236                provider.update(
237                    model.clone(),
238                    api_url.clone(),
239                    low_speed_timeout_in_seconds.map(Duration::from_secs),
240                    version,
241                    cx,
242                );
243            }),
244        };
245
246        // Previously configured provider was changed to another one
247        if updated.is_none() {
248            if let Some(client) = self.client.clone() {
249                self.provider = create_provider_from_settings(client, version, cx);
250            } else {
251                log::warn!("completion provider cannot be created because client is not set");
252            }
253        }
254    }
255}
256
257fn create_provider_from_settings(
258    client: Arc<Client>,
259    settings_version: usize,
260    cx: &mut AppContext,
261) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
262    match &AssistantSettings::get_global(cx).provider {
263        AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
264            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
265        )),
266        AssistantProvider::OpenAi {
267            model,
268            api_url,
269            low_speed_timeout_in_seconds,
270            available_models,
271        } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
272            choose_openai_model(&model, &available_models),
273            api_url.clone(),
274            client.http_client(),
275            low_speed_timeout_in_seconds.map(Duration::from_secs),
276            settings_version,
277        ))),
278        AssistantProvider::Anthropic {
279            model,
280            api_url,
281            low_speed_timeout_in_seconds,
282        } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
283            model.clone(),
284            api_url.clone(),
285            client.http_client(),
286            low_speed_timeout_in_seconds.map(Duration::from_secs),
287            settings_version,
288        ))),
289        AssistantProvider::Ollama {
290            model,
291            api_url,
292            low_speed_timeout_in_seconds,
293        } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
294            model.clone(),
295            api_url.clone(),
296            client.http_client(),
297            low_speed_timeout_in_seconds.map(Duration::from_secs),
298            settings_version,
299            cx,
300        ))),
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use std::sync::Arc;
307
308    use gpui::AppContext;
309    use parking_lot::RwLock;
310    use settings::SettingsStore;
311    use smol::stream::StreamExt;
312
313    use crate::{
314        completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
315        FakeCompletionProvider, LanguageModelRequest,
316    };
317
318    #[gpui::test]
319    fn test_rate_limiting(cx: &mut AppContext) {
320        SettingsStore::test(cx);
321        let fake_provider = FakeCompletionProvider::setup_test(cx);
322
323        let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
324
325        // Enqueue some requests
326        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
327            let response = provider.stream_completion(
328                LanguageModelRequest {
329                    temperature: i as f32 / 10.0,
330                    ..Default::default()
331                },
332                cx,
333            );
334            cx.background_executor()
335                .spawn(async move {
336                    let mut stream = response.await.unwrap();
337                    while let Some(message) = stream.next().await {
338                        message.unwrap();
339                    }
340                })
341                .detach();
342        }
343        cx.background_executor().run_until_parked();
344
345        assert_eq!(
346            fake_provider.completion_count(),
347            MAX_CONCURRENT_COMPLETION_REQUESTS
348        );
349
350        // Get the first completion request that is in flight and mark it as completed.
351        let completion = fake_provider
352            .pending_completions()
353            .into_iter()
354            .next()
355            .unwrap();
356        fake_provider.finish_completion(&completion);
357
358        // Ensure that the number of in-flight completion requests is reduced.
359        assert_eq!(
360            fake_provider.completion_count(),
361            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
362        );
363
364        cx.background_executor().run_until_parked();
365
366        // Ensure that another completion request was allowed to acquire the lock.
367        assert_eq!(
368            fake_provider.completion_count(),
369            MAX_CONCURRENT_COMPLETION_REQUESTS
370        );
371
372        // Mark all completion requests as finished that are in flight.
373        for request in fake_provider.pending_completions() {
374            fake_provider.finish_completion(&request);
375        }
376
377        assert_eq!(fake_provider.completion_count(), 0);
378
379        // Wait until the background tasks acquire the lock again.
380        cx.background_executor().run_until_parked();
381
382        assert_eq!(
383            fake_provider.completion_count(),
384            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
385        );
386
387        // Finish all remaining completion requests.
388        for request in fake_provider.pending_completions() {
389            fake_provider.finish_completion(&request);
390        }
391
392        cx.background_executor().run_until_parked();
393
394        assert_eq!(fake_provider.completion_count(), 0);
395    }
396}