fake_provider.rs

  1use crate::{
  2    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
  3    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  4    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  5    LanguageModelRequest, LanguageModelToolChoice,
  6};
  7use anyhow::anyhow;
  8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt};
  9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
 10use http_client::Result;
 11use parking_lot::Mutex;
 12use std::sync::{
 13    Arc,
 14    atomic::{AtomicBool, Ordering::SeqCst},
 15};
 16
 17#[derive(Clone)]
 18pub struct FakeLanguageModelProvider {
 19    id: LanguageModelProviderId,
 20    name: LanguageModelProviderName,
 21    models: Vec<Arc<dyn LanguageModel>>,
 22}
 23
 24impl Default for FakeLanguageModelProvider {
 25    fn default() -> Self {
 26        Self {
 27            id: LanguageModelProviderId::from("fake".to_string()),
 28            name: LanguageModelProviderName::from("Fake".to_string()),
 29            models: vec![Arc::new(FakeLanguageModel::default())],
 30        }
 31    }
 32}
 33
 34impl LanguageModelProviderState for FakeLanguageModelProvider {
 35    type ObservableEntity = ();
 36
 37    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 38        None
 39    }
 40}
 41
 42impl LanguageModelProvider for FakeLanguageModelProvider {
 43    fn id(&self) -> LanguageModelProviderId {
 44        self.id.clone()
 45    }
 46
 47    fn name(&self) -> LanguageModelProviderName {
 48        self.name.clone()
 49    }
 50
 51    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 52        self.models.first().cloned()
 53    }
 54
 55    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 56        self.models.first().cloned()
 57    }
 58
 59    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 60        self.models.clone()
 61    }
 62
 63    fn is_authenticated(&self, _: &App) -> bool {
 64        true
 65    }
 66
 67    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 68        Task::ready(Ok(()))
 69    }
 70
 71    fn configuration_view(
 72        &self,
 73        _target_agent: ConfigurationViewTargetAgent,
 74        _window: &mut Window,
 75        _: &mut App,
 76    ) -> AnyView {
 77        unimplemented!()
 78    }
 79
 80    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 81        Task::ready(Ok(()))
 82    }
 83}
 84
 85impl FakeLanguageModelProvider {
 86    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 87        Self {
 88            id,
 89            name,
 90            models: vec![Arc::new(FakeLanguageModel::default())],
 91        }
 92    }
 93
 94    pub fn with_models(mut self, models: Vec<Arc<dyn LanguageModel>>) -> Self {
 95        self.models = models;
 96        self
 97    }
 98
 99    pub fn test_model(&self) -> FakeLanguageModel {
100        FakeLanguageModel::default()
101    }
102}
103
104#[derive(Debug, PartialEq)]
105pub struct ToolUseRequest {
106    pub request: LanguageModelRequest,
107    pub name: String,
108    pub description: String,
109    pub schema: serde_json::Value,
110}
111
112pub struct FakeLanguageModel {
113    id: LanguageModelId,
114    name: LanguageModelName,
115    provider_id: LanguageModelProviderId,
116    provider_name: LanguageModelProviderName,
117    current_completion_txs: Mutex<
118        Vec<(
119            LanguageModelRequest,
120            mpsc::UnboundedSender<
121                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
122            >,
123        )>,
124    >,
125    forbid_requests: AtomicBool,
126    supports_thinking: AtomicBool,
127    supports_streaming_tools: AtomicBool,
128}
129
130impl Default for FakeLanguageModel {
131    fn default() -> Self {
132        Self {
133            id: LanguageModelId::from("fake".to_string()),
134            name: LanguageModelName::from("Fake".to_string()),
135            provider_id: LanguageModelProviderId::from("fake".to_string()),
136            provider_name: LanguageModelProviderName::from("Fake".to_string()),
137            current_completion_txs: Mutex::new(Vec::new()),
138            forbid_requests: AtomicBool::new(false),
139            supports_thinking: AtomicBool::new(false),
140            supports_streaming_tools: AtomicBool::new(false),
141        }
142    }
143}
144
145impl FakeLanguageModel {
146    pub fn with_id_and_thinking(
147        provider_id: &str,
148        id: &str,
149        name: &str,
150        supports_thinking: bool,
151    ) -> Self {
152        Self {
153            id: LanguageModelId::from(id.to_string()),
154            name: LanguageModelName::from(name.to_string()),
155            provider_id: LanguageModelProviderId::from(provider_id.to_string()),
156            supports_thinking: AtomicBool::new(supports_thinking),
157            ..Default::default()
158        }
159    }
160
161    pub fn allow_requests(&self) {
162        self.forbid_requests.store(false, SeqCst);
163    }
164
165    pub fn forbid_requests(&self) {
166        self.forbid_requests.store(true, SeqCst);
167    }
168
169    pub fn set_supports_thinking(&self, supports: bool) {
170        self.supports_thinking.store(supports, SeqCst);
171    }
172
173    pub fn set_supports_streaming_tools(&self, supports: bool) {
174        self.supports_streaming_tools.store(supports, SeqCst);
175    }
176
177    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
178        self.current_completion_txs
179            .lock()
180            .iter()
181            .map(|(request, _)| request.clone())
182            .collect()
183    }
184
185    pub fn completion_count(&self) -> usize {
186        self.current_completion_txs.lock().len()
187    }
188
189    pub fn send_completion_stream_text_chunk(
190        &self,
191        request: &LanguageModelRequest,
192        chunk: impl Into<String>,
193    ) {
194        self.send_completion_stream_event(
195            request,
196            LanguageModelCompletionEvent::Text(chunk.into()),
197        );
198    }
199
200    pub fn send_completion_stream_event(
201        &self,
202        request: &LanguageModelRequest,
203        event: impl Into<LanguageModelCompletionEvent>,
204    ) {
205        let current_completion_txs = self.current_completion_txs.lock();
206        let tx = current_completion_txs
207            .iter()
208            .find(|(req, _)| req == request)
209            .map(|(_, tx)| tx)
210            .unwrap();
211        tx.unbounded_send(Ok(event.into())).unwrap();
212    }
213
214    pub fn send_completion_stream_error(
215        &self,
216        request: &LanguageModelRequest,
217        error: impl Into<LanguageModelCompletionError>,
218    ) {
219        let current_completion_txs = self.current_completion_txs.lock();
220        let tx = current_completion_txs
221            .iter()
222            .find(|(req, _)| req == request)
223            .map(|(_, tx)| tx)
224            .unwrap();
225        tx.unbounded_send(Err(error.into())).unwrap();
226    }
227
228    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
229        self.current_completion_txs
230            .lock()
231            .retain(|(req, _)| req != request);
232    }
233
234    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
235        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
236    }
237
238    pub fn send_last_completion_stream_event(
239        &self,
240        event: impl Into<LanguageModelCompletionEvent>,
241    ) {
242        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
243    }
244
245    pub fn send_last_completion_stream_error(
246        &self,
247        error: impl Into<LanguageModelCompletionError>,
248    ) {
249        self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
250    }
251
252    pub fn end_last_completion_stream(&self) {
253        self.end_completion_stream(self.pending_completions().last().unwrap());
254    }
255}
256
257impl LanguageModel for FakeLanguageModel {
258    fn id(&self) -> LanguageModelId {
259        self.id.clone()
260    }
261
262    fn name(&self) -> LanguageModelName {
263        self.name.clone()
264    }
265
266    fn provider_id(&self) -> LanguageModelProviderId {
267        self.provider_id.clone()
268    }
269
270    fn provider_name(&self) -> LanguageModelProviderName {
271        self.provider_name.clone()
272    }
273
274    fn supports_tools(&self) -> bool {
275        false
276    }
277
278    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
279        false
280    }
281
282    fn supports_images(&self) -> bool {
283        false
284    }
285
286    fn supports_thinking(&self) -> bool {
287        self.supports_thinking.load(SeqCst)
288    }
289
290    fn supports_streaming_tools(&self) -> bool {
291        self.supports_streaming_tools.load(SeqCst)
292    }
293
294    fn telemetry_id(&self) -> String {
295        "fake".to_string()
296    }
297
298    fn max_token_count(&self) -> u64 {
299        1000000
300    }
301
302    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
303        futures::future::ready(Ok(0)).boxed()
304    }
305
306    fn stream_completion(
307        &self,
308        request: LanguageModelRequest,
309        _: &AsyncApp,
310    ) -> BoxFuture<
311        'static,
312        Result<
313            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
314            LanguageModelCompletionError,
315        >,
316    > {
317        if self.forbid_requests.load(SeqCst) {
318            async move {
319                Err(LanguageModelCompletionError::Other(anyhow!(
320                    "requests are forbidden"
321                )))
322            }
323            .boxed()
324        } else {
325            let (tx, rx) = mpsc::unbounded();
326            self.current_completion_txs.lock().push((request, tx));
327            async move { Ok(rx.boxed()) }.boxed()
328        }
329    }
330
331    fn as_fake(&self) -> &Self {
332        self
333    }
334}