fake.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  3    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  4    LanguageModelRequest,
  5};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  9use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 10use http_client::Result;
 11use std::{
 12    future,
 13    sync::{Arc, Mutex},
 14};
 15use ui::WindowContext;
 16
 17pub fn language_model_id() -> LanguageModelId {
 18    LanguageModelId::from("fake".to_string())
 19}
 20
 21pub fn language_model_name() -> LanguageModelName {
 22    LanguageModelName::from("Fake".to_string())
 23}
 24
 25pub fn provider_id() -> LanguageModelProviderId {
 26    LanguageModelProviderId::from("fake".to_string())
 27}
 28
 29pub fn provider_name() -> LanguageModelProviderName {
 30    LanguageModelProviderName::from("Fake".to_string())
 31}
 32
 33#[derive(Clone, Default)]
 34pub struct FakeLanguageModelProvider {
 35    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 36}
 37
 38impl LanguageModelProviderState for FakeLanguageModelProvider {
 39    type ObservableEntity = ();
 40
 41    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 42        None
 43    }
 44}
 45
 46impl LanguageModelProvider for FakeLanguageModelProvider {
 47    fn id(&self) -> LanguageModelProviderId {
 48        provider_id()
 49    }
 50
 51    fn name(&self) -> LanguageModelProviderName {
 52        provider_name()
 53    }
 54
 55    fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 56        vec![Arc::new(FakeLanguageModel {
 57            current_completion_txs: self.current_completion_txs.clone(),
 58        })]
 59    }
 60
 61    fn is_authenticated(&self, _: &AppContext) -> bool {
 62        true
 63    }
 64
 65    fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
 66        Task::ready(Ok(()))
 67    }
 68
 69    fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
 70        unimplemented!()
 71    }
 72
 73    fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
 74        Task::ready(Ok(()))
 75    }
 76}
 77
 78impl FakeLanguageModelProvider {
 79    pub fn test_model(&self) -> FakeLanguageModel {
 80        FakeLanguageModel {
 81            current_completion_txs: self.current_completion_txs.clone(),
 82        }
 83    }
 84}
 85
 86pub struct FakeLanguageModel {
 87    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 88}
 89
 90impl FakeLanguageModel {
 91    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 92        self.current_completion_txs
 93            .lock()
 94            .unwrap()
 95            .keys()
 96            .map(|k| serde_json::from_str(k).unwrap())
 97            .collect()
 98    }
 99
100    pub fn completion_count(&self) -> usize {
101        self.current_completion_txs.lock().unwrap().len()
102    }
103
104    pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
105        let json = serde_json::to_string(request).unwrap();
106        self.current_completion_txs
107            .lock()
108            .unwrap()
109            .get(&json)
110            .unwrap()
111            .unbounded_send(chunk)
112            .unwrap();
113    }
114
115    pub fn send_last_completion_chunk(&self, chunk: String) {
116        self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
117    }
118
119    pub fn finish_completion(&self, request: &LanguageModelRequest) {
120        self.current_completion_txs
121            .lock()
122            .unwrap()
123            .remove(&serde_json::to_string(request).unwrap())
124            .unwrap();
125    }
126
127    pub fn finish_last_completion(&self) {
128        self.finish_completion(self.pending_completions().last().unwrap());
129    }
130}
131
132impl LanguageModel for FakeLanguageModel {
133    fn id(&self) -> LanguageModelId {
134        language_model_id()
135    }
136
137    fn name(&self) -> LanguageModelName {
138        language_model_name()
139    }
140
141    fn provider_id(&self) -> LanguageModelProviderId {
142        provider_id()
143    }
144
145    fn provider_name(&self) -> LanguageModelProviderName {
146        provider_name()
147    }
148
149    fn telemetry_id(&self) -> String {
150        "fake".to_string()
151    }
152
153    fn max_token_count(&self) -> usize {
154        1000000
155    }
156
157    fn count_tokens(
158        &self,
159        _: LanguageModelRequest,
160        _: &AppContext,
161    ) -> BoxFuture<'static, Result<usize>> {
162        futures::future::ready(Ok(0)).boxed()
163    }
164
165    fn stream_completion(
166        &self,
167        request: LanguageModelRequest,
168        _: &AsyncAppContext,
169    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
170        let (tx, rx) = mpsc::unbounded();
171        self.current_completion_txs
172            .lock()
173            .unwrap()
174            .insert(serde_json::to_string(&request).unwrap(), tx);
175        async move { Ok(rx.map(Ok).boxed()) }.boxed()
176    }
177
178    fn use_any_tool(
179        &self,
180        _request: LanguageModelRequest,
181        _name: String,
182        _description: String,
183        _schema: serde_json::Value,
184        _cx: &AsyncAppContext,
185    ) -> BoxFuture<'static, Result<serde_json::Value>> {
186        future::ready(Err(anyhow!("not implemented"))).boxed()
187    }
188}