fake.rs

  1use std::sync::{Arc, Mutex};
  2
  3use collections::HashMap;
  4use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  5
  6use crate::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  9};
 10use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 11use http::Result;
 12use ui::WindowContext;
 13
 14pub fn language_model_id() -> LanguageModelId {
 15    LanguageModelId::from("fake".to_string())
 16}
 17
 18pub fn language_model_name() -> LanguageModelName {
 19    LanguageModelName::from("Fake".to_string())
 20}
 21
 22pub fn provider_name() -> LanguageModelProviderName {
 23    LanguageModelProviderName::from("fake".to_string())
 24}
 25
 26#[derive(Clone, Default)]
 27pub struct FakeLanguageModelProvider {
 28    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 29}
 30
 31impl LanguageModelProviderState for FakeLanguageModelProvider {
 32    fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 33        None
 34    }
 35}
 36
 37impl LanguageModelProvider for FakeLanguageModelProvider {
 38    fn name(&self) -> LanguageModelProviderName {
 39        provider_name()
 40    }
 41
 42    fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 43        vec![Arc::new(FakeLanguageModel {
 44            current_completion_txs: self.current_completion_txs.clone(),
 45        })]
 46    }
 47
 48    fn is_authenticated(&self, _: &AppContext) -> bool {
 49        true
 50    }
 51
 52    fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
 53        Task::ready(Ok(()))
 54    }
 55
 56    fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
 57        unimplemented!()
 58    }
 59
 60    fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
 61        Task::ready(Ok(()))
 62    }
 63}
 64
 65impl FakeLanguageModelProvider {
 66    pub fn test_model(&self) -> FakeLanguageModel {
 67        FakeLanguageModel {
 68            current_completion_txs: self.current_completion_txs.clone(),
 69        }
 70    }
 71}
 72
 73pub struct FakeLanguageModel {
 74    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 75}
 76
 77impl FakeLanguageModel {
 78    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 79        self.current_completion_txs
 80            .lock()
 81            .unwrap()
 82            .keys()
 83            .map(|k| serde_json::from_str(k).unwrap())
 84            .collect()
 85    }
 86
 87    pub fn completion_count(&self) -> usize {
 88        self.current_completion_txs.lock().unwrap().len()
 89    }
 90
 91    pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
 92        let json = serde_json::to_string(request).unwrap();
 93        self.current_completion_txs
 94            .lock()
 95            .unwrap()
 96            .get(&json)
 97            .unwrap()
 98            .unbounded_send(chunk)
 99            .unwrap();
100    }
101
102    pub fn send_last_completion_chunk(&self, chunk: String) {
103        self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
104    }
105
106    pub fn finish_completion(&self, request: &LanguageModelRequest) {
107        self.current_completion_txs
108            .lock()
109            .unwrap()
110            .remove(&serde_json::to_string(request).unwrap())
111            .unwrap();
112    }
113
114    pub fn finish_last_completion(&self) {
115        self.finish_completion(self.pending_completions().last().unwrap());
116    }
117}
118
119impl LanguageModel for FakeLanguageModel {
120    fn id(&self) -> LanguageModelId {
121        language_model_id()
122    }
123
124    fn name(&self) -> LanguageModelName {
125        language_model_name()
126    }
127
128    fn provider_name(&self) -> LanguageModelProviderName {
129        provider_name()
130    }
131
132    fn telemetry_id(&self) -> String {
133        "fake".to_string()
134    }
135
136    fn max_token_count(&self) -> usize {
137        1000000
138    }
139
140    fn count_tokens(
141        &self,
142        _: LanguageModelRequest,
143        _: &AppContext,
144    ) -> BoxFuture<'static, Result<usize>> {
145        futures::future::ready(Ok(0)).boxed()
146    }
147
148    fn stream_completion(
149        &self,
150        request: LanguageModelRequest,
151        _: &AsyncAppContext,
152    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
153        let (tx, rx) = mpsc::unbounded();
154        self.current_completion_txs
155            .lock()
156            .unwrap()
157            .insert(serde_json::to_string(&request).unwrap(), tx);
158        async move { Ok(rx.map(Ok).boxed()) }.boxed()
159    }
160}