fake.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  4use gpui::{AnyView, AppContext, Task};
  5use std::sync::Arc;
  6use ui::WindowContext;
  7
  8use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
  9
 10#[derive(Clone, Default)]
 11pub struct FakeCompletionProvider {
 12    current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 13}
 14
 15impl FakeCompletionProvider {
 16    pub fn setup_test(cx: &mut AppContext) -> Self {
 17        use crate::CompletionProvider;
 18        use parking_lot::RwLock;
 19
 20        let this = Self::default();
 21        let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
 22        cx.set_global(provider);
 23        this
 24    }
 25
 26    pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
 27        self.current_completion_txs
 28            .lock()
 29            .keys()
 30            .map(|k| serde_json::from_str(k).unwrap())
 31            .collect()
 32    }
 33
 34    pub fn completion_count(&self) -> usize {
 35        self.current_completion_txs.lock().len()
 36    }
 37
 38    pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
 39        let json = serde_json::to_string(request).unwrap();
 40        self.current_completion_txs
 41            .lock()
 42            .get(&json)
 43            .unwrap()
 44            .unbounded_send(chunk)
 45            .unwrap();
 46    }
 47
 48    pub fn finish_completion(&self, request: &LanguageModelRequest) {
 49        self.current_completion_txs
 50            .lock()
 51            .remove(&serde_json::to_string(request).unwrap());
 52    }
 53}
 54
 55impl LanguageModelCompletionProvider for FakeCompletionProvider {
 56    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
 57        vec![LanguageModel::default()]
 58    }
 59
 60    fn settings_version(&self) -> usize {
 61        0
 62    }
 63
 64    fn is_authenticated(&self) -> bool {
 65        true
 66    }
 67
 68    fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
 69        Task::ready(Ok(()))
 70    }
 71
 72    fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
 73        unimplemented!()
 74    }
 75
 76    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
 77        Task::ready(Ok(()))
 78    }
 79
 80    fn model(&self) -> LanguageModel {
 81        LanguageModel::default()
 82    }
 83
 84    fn count_tokens(
 85        &self,
 86        _request: LanguageModelRequest,
 87        _cx: &AppContext,
 88    ) -> BoxFuture<'static, Result<usize>> {
 89        futures::future::ready(Ok(0)).boxed()
 90    }
 91
 92    fn complete(
 93        &self,
 94        _request: LanguageModelRequest,
 95    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 96        let (tx, rx) = mpsc::unbounded();
 97        self.current_completion_txs
 98            .lock()
 99            .insert(serde_json::to_string(&_request).unwrap(), tx);
100        async move { Ok(rx.map(Ok).boxed()) }.boxed()
101    }
102
103    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
104        self
105    }
106}