fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  3    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  4    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  5    LanguageModelToolChoice,
  6};
  7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  9use http_client::Result;
 10use parking_lot::Mutex;
 11use std::sync::Arc;
 12
 13pub fn language_model_id() -> LanguageModelId {
 14    LanguageModelId::from("fake".to_string())
 15}
 16
 17pub fn language_model_name() -> LanguageModelName {
 18    LanguageModelName::from("Fake".to_string())
 19}
 20
 21pub fn provider_id() -> LanguageModelProviderId {
 22    LanguageModelProviderId::from("fake".to_string())
 23}
 24
 25pub fn provider_name() -> LanguageModelProviderName {
 26    LanguageModelProviderName::from("Fake".to_string())
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct FakeLanguageModelProvider;
 31
 32impl LanguageModelProviderState for FakeLanguageModelProvider {
 33    type ObservableEntity = ();
 34
 35    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 36        None
 37    }
 38}
 39
 40impl LanguageModelProvider for FakeLanguageModelProvider {
 41    fn id(&self) -> LanguageModelProviderId {
 42        provider_id()
 43    }
 44
 45    fn name(&self) -> LanguageModelProviderName {
 46        provider_name()
 47    }
 48
 49    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 50        Some(Arc::new(FakeLanguageModel::default()))
 51    }
 52
 53    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 54        Some(Arc::new(FakeLanguageModel::default()))
 55    }
 56
 57    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 58        vec![Arc::new(FakeLanguageModel::default())]
 59    }
 60
 61    fn is_authenticated(&self, _: &App) -> bool {
 62        true
 63    }
 64
 65    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 66        Task::ready(Ok(()))
 67    }
 68
 69    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 70        unimplemented!()
 71    }
 72
 73    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 74        Task::ready(Ok(()))
 75    }
 76}
 77
 78impl FakeLanguageModelProvider {
 79    pub fn test_model(&self) -> FakeLanguageModel {
 80        FakeLanguageModel::default()
 81    }
 82}
 83
 84#[derive(Debug, PartialEq)]
 85pub struct ToolUseRequest {
 86    pub request: LanguageModelRequest,
 87    pub name: String,
 88    pub description: String,
 89    pub schema: serde_json::Value,
 90}
 91
 92#[derive(Default)]
 93pub struct FakeLanguageModel {
 94    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 95}
 96
 97impl FakeLanguageModel {
 98    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 99        self.current_completion_txs
100            .lock()
101            .iter()
102            .map(|(request, _)| request.clone())
103            .collect()
104    }
105
106    pub fn completion_count(&self) -> usize {
107        self.current_completion_txs.lock().len()
108    }
109
110    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
111        let current_completion_txs = self.current_completion_txs.lock();
112        let tx = current_completion_txs
113            .iter()
114            .find(|(req, _)| req == request)
115            .map(|(_, tx)| tx)
116            .unwrap();
117        tx.unbounded_send(chunk).unwrap();
118    }
119
120    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
121        self.current_completion_txs
122            .lock()
123            .retain(|(req, _)| req != request);
124    }
125
126    pub fn stream_last_completion_response(&self, chunk: String) {
127        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
128    }
129
130    pub fn end_last_completion_stream(&self) {
131        self.end_completion_stream(self.pending_completions().last().unwrap());
132    }
133}
134
135impl LanguageModel for FakeLanguageModel {
136    fn id(&self) -> LanguageModelId {
137        language_model_id()
138    }
139
140    fn name(&self) -> LanguageModelName {
141        language_model_name()
142    }
143
144    fn provider_id(&self) -> LanguageModelProviderId {
145        provider_id()
146    }
147
148    fn provider_name(&self) -> LanguageModelProviderName {
149        provider_name()
150    }
151
152    fn supports_tools(&self) -> bool {
153        false
154    }
155
156    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
157        false
158    }
159
160    fn supports_images(&self) -> bool {
161        false
162    }
163
164    fn telemetry_id(&self) -> String {
165        "fake".to_string()
166    }
167
168    fn max_token_count(&self) -> usize {
169        1000000
170    }
171
172    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
173        futures::future::ready(Ok(0)).boxed()
174    }
175
176    fn stream_completion(
177        &self,
178        request: LanguageModelRequest,
179        _: &AsyncApp,
180    ) -> BoxFuture<
181        'static,
182        Result<
183            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
184        >,
185    > {
186        let (tx, rx) = mpsc::unbounded();
187        self.current_completion_txs.lock().push((request, tx));
188        async move {
189            Ok(rx
190                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
191                .boxed())
192        }
193        .boxed()
194    }
195
196    fn as_fake(&self) -> &Self {
197        self
198    }
199}