fake.rs

  1use std::sync::{Arc, Mutex};
  2
  3use collections::HashMap;
  4use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  5
  6use crate::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest,
 10};
 11use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 12use http_client::Result;
 13use ui::WindowContext;
 14
 15pub fn language_model_id() -> LanguageModelId {
 16    LanguageModelId::from("fake".to_string())
 17}
 18
 19pub fn language_model_name() -> LanguageModelName {
 20    LanguageModelName::from("Fake".to_string())
 21}
 22
 23pub fn provider_id() -> LanguageModelProviderId {
 24    LanguageModelProviderId::from("fake".to_string())
 25}
 26
 27pub fn provider_name() -> LanguageModelProviderName {
 28    LanguageModelProviderName::from("Fake".to_string())
 29}
 30
 31#[derive(Clone, Default)]
 32pub struct FakeLanguageModelProvider {
 33    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 34}
 35
 36impl LanguageModelProviderState for FakeLanguageModelProvider {
 37    fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 38        None
 39    }
 40}
 41
 42impl LanguageModelProvider for FakeLanguageModelProvider {
 43    fn id(&self) -> LanguageModelProviderId {
 44        provider_id()
 45    }
 46
 47    fn name(&self) -> LanguageModelProviderName {
 48        provider_name()
 49    }
 50
 51    fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 52        vec![Arc::new(FakeLanguageModel {
 53            current_completion_txs: self.current_completion_txs.clone(),
 54        })]
 55    }
 56
 57    fn is_authenticated(&self, _: &AppContext) -> bool {
 58        true
 59    }
 60
 61    fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
 62        Task::ready(Ok(()))
 63    }
 64
 65    fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
 66        unimplemented!()
 67    }
 68
 69    fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
 70        Task::ready(Ok(()))
 71    }
 72}
 73
 74impl FakeLanguageModelProvider {
 75    pub fn test_model(&self) -> FakeLanguageModel {
 76        FakeLanguageModel {
 77            current_completion_txs: self.current_completion_txs.clone(),
 78        }
 79    }
 80}
 81
 82pub struct FakeLanguageModel {
 83    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 84}
 85
 86impl FakeLanguageModel {
 87    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 88        self.current_completion_txs
 89            .lock()
 90            .unwrap()
 91            .keys()
 92            .map(|k| serde_json::from_str(k).unwrap())
 93            .collect()
 94    }
 95
 96    pub fn completion_count(&self) -> usize {
 97        self.current_completion_txs.lock().unwrap().len()
 98    }
 99
100    pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
101        let json = serde_json::to_string(request).unwrap();
102        self.current_completion_txs
103            .lock()
104            .unwrap()
105            .get(&json)
106            .unwrap()
107            .unbounded_send(chunk)
108            .unwrap();
109    }
110
111    pub fn send_last_completion_chunk(&self, chunk: String) {
112        self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
113    }
114
115    pub fn finish_completion(&self, request: &LanguageModelRequest) {
116        self.current_completion_txs
117            .lock()
118            .unwrap()
119            .remove(&serde_json::to_string(request).unwrap())
120            .unwrap();
121    }
122
123    pub fn finish_last_completion(&self) {
124        self.finish_completion(self.pending_completions().last().unwrap());
125    }
126}
127
128impl LanguageModel for FakeLanguageModel {
129    fn id(&self) -> LanguageModelId {
130        language_model_id()
131    }
132
133    fn name(&self) -> LanguageModelName {
134        language_model_name()
135    }
136
137    fn provider_id(&self) -> LanguageModelProviderId {
138        provider_id()
139    }
140
141    fn provider_name(&self) -> LanguageModelProviderName {
142        provider_name()
143    }
144
145    fn telemetry_id(&self) -> String {
146        "fake".to_string()
147    }
148
149    fn max_token_count(&self) -> usize {
150        1000000
151    }
152
153    fn count_tokens(
154        &self,
155        _: LanguageModelRequest,
156        _: &AppContext,
157    ) -> BoxFuture<'static, Result<usize>> {
158        futures::future::ready(Ok(0)).boxed()
159    }
160
161    fn stream_completion(
162        &self,
163        request: LanguageModelRequest,
164        _: &AsyncAppContext,
165    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
166        let (tx, rx) = mpsc::unbounded();
167        self.current_completion_txs
168            .lock()
169            .unwrap()
170            .insert(serde_json::to_string(&request).unwrap(), tx);
171        async move { Ok(rx.map(Ok).boxed()) }.boxed()
172    }
173}