fake_provider.rs

  1use crate::{
  2    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
  3    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  4    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  5    LanguageModelRequest, LanguageModelToolChoice,
  6};
  7use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  9use http_client::Result;
 10use parking_lot::Mutex;
 11use smol::stream::StreamExt;
 12use std::sync::Arc;
 13
 14#[derive(Clone)]
 15pub struct FakeLanguageModelProvider {
 16    id: LanguageModelProviderId,
 17    name: LanguageModelProviderName,
 18}
 19
 20impl Default for FakeLanguageModelProvider {
 21    fn default() -> Self {
 22        Self {
 23            id: LanguageModelProviderId::from("fake".to_string()),
 24            name: LanguageModelProviderName::from("Fake".to_string()),
 25        }
 26    }
 27}
 28
 29impl LanguageModelProviderState for FakeLanguageModelProvider {
 30    type ObservableEntity = ();
 31
 32    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 33        None
 34    }
 35}
 36
 37impl LanguageModelProvider for FakeLanguageModelProvider {
 38    fn id(&self) -> LanguageModelProviderId {
 39        self.id.clone()
 40    }
 41
 42    fn name(&self) -> LanguageModelProviderName {
 43        self.name.clone()
 44    }
 45
 46    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 47        Some(Arc::new(FakeLanguageModel::default()))
 48    }
 49
 50    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 51        Some(Arc::new(FakeLanguageModel::default()))
 52    }
 53
 54    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 55        vec![Arc::new(FakeLanguageModel::default())]
 56    }
 57
 58    fn is_authenticated(&self, _: &App) -> bool {
 59        true
 60    }
 61
 62    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 63        Task::ready(Ok(()))
 64    }
 65
 66    fn configuration_view(
 67        &self,
 68        _target_agent: ConfigurationViewTargetAgent,
 69        _window: &mut Window,
 70        _: &mut App,
 71    ) -> AnyView {
 72        unimplemented!()
 73    }
 74
 75    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 76        Task::ready(Ok(()))
 77    }
 78}
 79
 80impl FakeLanguageModelProvider {
 81    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 82        Self { id, name }
 83    }
 84
 85    pub fn test_model(&self) -> FakeLanguageModel {
 86        FakeLanguageModel::default()
 87    }
 88}
 89
 90#[derive(Debug, PartialEq)]
 91pub struct ToolUseRequest {
 92    pub request: LanguageModelRequest,
 93    pub name: String,
 94    pub description: String,
 95    pub schema: serde_json::Value,
 96}
 97
 98pub struct FakeLanguageModel {
 99    provider_id: LanguageModelProviderId,
100    provider_name: LanguageModelProviderName,
101    current_completion_txs: Mutex<
102        Vec<(
103            LanguageModelRequest,
104            mpsc::UnboundedSender<
105                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
106            >,
107        )>,
108    >,
109}
110
111impl Default for FakeLanguageModel {
112    fn default() -> Self {
113        Self {
114            provider_id: LanguageModelProviderId::from("fake".to_string()),
115            provider_name: LanguageModelProviderName::from("Fake".to_string()),
116            current_completion_txs: Mutex::new(Vec::new()),
117        }
118    }
119}
120
121impl FakeLanguageModel {
122    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
123        self.current_completion_txs
124            .lock()
125            .iter()
126            .map(|(request, _)| request.clone())
127            .collect()
128    }
129
130    pub fn completion_count(&self) -> usize {
131        self.current_completion_txs.lock().len()
132    }
133
134    pub fn send_completion_stream_text_chunk(
135        &self,
136        request: &LanguageModelRequest,
137        chunk: impl Into<String>,
138    ) {
139        self.send_completion_stream_event(
140            request,
141            LanguageModelCompletionEvent::Text(chunk.into()),
142        );
143    }
144
145    pub fn send_completion_stream_event(
146        &self,
147        request: &LanguageModelRequest,
148        event: impl Into<LanguageModelCompletionEvent>,
149    ) {
150        let current_completion_txs = self.current_completion_txs.lock();
151        let tx = current_completion_txs
152            .iter()
153            .find(|(req, _)| req == request)
154            .map(|(_, tx)| tx)
155            .unwrap();
156        tx.unbounded_send(Ok(event.into())).unwrap();
157    }
158
159    pub fn send_completion_stream_error(
160        &self,
161        request: &LanguageModelRequest,
162        error: impl Into<LanguageModelCompletionError>,
163    ) {
164        let current_completion_txs = self.current_completion_txs.lock();
165        let tx = current_completion_txs
166            .iter()
167            .find(|(req, _)| req == request)
168            .map(|(_, tx)| tx)
169            .unwrap();
170        tx.unbounded_send(Err(error.into())).unwrap();
171    }
172
173    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
174        self.current_completion_txs
175            .lock()
176            .retain(|(req, _)| req != request);
177    }
178
179    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
180        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
181    }
182
183    pub fn send_last_completion_stream_event(
184        &self,
185        event: impl Into<LanguageModelCompletionEvent>,
186    ) {
187        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
188    }
189
190    pub fn send_last_completion_stream_error(
191        &self,
192        error: impl Into<LanguageModelCompletionError>,
193    ) {
194        self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
195    }
196
197    pub fn end_last_completion_stream(&self) {
198        self.end_completion_stream(self.pending_completions().last().unwrap());
199    }
200}
201
202impl LanguageModel for FakeLanguageModel {
203    fn id(&self) -> LanguageModelId {
204        LanguageModelId::from("fake".to_string())
205    }
206
207    fn name(&self) -> LanguageModelName {
208        LanguageModelName::from("Fake".to_string())
209    }
210
211    fn provider_id(&self) -> LanguageModelProviderId {
212        self.provider_id.clone()
213    }
214
215    fn provider_name(&self) -> LanguageModelProviderName {
216        self.provider_name.clone()
217    }
218
219    fn supports_tools(&self) -> bool {
220        false
221    }
222
223    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
224        false
225    }
226
227    fn supports_images(&self) -> bool {
228        false
229    }
230
231    fn telemetry_id(&self) -> String {
232        "fake".to_string()
233    }
234
235    fn max_token_count(&self) -> u64 {
236        1000000
237    }
238
239    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
240        futures::future::ready(Ok(0)).boxed()
241    }
242
243    fn stream_completion(
244        &self,
245        request: LanguageModelRequest,
246        _: &AsyncApp,
247    ) -> BoxFuture<
248        'static,
249        Result<
250            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
251            LanguageModelCompletionError,
252        >,
253    > {
254        let (tx, rx) = mpsc::unbounded();
255        self.current_completion_txs.lock().push((request, tx));
256        async move { Ok(rx.boxed()) }.boxed()
257    }
258
259    fn as_fake(&self) -> &Self {
260        self
261    }
262}