fake_provider.rs

  1use crate::{
  2    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
  3    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  4    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  5    LanguageModelRequest, LanguageModelToolChoice,
  6};
  7use anyhow::anyhow;
  8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
 10use http_client::Result;
 11use parking_lot::Mutex;
 12use smol::stream::StreamExt;
 13use std::sync::{
 14    Arc,
 15    atomic::{AtomicBool, Ordering::SeqCst},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModelProvider {
 20    id: LanguageModelProviderId,
 21    name: LanguageModelProviderName,
 22}
 23
 24impl Default for FakeLanguageModelProvider {
 25    fn default() -> Self {
 26        Self {
 27            id: LanguageModelProviderId::from("fake".to_string()),
 28            name: LanguageModelProviderName::from("Fake".to_string()),
 29        }
 30    }
 31}
 32
 33impl LanguageModelProviderState for FakeLanguageModelProvider {
 34    type ObservableEntity = ();
 35
 36    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 37        None
 38    }
 39}
 40
 41impl LanguageModelProvider for FakeLanguageModelProvider {
 42    fn id(&self) -> LanguageModelProviderId {
 43        self.id.clone()
 44    }
 45
 46    fn name(&self) -> LanguageModelProviderName {
 47        self.name.clone()
 48    }
 49
 50    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 51        Some(Arc::new(FakeLanguageModel::default()))
 52    }
 53
 54    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 55        Some(Arc::new(FakeLanguageModel::default()))
 56    }
 57
 58    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 59        vec![Arc::new(FakeLanguageModel::default())]
 60    }
 61
 62    fn is_authenticated(&self, _: &App) -> bool {
 63        true
 64    }
 65
 66    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 67        Task::ready(Ok(()))
 68    }
 69
 70    fn configuration_view(
 71        &self,
 72        _target_agent: ConfigurationViewTargetAgent,
 73        _window: &mut Window,
 74        _: &mut App,
 75    ) -> AnyView {
 76        unimplemented!()
 77    }
 78
 79    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 80        Task::ready(Ok(()))
 81    }
 82}
 83
 84impl FakeLanguageModelProvider {
 85    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 86        Self { id, name }
 87    }
 88
 89    pub fn test_model(&self) -> FakeLanguageModel {
 90        FakeLanguageModel::default()
 91    }
 92}
 93
 94#[derive(Debug, PartialEq)]
 95pub struct ToolUseRequest {
 96    pub request: LanguageModelRequest,
 97    pub name: String,
 98    pub description: String,
 99    pub schema: serde_json::Value,
100}
101
102pub struct FakeLanguageModel {
103    provider_id: LanguageModelProviderId,
104    provider_name: LanguageModelProviderName,
105    current_completion_txs: Mutex<
106        Vec<(
107            LanguageModelRequest,
108            mpsc::UnboundedSender<
109                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
110            >,
111        )>,
112    >,
113    forbid_requests: AtomicBool,
114}
115
116impl Default for FakeLanguageModel {
117    fn default() -> Self {
118        Self {
119            provider_id: LanguageModelProviderId::from("fake".to_string()),
120            provider_name: LanguageModelProviderName::from("Fake".to_string()),
121            current_completion_txs: Mutex::new(Vec::new()),
122            forbid_requests: AtomicBool::new(false),
123        }
124    }
125}
126
127impl FakeLanguageModel {
128    pub fn allow_requests(&self) {
129        self.forbid_requests.store(false, SeqCst);
130    }
131
132    pub fn forbid_requests(&self) {
133        self.forbid_requests.store(true, SeqCst);
134    }
135
136    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
137        self.current_completion_txs
138            .lock()
139            .iter()
140            .map(|(request, _)| request.clone())
141            .collect()
142    }
143
144    pub fn completion_count(&self) -> usize {
145        self.current_completion_txs.lock().len()
146    }
147
148    pub fn send_completion_stream_text_chunk(
149        &self,
150        request: &LanguageModelRequest,
151        chunk: impl Into<String>,
152    ) {
153        self.send_completion_stream_event(
154            request,
155            LanguageModelCompletionEvent::Text(chunk.into()),
156        );
157    }
158
159    pub fn send_completion_stream_event(
160        &self,
161        request: &LanguageModelRequest,
162        event: impl Into<LanguageModelCompletionEvent>,
163    ) {
164        let current_completion_txs = self.current_completion_txs.lock();
165        let tx = current_completion_txs
166            .iter()
167            .find(|(req, _)| req == request)
168            .map(|(_, tx)| tx)
169            .unwrap();
170        tx.unbounded_send(Ok(event.into())).unwrap();
171    }
172
173    pub fn send_completion_stream_error(
174        &self,
175        request: &LanguageModelRequest,
176        error: impl Into<LanguageModelCompletionError>,
177    ) {
178        let current_completion_txs = self.current_completion_txs.lock();
179        let tx = current_completion_txs
180            .iter()
181            .find(|(req, _)| req == request)
182            .map(|(_, tx)| tx)
183            .unwrap();
184        tx.unbounded_send(Err(error.into())).unwrap();
185    }
186
187    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
188        self.current_completion_txs
189            .lock()
190            .retain(|(req, _)| req != request);
191    }
192
193    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
194        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
195    }
196
197    pub fn send_last_completion_stream_event(
198        &self,
199        event: impl Into<LanguageModelCompletionEvent>,
200    ) {
201        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
202    }
203
204    pub fn send_last_completion_stream_error(
205        &self,
206        error: impl Into<LanguageModelCompletionError>,
207    ) {
208        self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
209    }
210
211    pub fn end_last_completion_stream(&self) {
212        self.end_completion_stream(self.pending_completions().last().unwrap());
213    }
214}
215
216impl LanguageModel for FakeLanguageModel {
217    fn id(&self) -> LanguageModelId {
218        LanguageModelId::from("fake".to_string())
219    }
220
221    fn name(&self) -> LanguageModelName {
222        LanguageModelName::from("Fake".to_string())
223    }
224
225    fn provider_id(&self) -> LanguageModelProviderId {
226        self.provider_id.clone()
227    }
228
229    fn provider_name(&self) -> LanguageModelProviderName {
230        self.provider_name.clone()
231    }
232
233    fn supports_tools(&self) -> bool {
234        false
235    }
236
237    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
238        false
239    }
240
241    fn supports_images(&self) -> bool {
242        false
243    }
244
245    fn telemetry_id(&self) -> String {
246        "fake".to_string()
247    }
248
249    fn max_token_count(&self) -> u64 {
250        1000000
251    }
252
253    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
254        futures::future::ready(Ok(0)).boxed()
255    }
256
257    fn stream_completion(
258        &self,
259        request: LanguageModelRequest,
260        _: &AsyncApp,
261    ) -> BoxFuture<
262        'static,
263        Result<
264            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
265            LanguageModelCompletionError,
266        >,
267    > {
268        if self.forbid_requests.load(SeqCst) {
269            async move {
270                Err(LanguageModelCompletionError::Other(anyhow!(
271                    "requests are forbidden"
272                )))
273            }
274            .boxed()
275        } else {
276            let (tx, rx) = mpsc::unbounded();
277            self.current_completion_txs.lock().push((request, tx));
278            async move { Ok(rx.boxed()) }.boxed()
279        }
280    }
281
282    fn as_fake(&self) -> &Self {
283        self
284    }
285}