fake_provider.rs

  1use crate::{
  2    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
  3    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  4    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  5    LanguageModelRequest, LanguageModelToolChoice,
  6};
  7use anyhow::anyhow;
  8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
 10use http_client::Result;
 11use parking_lot::Mutex;
 12use smol::stream::StreamExt;
 13use std::sync::{
 14    Arc,
 15    atomic::{AtomicBool, Ordering::SeqCst},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModelProvider {
 20    id: LanguageModelProviderId,
 21    name: LanguageModelProviderName,
 22    models: Vec<Arc<dyn LanguageModel>>,
 23}
 24
 25impl Default for FakeLanguageModelProvider {
 26    fn default() -> Self {
 27        Self {
 28            id: LanguageModelProviderId::from("fake".to_string()),
 29            name: LanguageModelProviderName::from("Fake".to_string()),
 30            models: vec![Arc::new(FakeLanguageModel::default())],
 31        }
 32    }
 33}
 34
 35impl LanguageModelProviderState for FakeLanguageModelProvider {
 36    type ObservableEntity = ();
 37
 38    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 39        None
 40    }
 41}
 42
 43impl LanguageModelProvider for FakeLanguageModelProvider {
 44    fn id(&self) -> LanguageModelProviderId {
 45        self.id.clone()
 46    }
 47
 48    fn name(&self) -> LanguageModelProviderName {
 49        self.name.clone()
 50    }
 51
 52    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 53        self.models.first().cloned()
 54    }
 55
 56    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 57        self.models.first().cloned()
 58    }
 59
 60    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 61        self.models.clone()
 62    }
 63
 64    fn is_authenticated(&self, _: &App) -> bool {
 65        true
 66    }
 67
 68    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 69        Task::ready(Ok(()))
 70    }
 71
 72    fn configuration_view(
 73        &self,
 74        _target_agent: ConfigurationViewTargetAgent,
 75        _window: &mut Window,
 76        _: &mut App,
 77    ) -> AnyView {
 78        unimplemented!()
 79    }
 80
 81    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 82        Task::ready(Ok(()))
 83    }
 84}
 85
 86impl FakeLanguageModelProvider {
 87    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 88        Self {
 89            id,
 90            name,
 91            models: vec![Arc::new(FakeLanguageModel::default())],
 92        }
 93    }
 94
 95    pub fn with_models(mut self, models: Vec<Arc<dyn LanguageModel>>) -> Self {
 96        self.models = models;
 97        self
 98    }
 99
100    pub fn test_model(&self) -> FakeLanguageModel {
101        FakeLanguageModel::default()
102    }
103}
104
105#[derive(Debug, PartialEq)]
106pub struct ToolUseRequest {
107    pub request: LanguageModelRequest,
108    pub name: String,
109    pub description: String,
110    pub schema: serde_json::Value,
111}
112
113pub struct FakeLanguageModel {
114    id: LanguageModelId,
115    name: LanguageModelName,
116    provider_id: LanguageModelProviderId,
117    provider_name: LanguageModelProviderName,
118    current_completion_txs: Mutex<
119        Vec<(
120            LanguageModelRequest,
121            mpsc::UnboundedSender<
122                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
123            >,
124        )>,
125    >,
126    forbid_requests: AtomicBool,
127    supports_thinking: AtomicBool,
128}
129
130impl Default for FakeLanguageModel {
131    fn default() -> Self {
132        Self {
133            id: LanguageModelId::from("fake".to_string()),
134            name: LanguageModelName::from("Fake".to_string()),
135            provider_id: LanguageModelProviderId::from("fake".to_string()),
136            provider_name: LanguageModelProviderName::from("Fake".to_string()),
137            current_completion_txs: Mutex::new(Vec::new()),
138            forbid_requests: AtomicBool::new(false),
139            supports_thinking: AtomicBool::new(false),
140        }
141    }
142}
143
144impl FakeLanguageModel {
145    pub fn with_id_and_thinking(
146        provider_id: &str,
147        id: &str,
148        name: &str,
149        supports_thinking: bool,
150    ) -> Self {
151        Self {
152            id: LanguageModelId::from(id.to_string()),
153            name: LanguageModelName::from(name.to_string()),
154            provider_id: LanguageModelProviderId::from(provider_id.to_string()),
155            supports_thinking: AtomicBool::new(supports_thinking),
156            ..Default::default()
157        }
158    }
159
160    pub fn allow_requests(&self) {
161        self.forbid_requests.store(false, SeqCst);
162    }
163
164    pub fn forbid_requests(&self) {
165        self.forbid_requests.store(true, SeqCst);
166    }
167
168    pub fn set_supports_thinking(&self, supports: bool) {
169        self.supports_thinking.store(supports, SeqCst);
170    }
171
172    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
173        self.current_completion_txs
174            .lock()
175            .iter()
176            .map(|(request, _)| request.clone())
177            .collect()
178    }
179
180    pub fn completion_count(&self) -> usize {
181        self.current_completion_txs.lock().len()
182    }
183
184    pub fn send_completion_stream_text_chunk(
185        &self,
186        request: &LanguageModelRequest,
187        chunk: impl Into<String>,
188    ) {
189        self.send_completion_stream_event(
190            request,
191            LanguageModelCompletionEvent::Text(chunk.into()),
192        );
193    }
194
195    pub fn send_completion_stream_event(
196        &self,
197        request: &LanguageModelRequest,
198        event: impl Into<LanguageModelCompletionEvent>,
199    ) {
200        let current_completion_txs = self.current_completion_txs.lock();
201        let tx = current_completion_txs
202            .iter()
203            .find(|(req, _)| req == request)
204            .map(|(_, tx)| tx)
205            .unwrap();
206        tx.unbounded_send(Ok(event.into())).unwrap();
207    }
208
209    pub fn send_completion_stream_error(
210        &self,
211        request: &LanguageModelRequest,
212        error: impl Into<LanguageModelCompletionError>,
213    ) {
214        let current_completion_txs = self.current_completion_txs.lock();
215        let tx = current_completion_txs
216            .iter()
217            .find(|(req, _)| req == request)
218            .map(|(_, tx)| tx)
219            .unwrap();
220        tx.unbounded_send(Err(error.into())).unwrap();
221    }
222
223    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
224        self.current_completion_txs
225            .lock()
226            .retain(|(req, _)| req != request);
227    }
228
229    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
230        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
231    }
232
233    pub fn send_last_completion_stream_event(
234        &self,
235        event: impl Into<LanguageModelCompletionEvent>,
236    ) {
237        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
238    }
239
240    pub fn send_last_completion_stream_error(
241        &self,
242        error: impl Into<LanguageModelCompletionError>,
243    ) {
244        self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
245    }
246
247    pub fn end_last_completion_stream(&self) {
248        self.end_completion_stream(self.pending_completions().last().unwrap());
249    }
250}
251
252impl LanguageModel for FakeLanguageModel {
253    fn id(&self) -> LanguageModelId {
254        self.id.clone()
255    }
256
257    fn name(&self) -> LanguageModelName {
258        self.name.clone()
259    }
260
261    fn provider_id(&self) -> LanguageModelProviderId {
262        self.provider_id.clone()
263    }
264
265    fn provider_name(&self) -> LanguageModelProviderName {
266        self.provider_name.clone()
267    }
268
269    fn supports_tools(&self) -> bool {
270        false
271    }
272
273    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
274        false
275    }
276
277    fn supports_images(&self) -> bool {
278        false
279    }
280
281    fn supports_thinking(&self) -> bool {
282        self.supports_thinking.load(SeqCst)
283    }
284
285    fn telemetry_id(&self) -> String {
286        "fake".to_string()
287    }
288
289    fn max_token_count(&self) -> u64 {
290        1000000
291    }
292
293    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
294        futures::future::ready(Ok(0)).boxed()
295    }
296
297    fn stream_completion(
298        &self,
299        request: LanguageModelRequest,
300        _: &AsyncApp,
301    ) -> BoxFuture<
302        'static,
303        Result<
304            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
305            LanguageModelCompletionError,
306        >,
307    > {
308        if self.forbid_requests.load(SeqCst) {
309            async move {
310                Err(LanguageModelCompletionError::Other(anyhow!(
311                    "requests are forbidden"
312                )))
313            }
314            .boxed()
315        } else {
316            let (tx, rx) = mpsc::unbounded();
317            self.current_completion_txs.lock().push((request, tx));
318            async move { Ok(rx.boxed()) }.boxed()
319        }
320    }
321
322    fn as_fake(&self) -> &Self {
323        self
324    }
325}