fake_provider.rs

  1use crate::{
  2    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
  3    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  4    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  5    LanguageModelRequest, LanguageModelToolChoice,
  6};
  7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  9use http_client::Result;
 10use parking_lot::Mutex;
 11use std::sync::Arc;
 12
 13#[derive(Clone)]
 14pub struct FakeLanguageModelProvider {
 15    id: LanguageModelProviderId,
 16    name: LanguageModelProviderName,
 17}
 18
 19impl Default for FakeLanguageModelProvider {
 20    fn default() -> Self {
 21        Self {
 22            id: LanguageModelProviderId::from("fake".to_string()),
 23            name: LanguageModelProviderName::from("Fake".to_string()),
 24        }
 25    }
 26}
 27
 28impl LanguageModelProviderState for FakeLanguageModelProvider {
 29    type ObservableEntity = ();
 30
 31    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 32        None
 33    }
 34}
 35
 36impl LanguageModelProvider for FakeLanguageModelProvider {
 37    fn id(&self) -> LanguageModelProviderId {
 38        self.id.clone()
 39    }
 40
 41    fn name(&self) -> LanguageModelProviderName {
 42        self.name.clone()
 43    }
 44
 45    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 46        Some(Arc::new(FakeLanguageModel::default()))
 47    }
 48
 49    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 50        Some(Arc::new(FakeLanguageModel::default()))
 51    }
 52
 53    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 54        vec![Arc::new(FakeLanguageModel::default())]
 55    }
 56
 57    fn is_authenticated(&self, _: &App) -> bool {
 58        true
 59    }
 60
 61    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 62        Task::ready(Ok(()))
 63    }
 64
 65    fn configuration_view(
 66        &self,
 67        _target_agent: ConfigurationViewTargetAgent,
 68        _window: &mut Window,
 69        _: &mut App,
 70    ) -> AnyView {
 71        unimplemented!()
 72    }
 73
 74    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 75        Task::ready(Ok(()))
 76    }
 77}
 78
 79impl FakeLanguageModelProvider {
 80    pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
 81        Self { id, name }
 82    }
 83
 84    pub fn test_model(&self) -> FakeLanguageModel {
 85        FakeLanguageModel::default()
 86    }
 87}
 88
 89#[derive(Debug, PartialEq)]
 90pub struct ToolUseRequest {
 91    pub request: LanguageModelRequest,
 92    pub name: String,
 93    pub description: String,
 94    pub schema: serde_json::Value,
 95}
 96
 97pub struct FakeLanguageModel {
 98    provider_id: LanguageModelProviderId,
 99    provider_name: LanguageModelProviderName,
100    current_completion_txs: Mutex<
101        Vec<(
102            LanguageModelRequest,
103            mpsc::UnboundedSender<LanguageModelCompletionEvent>,
104        )>,
105    >,
106}
107
108impl Default for FakeLanguageModel {
109    fn default() -> Self {
110        Self {
111            provider_id: LanguageModelProviderId::from("fake".to_string()),
112            provider_name: LanguageModelProviderName::from("Fake".to_string()),
113            current_completion_txs: Mutex::new(Vec::new()),
114        }
115    }
116}
117
118impl FakeLanguageModel {
119    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
120        self.current_completion_txs
121            .lock()
122            .iter()
123            .map(|(request, _)| request.clone())
124            .collect()
125    }
126
127    pub fn completion_count(&self) -> usize {
128        self.current_completion_txs.lock().len()
129    }
130
131    pub fn send_completion_stream_text_chunk(
132        &self,
133        request: &LanguageModelRequest,
134        chunk: impl Into<String>,
135    ) {
136        self.send_completion_stream_event(
137            request,
138            LanguageModelCompletionEvent::Text(chunk.into()),
139        );
140    }
141
142    pub fn send_completion_stream_event(
143        &self,
144        request: &LanguageModelRequest,
145        event: impl Into<LanguageModelCompletionEvent>,
146    ) {
147        let current_completion_txs = self.current_completion_txs.lock();
148        let tx = current_completion_txs
149            .iter()
150            .find(|(req, _)| req == request)
151            .map(|(_, tx)| tx)
152            .unwrap();
153        tx.unbounded_send(event.into()).unwrap();
154    }
155
156    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
157        self.current_completion_txs
158            .lock()
159            .retain(|(req, _)| req != request);
160    }
161
162    pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
163        self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
164    }
165
166    pub fn send_last_completion_stream_event(
167        &self,
168        event: impl Into<LanguageModelCompletionEvent>,
169    ) {
170        self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
171    }
172
173    pub fn end_last_completion_stream(&self) {
174        self.end_completion_stream(self.pending_completions().last().unwrap());
175    }
176}
177
178impl LanguageModel for FakeLanguageModel {
179    fn id(&self) -> LanguageModelId {
180        LanguageModelId::from("fake".to_string())
181    }
182
183    fn name(&self) -> LanguageModelName {
184        LanguageModelName::from("Fake".to_string())
185    }
186
187    fn provider_id(&self) -> LanguageModelProviderId {
188        self.provider_id.clone()
189    }
190
191    fn provider_name(&self) -> LanguageModelProviderName {
192        self.provider_name.clone()
193    }
194
195    fn supports_tools(&self) -> bool {
196        false
197    }
198
199    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
200        false
201    }
202
203    fn supports_images(&self) -> bool {
204        false
205    }
206
207    fn telemetry_id(&self) -> String {
208        "fake".to_string()
209    }
210
211    fn max_token_count(&self) -> u64 {
212        1000000
213    }
214
215    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
216        futures::future::ready(Ok(0)).boxed()
217    }
218
219    fn stream_completion(
220        &self,
221        request: LanguageModelRequest,
222        _: &AsyncApp,
223    ) -> BoxFuture<
224        'static,
225        Result<
226            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
227            LanguageModelCompletionError,
228        >,
229    > {
230        let (tx, rx) = mpsc::unbounded();
231        self.current_completion_txs.lock().push((request, tx));
232        async move { Ok(rx.map(Ok).boxed()) }.boxed()
233    }
234
235    fn as_fake(&self) -> &Self {
236        self
237    }
238}