fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  3    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  4    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  5    LanguageModelToolChoice,
  6};
  7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  9use http_client::Result;
 10use parking_lot::Mutex;
 11use std::sync::Arc;
 12
 13#[derive(Clone)]
 14pub struct FakeLanguageModelProvider {
 15    id: LanguageModelProviderId,
 16    name: LanguageModelProviderName,
 17}
 18
 19impl Default for FakeLanguageModelProvider {
 20    fn default() -> Self {
 21        Self {
 22            id: LanguageModelProviderId::from("fake".to_string()),
 23            name: LanguageModelProviderName::from("Fake".to_string()),
 24        }
 25    }
 26}
 27
 28impl LanguageModelProviderState for FakeLanguageModelProvider {
 29    type ObservableEntity = ();
 30
 31    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 32        None
 33    }
 34}
 35
 36impl LanguageModelProvider for FakeLanguageModelProvider {
 37    fn id(&self) -> LanguageModelProviderId {
 38        self.id.clone()
 39    }
 40
 41    fn name(&self) -> LanguageModelProviderName {
 42        self.name.clone()
 43    }
 44
 45    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 46        Some(Arc::new(FakeLanguageModel::default()))
 47    }
 48
 49    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 50        Some(Arc::new(FakeLanguageModel::default()))
 51    }
 52
 53    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 54        vec![Arc::new(FakeLanguageModel::default())]
 55    }
 56
 57    fn is_authenticated(&self, _: &App) -> bool {
 58        true
 59    }
 60
 61    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 62        Task::ready(Ok(()))
 63    }
 64
 65    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 66        unimplemented!()
 67    }
 68
 69    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 70        Task::ready(Ok(()))
 71    }
 72}
 73
 74impl FakeLanguageModelProvider {
 75    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 76        Self { id, name }
 77    }
 78
 79    pub fn test_model(&self) -> FakeLanguageModel {
 80        FakeLanguageModel::default()
 81    }
 82}
 83
 84#[derive(Debug, PartialEq)]
 85pub struct ToolUseRequest {
 86    pub request: LanguageModelRequest,
 87    pub name: String,
 88    pub description: String,
 89    pub schema: serde_json::Value,
 90}
 91
 92pub struct FakeLanguageModel {
 93    provider_id: LanguageModelProviderId,
 94    provider_name: LanguageModelProviderName,
 95    current_completion_txs: Mutex<
 96        Vec<(
 97            LanguageModelRequest,
 98            mpsc::UnboundedSender<LanguageModelCompletionEvent>,
 99        )>,
100    >,
101}
102
103impl Default for FakeLanguageModel {
104    fn default() -> Self {
105        dbg!("default......");
106        eprintln!("{}", std::backtrace::Backtrace::force_capture());
107        Self {
108            provider_id: LanguageModelProviderId::from("fake".to_string()),
109            provider_name: LanguageModelProviderName::from("Fake".to_string()),
110            current_completion_txs: Mutex::new(Vec::new()),
111        }
112    }
113}
114
115impl FakeLanguageModel {
116    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
117        self.current_completion_txs
118            .lock()
119            .iter()
120            .map(|(request, _)| request.clone())
121            .collect()
122    }
123
124    pub fn completion_count(&self) -> usize {
125        self.current_completion_txs.lock().len()
126    }
127
128    pub fn send_completion_stream_text_chunk(
129        &self,
130        request: &LanguageModelRequest,
131        chunk: impl Into<String>,
132    ) {
133        self.send_completion_stream_event(
134            request,
135            LanguageModelCompletionEvent::Text(chunk.into()),
136        );
137    }
138
139    pub fn send_completion_stream_event(
140        &self,
141        request: &LanguageModelRequest,
142        event: impl Into<LanguageModelCompletionEvent>,
143    ) {
144        let current_completion_txs = self.current_completion_txs.lock();
145        let tx = current_completion_txs
146            .iter()
147            .find(|(req, _)| req == request)
148            .map(|(_, tx)| tx)
149            .unwrap();
150        tx.unbounded_send(event.into()).unwrap();
151    }
152
153    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
154        dbg!("remove...");
155        self.current_completion_txs
156            .lock()
157            .retain(|(req, _)| req != request);
158    }
159
160    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
161        dbg!("read...");
162        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
163    }
164
165    pub fn send_last_completion_stream_event(
166        &self,
167        event: impl Into<LanguageModelCompletionEvent>,
168    ) {
169        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
170    }
171
172    pub fn end_last_completion_stream(&self) {
173        self.end_completion_stream(self.pending_completions().last().unwrap());
174    }
175}
176
177impl LanguageModel for FakeLanguageModel {
178    fn id(&self) -> LanguageModelId {
179        LanguageModelId::from("fake".to_string())
180    }
181
182    fn name(&self) -> LanguageModelName {
183        LanguageModelName::from("Fake".to_string())
184    }
185
186    fn provider_id(&self) -> LanguageModelProviderId {
187        self.provider_id.clone()
188    }
189
190    fn provider_name(&self) -> LanguageModelProviderName {
191        self.provider_name.clone()
192    }
193
194    fn supports_tools(&self) -> bool {
195        false
196    }
197
198    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
199        false
200    }
201
202    fn supports_images(&self) -> bool {
203        false
204    }
205
206    fn telemetry_id(&self) -> String {
207        "fake".to_string()
208    }
209
210    fn max_token_count(&self) -> u64 {
211        1000000
212    }
213
214    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
215        futures::future::ready(Ok(0)).boxed()
216    }
217
218    fn stream_completion(
219        &self,
220        request: LanguageModelRequest,
221        _: &AsyncApp,
222    ) -> BoxFuture<
223        'static,
224        Result<
225            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
226            LanguageModelCompletionError,
227        >,
228    > {
229        let (tx, rx) = mpsc::unbounded();
230        dbg!("insert...");
231        self.current_completion_txs.lock().push((request, tx));
232        async move { Ok(rx.map(Ok).boxed()) }.boxed()
233    }
234
235    fn as_fake(&self) -> &Self {
236        self
237    }
238}