fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  3    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  4    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  5    LanguageModelToolChoice,
  6};
  7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  9use http_client::Result;
 10use parking_lot::Mutex;
 11use std::sync::Arc;
 12
 13pub fn language_model_id() -> LanguageModelId {
 14    LanguageModelId::from("fake".to_string())
 15}
 16
 17pub fn language_model_name() -> LanguageModelName {
 18    LanguageModelName::from("Fake".to_string())
 19}
 20
 21pub fn provider_id() -> LanguageModelProviderId {
 22    LanguageModelProviderId::from("fake".to_string())
 23}
 24
 25pub fn provider_name() -> LanguageModelProviderName {
 26    LanguageModelProviderName::from("Fake".to_string())
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct FakeLanguageModelProvider;
 31
 32impl LanguageModelProviderState for FakeLanguageModelProvider {
 33    type ObservableEntity = ();
 34
 35    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 36        None
 37    }
 38}
 39
 40impl LanguageModelProvider for FakeLanguageModelProvider {
 41    fn id(&self) -> LanguageModelProviderId {
 42        provider_id()
 43    }
 44
 45    fn name(&self) -> LanguageModelProviderName {
 46        provider_name()
 47    }
 48
 49    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 50        Some(Arc::new(FakeLanguageModel::default()))
 51    }
 52
 53    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 54        Some(Arc::new(FakeLanguageModel::default()))
 55    }
 56
 57    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 58        vec![Arc::new(FakeLanguageModel::default())]
 59    }
 60
 61    fn is_authenticated(&self, _: &App) -> bool {
 62        true
 63    }
 64
 65    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 66        Task::ready(Ok(()))
 67    }
 68
 69    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 70        unimplemented!()
 71    }
 72
 73    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 74        Task::ready(Ok(()))
 75    }
 76}
 77
 78impl FakeLanguageModelProvider {
 79    pub fn test_model(&self) -> FakeLanguageModel {
 80        FakeLanguageModel::default()
 81    }
 82}
 83
 84#[derive(Debug, PartialEq)]
 85pub struct ToolUseRequest {
 86    pub request: LanguageModelRequest,
 87    pub name: String,
 88    pub description: String,
 89    pub schema: serde_json::Value,
 90}
 91
 92#[derive(Default)]
 93pub struct FakeLanguageModel {
 94    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 95}
 96
 97impl FakeLanguageModel {
 98    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 99        self.current_completion_txs
100            .lock()
101            .iter()
102            .map(|(request, _)| request.clone())
103            .collect()
104    }
105
106    pub fn completion_count(&self) -> usize {
107        self.current_completion_txs.lock().len()
108    }
109
110    pub fn stream_completion_response(
111        &self,
112        request: &LanguageModelRequest,
113        chunk: impl Into<String>,
114    ) {
115        let current_completion_txs = self.current_completion_txs.lock();
116        let tx = current_completion_txs
117            .iter()
118            .find(|(req, _)| req == request)
119            .map(|(_, tx)| tx)
120            .unwrap();
121        tx.unbounded_send(chunk.into()).unwrap();
122    }
123
124    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
125        self.current_completion_txs
126            .lock()
127            .retain(|(req, _)| req != request);
128    }
129
130    pub fn stream_last_completion_response(&self, chunk: impl Into<String>) {
131        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
132    }
133
134    pub fn end_last_completion_stream(&self) {
135        self.end_completion_stream(self.pending_completions().last().unwrap());
136    }
137}
138
139impl LanguageModel for FakeLanguageModel {
140    fn id(&self) -> LanguageModelId {
141        language_model_id()
142    }
143
144    fn name(&self) -> LanguageModelName {
145        language_model_name()
146    }
147
148    fn provider_id(&self) -> LanguageModelProviderId {
149        provider_id()
150    }
151
152    fn provider_name(&self) -> LanguageModelProviderName {
153        provider_name()
154    }
155
156    fn supports_tools(&self) -> bool {
157        false
158    }
159
160    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
161        false
162    }
163
164    fn supports_images(&self) -> bool {
165        false
166    }
167
168    fn telemetry_id(&self) -> String {
169        "fake".to_string()
170    }
171
172    fn max_token_count(&self) -> usize {
173        1000000
174    }
175
176    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
177        futures::future::ready(Ok(0)).boxed()
178    }
179
180    fn stream_completion(
181        &self,
182        request: LanguageModelRequest,
183        _: &AsyncApp,
184    ) -> BoxFuture<
185        'static,
186        Result<
187            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
188        >,
189    > {
190        let (tx, rx) = mpsc::unbounded();
191        self.current_completion_txs.lock().push((request, tx));
192        async move {
193            Ok(rx
194                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
195                .boxed())
196        }
197        .boxed()
198    }
199
200    fn as_fake(&self) -> &Self {
201        self
202    }
203}