fake_provider.rs

  1use crate::{
  2    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
  3    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  4    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  5    LanguageModelRequest, LanguageModelToolChoice,
  6};
  7use anyhow::anyhow;
  8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
 10use http_client::Result;
 11use parking_lot::Mutex;
 12use smol::stream::StreamExt;
 13use std::sync::{
 14    Arc,
 15    atomic::{AtomicBool, Ordering::SeqCst},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModelProvider {
 20    id: LanguageModelProviderId,
 21    name: LanguageModelProviderName,
 22    models: Vec<Arc<dyn LanguageModel>>,
 23}
 24
 25impl Default for FakeLanguageModelProvider {
 26    fn default() -> Self {
 27        Self {
 28            id: LanguageModelProviderId::from("fake".to_string()),
 29            name: LanguageModelProviderName::from("Fake".to_string()),
 30            models: vec![Arc::new(FakeLanguageModel::default())],
 31        }
 32    }
 33}
 34
 35impl LanguageModelProviderState for FakeLanguageModelProvider {
 36    type ObservableEntity = ();
 37
 38    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 39        None
 40    }
 41}
 42
 43impl LanguageModelProvider for FakeLanguageModelProvider {
 44    fn id(&self) -> LanguageModelProviderId {
 45        self.id.clone()
 46    }
 47
 48    fn name(&self) -> LanguageModelProviderName {
 49        self.name.clone()
 50    }
 51
 52    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 53        self.models.first().cloned()
 54    }
 55
 56    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 57        self.models.first().cloned()
 58    }
 59
 60    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 61        self.models.clone()
 62    }
 63
 64    fn is_authenticated(&self, _: &App) -> bool {
 65        true
 66    }
 67
 68    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 69        Task::ready(Ok(()))
 70    }
 71
 72    fn configuration_view(
 73        &self,
 74        _target_agent: ConfigurationViewTargetAgent,
 75        _window: &mut Window,
 76        _: &mut App,
 77    ) -> AnyView {
 78        unimplemented!()
 79    }
 80
 81    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 82        Task::ready(Ok(()))
 83    }
 84}
 85
 86impl FakeLanguageModelProvider {
 87    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 88        Self {
 89            id,
 90            name,
 91            models: vec![Arc::new(FakeLanguageModel::default())],
 92        }
 93    }
 94
 95    pub fn with_models(mut self, models: Vec<Arc<dyn LanguageModel>>) -> Self {
 96        self.models = models;
 97        self
 98    }
 99
100    pub fn test_model(&self) -> FakeLanguageModel {
101        FakeLanguageModel::default()
102    }
103}
104
105#[derive(Debug, PartialEq)]
106pub struct ToolUseRequest {
107    pub request: LanguageModelRequest,
108    pub name: String,
109    pub description: String,
110    pub schema: serde_json::Value,
111}
112
113pub struct FakeLanguageModel {
114    id: LanguageModelId,
115    name: LanguageModelName,
116    provider_id: LanguageModelProviderId,
117    provider_name: LanguageModelProviderName,
118    current_completion_txs: Mutex<
119        Vec<(
120            LanguageModelRequest,
121            mpsc::UnboundedSender<
122                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
123            >,
124        )>,
125    >,
126    forbid_requests: AtomicBool,
127    supports_thinking: AtomicBool,
128    supports_streaming_tools: AtomicBool,
129}
130
131impl Default for FakeLanguageModel {
132    fn default() -> Self {
133        Self {
134            id: LanguageModelId::from("fake".to_string()),
135            name: LanguageModelName::from("Fake".to_string()),
136            provider_id: LanguageModelProviderId::from("fake".to_string()),
137            provider_name: LanguageModelProviderName::from("Fake".to_string()),
138            current_completion_txs: Mutex::new(Vec::new()),
139            forbid_requests: AtomicBool::new(false),
140            supports_thinking: AtomicBool::new(false),
141            supports_streaming_tools: AtomicBool::new(false),
142        }
143    }
144}
145
146impl FakeLanguageModel {
147    pub fn with_id_and_thinking(
148        provider_id: &str,
149        id: &str,
150        name: &str,
151        supports_thinking: bool,
152    ) -> Self {
153        Self {
154            id: LanguageModelId::from(id.to_string()),
155            name: LanguageModelName::from(name.to_string()),
156            provider_id: LanguageModelProviderId::from(provider_id.to_string()),
157            supports_thinking: AtomicBool::new(supports_thinking),
158            ..Default::default()
159        }
160    }
161
162    pub fn allow_requests(&self) {
163        self.forbid_requests.store(false, SeqCst);
164    }
165
166    pub fn forbid_requests(&self) {
167        self.forbid_requests.store(true, SeqCst);
168    }
169
170    pub fn set_supports_thinking(&self, supports: bool) {
171        self.supports_thinking.store(supports, SeqCst);
172    }
173
174    pub fn set_supports_streaming_tools(&self, supports: bool) {
175        self.supports_streaming_tools.store(supports, SeqCst);
176    }
177
178    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
179        self.current_completion_txs
180            .lock()
181            .iter()
182            .map(|(request, _)| request.clone())
183            .collect()
184    }
185
186    pub fn completion_count(&self) -> usize {
187        self.current_completion_txs.lock().len()
188    }
189
190    pub fn send_completion_stream_text_chunk(
191        &self,
192        request: &LanguageModelRequest,
193        chunk: impl Into<String>,
194    ) {
195        self.send_completion_stream_event(
196            request,
197            LanguageModelCompletionEvent::Text(chunk.into()),
198        );
199    }
200
201    pub fn send_completion_stream_event(
202        &self,
203        request: &LanguageModelRequest,
204        event: impl Into<LanguageModelCompletionEvent>,
205    ) {
206        let current_completion_txs = self.current_completion_txs.lock();
207        let tx = current_completion_txs
208            .iter()
209            .find(|(req, _)| req == request)
210            .map(|(_, tx)| tx)
211            .unwrap();
212        tx.unbounded_send(Ok(event.into())).unwrap();
213    }
214
215    pub fn send_completion_stream_error(
216        &self,
217        request: &LanguageModelRequest,
218        error: impl Into<LanguageModelCompletionError>,
219    ) {
220        let current_completion_txs = self.current_completion_txs.lock();
221        let tx = current_completion_txs
222            .iter()
223            .find(|(req, _)| req == request)
224            .map(|(_, tx)| tx)
225            .unwrap();
226        tx.unbounded_send(Err(error.into())).unwrap();
227    }
228
229    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
230        self.current_completion_txs
231            .lock()
232            .retain(|(req, _)| req != request);
233    }
234
235    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
236        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
237    }
238
239    pub fn send_last_completion_stream_event(
240        &self,
241        event: impl Into<LanguageModelCompletionEvent>,
242    ) {
243        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
244    }
245
246    pub fn send_last_completion_stream_error(
247        &self,
248        error: impl Into<LanguageModelCompletionError>,
249    ) {
250        self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
251    }
252
253    pub fn end_last_completion_stream(&self) {
254        self.end_completion_stream(self.pending_completions().last().unwrap());
255    }
256}
257
258impl LanguageModel for FakeLanguageModel {
259    fn id(&self) -> LanguageModelId {
260        self.id.clone()
261    }
262
263    fn name(&self) -> LanguageModelName {
264        self.name.clone()
265    }
266
267    fn provider_id(&self) -> LanguageModelProviderId {
268        self.provider_id.clone()
269    }
270
271    fn provider_name(&self) -> LanguageModelProviderName {
272        self.provider_name.clone()
273    }
274
275    fn supports_tools(&self) -> bool {
276        false
277    }
278
279    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
280        false
281    }
282
283    fn supports_images(&self) -> bool {
284        false
285    }
286
287    fn supports_thinking(&self) -> bool {
288        self.supports_thinking.load(SeqCst)
289    }
290
291    fn supports_streaming_tools(&self) -> bool {
292        self.supports_streaming_tools.load(SeqCst)
293    }
294
295    fn telemetry_id(&self) -> String {
296        "fake".to_string()
297    }
298
299    fn max_token_count(&self) -> u64 {
300        1000000
301    }
302
303    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
304        futures::future::ready(Ok(0)).boxed()
305    }
306
307    fn stream_completion(
308        &self,
309        request: LanguageModelRequest,
310        _: &AsyncApp,
311    ) -> BoxFuture<
312        'static,
313        Result<
314            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
315            LanguageModelCompletionError,
316        >,
317    > {
318        if self.forbid_requests.load(SeqCst) {
319            async move {
320                Err(LanguageModelCompletionError::Other(anyhow!(
321                    "requests are forbidden"
322                )))
323            }
324            .boxed()
325        } else {
326            let (tx, rx) = mpsc::unbounded();
327            self.current_completion_txs.lock().push((request, tx));
328            async move { Ok(rx.boxed()) }.boxed()
329        }
330    }
331
332    fn as_fake(&self) -> &Self {
333        self
334    }
335}