fake.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  4use gpui::{AnyView, AppContext, Task};
  5use std::sync::Arc;
  6use ui::WindowContext;
  7
  8use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
  9
 10#[derive(Clone, Default)]
 11pub struct FakeCompletionProvider {
 12    current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 13}
 14
 15impl FakeCompletionProvider {
 16    #[cfg(test)]
 17    pub fn setup_test(cx: &mut AppContext) -> Self {
 18        use crate::CompletionProvider;
 19        use parking_lot::RwLock;
 20
 21        let this = Self::default();
 22        let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
 23        cx.set_global(provider);
 24        this
 25    }
 26
 27    pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
 28        self.current_completion_txs
 29            .lock()
 30            .keys()
 31            .map(|k| serde_json::from_str(k).unwrap())
 32            .collect()
 33    }
 34
 35    pub fn completion_count(&self) -> usize {
 36        self.current_completion_txs.lock().len()
 37    }
 38
 39    pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
 40        let json = serde_json::to_string(request).unwrap();
 41        self.current_completion_txs
 42            .lock()
 43            .get(&json)
 44            .unwrap()
 45            .unbounded_send(chunk)
 46            .unwrap();
 47    }
 48
 49    pub fn finish_completion(&self, request: &LanguageModelRequest) {
 50        self.current_completion_txs
 51            .lock()
 52            .remove(&serde_json::to_string(request).unwrap());
 53    }
 54}
 55
 56impl LanguageModelCompletionProvider for FakeCompletionProvider {
 57    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
 58        vec![LanguageModel::default()]
 59    }
 60
 61    fn settings_version(&self) -> usize {
 62        0
 63    }
 64
 65    fn is_authenticated(&self) -> bool {
 66        true
 67    }
 68
 69    fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
 70        Task::ready(Ok(()))
 71    }
 72
 73    fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
 74        unimplemented!()
 75    }
 76
 77    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
 78        Task::ready(Ok(()))
 79    }
 80
 81    fn model(&self) -> LanguageModel {
 82        LanguageModel::default()
 83    }
 84
 85    fn count_tokens(
 86        &self,
 87        _request: LanguageModelRequest,
 88        _cx: &AppContext,
 89    ) -> BoxFuture<'static, Result<usize>> {
 90        futures::future::ready(Ok(0)).boxed()
 91    }
 92
 93    fn complete(
 94        &self,
 95        _request: LanguageModelRequest,
 96    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 97        let (tx, rx) = mpsc::unbounded();
 98        self.current_completion_txs
 99            .lock()
100            .insert(serde_json::to_string(&_request).unwrap(), tx);
101        async move { Ok(rx.map(Ok).boxed()) }.boxed()
102    }
103
104    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
105        self
106    }
107}