fake.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  3    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  4    LanguageModelRequest,
  5};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  9use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 10use http_client::Result;
 11use std::{
 12    future,
 13    sync::{Arc, Mutex},
 14};
 15use ui::WindowContext;
 16
 17pub fn language_model_id() -> LanguageModelId {
 18    LanguageModelId::from("fake".to_string())
 19}
 20
 21pub fn language_model_name() -> LanguageModelName {
 22    LanguageModelName::from("Fake".to_string())
 23}
 24
 25pub fn provider_id() -> LanguageModelProviderId {
 26    LanguageModelProviderId::from("fake".to_string())
 27}
 28
 29pub fn provider_name() -> LanguageModelProviderName {
 30    LanguageModelProviderName::from("Fake".to_string())
 31}
 32
 33#[derive(Clone, Default)]
 34pub struct FakeLanguageModelProvider {
 35    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 36}
 37
 38impl LanguageModelProviderState for FakeLanguageModelProvider {
 39    fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 40        None
 41    }
 42}
 43
 44impl LanguageModelProvider for FakeLanguageModelProvider {
 45    fn id(&self) -> LanguageModelProviderId {
 46        provider_id()
 47    }
 48
 49    fn name(&self) -> LanguageModelProviderName {
 50        provider_name()
 51    }
 52
 53    fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 54        vec![Arc::new(FakeLanguageModel {
 55            current_completion_txs: self.current_completion_txs.clone(),
 56        })]
 57    }
 58
 59    fn is_authenticated(&self, _: &AppContext) -> bool {
 60        true
 61    }
 62
 63    fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
 64        Task::ready(Ok(()))
 65    }
 66
 67    fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
 68        unimplemented!()
 69    }
 70
 71    fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
 72        Task::ready(Ok(()))
 73    }
 74}
 75
 76impl FakeLanguageModelProvider {
 77    pub fn test_model(&self) -> FakeLanguageModel {
 78        FakeLanguageModel {
 79            current_completion_txs: self.current_completion_txs.clone(),
 80        }
 81    }
 82}
 83
 84pub struct FakeLanguageModel {
 85    current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 86}
 87
 88impl FakeLanguageModel {
 89    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 90        self.current_completion_txs
 91            .lock()
 92            .unwrap()
 93            .keys()
 94            .map(|k| serde_json::from_str(k).unwrap())
 95            .collect()
 96    }
 97
 98    pub fn completion_count(&self) -> usize {
 99        self.current_completion_txs.lock().unwrap().len()
100    }
101
102    pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
103        let json = serde_json::to_string(request).unwrap();
104        self.current_completion_txs
105            .lock()
106            .unwrap()
107            .get(&json)
108            .unwrap()
109            .unbounded_send(chunk)
110            .unwrap();
111    }
112
113    pub fn send_last_completion_chunk(&self, chunk: String) {
114        self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
115    }
116
117    pub fn finish_completion(&self, request: &LanguageModelRequest) {
118        self.current_completion_txs
119            .lock()
120            .unwrap()
121            .remove(&serde_json::to_string(request).unwrap())
122            .unwrap();
123    }
124
125    pub fn finish_last_completion(&self) {
126        self.finish_completion(self.pending_completions().last().unwrap());
127    }
128}
129
130impl LanguageModel for FakeLanguageModel {
131    fn id(&self) -> LanguageModelId {
132        language_model_id()
133    }
134
135    fn name(&self) -> LanguageModelName {
136        language_model_name()
137    }
138
139    fn provider_id(&self) -> LanguageModelProviderId {
140        provider_id()
141    }
142
143    fn provider_name(&self) -> LanguageModelProviderName {
144        provider_name()
145    }
146
147    fn telemetry_id(&self) -> String {
148        "fake".to_string()
149    }
150
151    fn max_token_count(&self) -> usize {
152        1000000
153    }
154
155    fn count_tokens(
156        &self,
157        _: LanguageModelRequest,
158        _: &AppContext,
159    ) -> BoxFuture<'static, Result<usize>> {
160        futures::future::ready(Ok(0)).boxed()
161    }
162
163    fn stream_completion(
164        &self,
165        request: LanguageModelRequest,
166        _: &AsyncAppContext,
167    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
168        let (tx, rx) = mpsc::unbounded();
169        self.current_completion_txs
170            .lock()
171            .unwrap()
172            .insert(serde_json::to_string(&request).unwrap(), tx);
173        async move { Ok(rx.map(Ok).boxed()) }.boxed()
174    }
175
176    fn use_any_tool(
177        &self,
178        _request: LanguageModelRequest,
179        _name: String,
180        _description: String,
181        _schema: serde_json::Value,
182        _cx: &AsyncAppContext,
183    ) -> BoxFuture<'static, Result<serde_json::Value>> {
184        future::ready(Err(anyhow!("not implemented"))).boxed()
185    }
186}