fake.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  4use gpui::{AnyView, AppContext, Task};
  5use std::sync::Arc;
  6use ui::WindowContext;
  7
  8use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
  9
 10#[derive(Clone, Default)]
 11pub struct FakeCompletionProvider {
 12    current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 13}
 14
 15impl FakeCompletionProvider {
 16    pub fn setup_test(cx: &mut AppContext) -> Self {
 17        use crate::CompletionProvider;
 18        use parking_lot::RwLock;
 19
 20        let this = Self::default();
 21        let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
 22        cx.set_global(provider);
 23        this
 24    }
 25
 26    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 27        self.current_completion_txs
 28            .lock()
 29            .keys()
 30            .map(|k| serde_json::from_str(k).unwrap())
 31            .collect()
 32    }
 33
 34    pub fn completion_count(&self) -> usize {
 35        self.current_completion_txs.lock().len()
 36    }
 37
 38    pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
 39        let json = serde_json::to_string(request).unwrap();
 40        self.current_completion_txs
 41            .lock()
 42            .get(&json)
 43            .unwrap()
 44            .unbounded_send(chunk)
 45            .unwrap();
 46    }
 47
 48    pub fn send_last_completion_chunk(&self, chunk: String) {
 49        self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
 50    }
 51
 52    pub fn finish_completion(&self, request: &LanguageModelRequest) {
 53        self.current_completion_txs
 54            .lock()
 55            .remove(&serde_json::to_string(request).unwrap())
 56            .unwrap();
 57    }
 58
 59    pub fn finish_last_completion(&self) {
 60        self.finish_completion(self.pending_completions().last().unwrap());
 61    }
 62}
 63
 64impl LanguageModelCompletionProvider for FakeCompletionProvider {
 65    fn available_models(&self) -> Vec<LanguageModel> {
 66        vec![LanguageModel::default()]
 67    }
 68
 69    fn settings_version(&self) -> usize {
 70        0
 71    }
 72
 73    fn is_authenticated(&self) -> bool {
 74        true
 75    }
 76
 77    fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
 78        Task::ready(Ok(()))
 79    }
 80
 81    fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
 82        unimplemented!()
 83    }
 84
 85    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
 86        Task::ready(Ok(()))
 87    }
 88
 89    fn model(&self) -> LanguageModel {
 90        LanguageModel::default()
 91    }
 92
 93    fn count_tokens(
 94        &self,
 95        _request: LanguageModelRequest,
 96        _cx: &AppContext,
 97    ) -> BoxFuture<'static, Result<usize>> {
 98        futures::future::ready(Ok(0)).boxed()
 99    }
100
101    fn stream_completion(
102        &self,
103        _request: LanguageModelRequest,
104    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
105        let (tx, rx) = mpsc::unbounded();
106        self.current_completion_txs
107            .lock()
108            .insert(serde_json::to_string(&_request).unwrap(), tx);
109        async move { Ok(rx.map(Ok).boxed()) }.boxed()
110    }
111
112    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
113        self
114    }
115}