fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  3    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  4    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  5    LanguageModelToolChoice, LanguageModelToolUse, StopReason,
  6};
  7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  9use http_client::Result;
 10use parking_lot::Mutex;
 11use std::sync::Arc;
 12
 13pub fn language_model_id() -> LanguageModelId {
 14    LanguageModelId::from("fake".to_string())
 15}
 16
 17pub fn language_model_name() -> LanguageModelName {
 18    LanguageModelName::from("Fake".to_string())
 19}
 20
 21pub fn provider_id() -> LanguageModelProviderId {
 22    LanguageModelProviderId::from("fake".to_string())
 23}
 24
 25pub fn provider_name() -> LanguageModelProviderName {
 26    LanguageModelProviderName::from("Fake".to_string())
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct FakeLanguageModelProvider;
 31
 32impl LanguageModelProviderState for FakeLanguageModelProvider {
 33    type ObservableEntity = ();
 34
 35    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 36        None
 37    }
 38}
 39
 40impl LanguageModelProvider for FakeLanguageModelProvider {
 41    fn id(&self) -> LanguageModelProviderId {
 42        provider_id()
 43    }
 44
 45    fn name(&self) -> LanguageModelProviderName {
 46        provider_name()
 47    }
 48
 49    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 50        Some(Arc::new(FakeLanguageModel::default()))
 51    }
 52
 53    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 54        Some(Arc::new(FakeLanguageModel::default()))
 55    }
 56
 57    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 58        vec![Arc::new(FakeLanguageModel::default())]
 59    }
 60
 61    fn is_authenticated(&self, _: &App) -> bool {
 62        true
 63    }
 64
 65    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 66        Task::ready(Ok(()))
 67    }
 68
 69    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 70        unimplemented!()
 71    }
 72
 73    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 74        Task::ready(Ok(()))
 75    }
 76}
 77
 78impl FakeLanguageModelProvider {
 79    pub fn test_model(&self) -> FakeLanguageModel {
 80        FakeLanguageModel::default()
 81    }
 82}
 83
 84#[derive(Debug, PartialEq)]
 85pub struct ToolUseRequest {
 86    pub request: LanguageModelRequest,
 87    pub name: String,
 88    pub description: String,
 89    pub schema: serde_json::Value,
 90}
 91
 92#[derive(Default)]
 93pub struct FakeLanguageModel {
 94    current_completion_txs: Mutex<
 95        Vec<(
 96            LanguageModelRequest,
 97            mpsc::UnboundedSender<LanguageModelCompletionEvent>,
 98        )>,
 99    >,
100}
101
102impl FakeLanguageModel {
103    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
104        self.current_completion_txs
105            .lock()
106            .iter()
107            .map(|(request, _)| request.clone())
108            .collect()
109    }
110
111    pub fn completion_count(&self) -> usize {
112        self.current_completion_txs.lock().len()
113    }
114
115    pub fn stream_completion_response(
116        &self,
117        request: &LanguageModelRequest,
118        stream: impl Into<FakeLanguageModelStream>,
119    ) {
120        let current_completion_txs = self.current_completion_txs.lock();
121        let tx = current_completion_txs
122            .iter()
123            .find(|(req, _)| req == request)
124            .map(|(_, tx)| tx)
125            .unwrap();
126        for event in stream.into().events {
127            tx.unbounded_send(event).unwrap();
128        }
129    }
130
131    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
132        self.current_completion_txs
133            .lock()
134            .retain(|(req, _)| req != request);
135    }
136
137    pub fn stream_last_completion_response(&self, chunk: impl Into<FakeLanguageModelStream>) {
138        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
139    }
140
141    pub fn end_last_completion_stream(&self) {
142        self.end_completion_stream(self.pending_completions().last().unwrap());
143    }
144}
145
146pub struct FakeLanguageModelStream {
147    events: Vec<LanguageModelCompletionEvent>,
148}
149
150impl<T: Into<String>> From<T> for FakeLanguageModelStream {
151    fn from(chunk: T) -> Self {
152        Self {
153            events: vec![LanguageModelCompletionEvent::Text(chunk.into())],
154        }
155    }
156}
157
158impl From<LanguageModelToolUse> for FakeLanguageModelStream {
159    fn from(tool_use: LanguageModelToolUse) -> Self {
160        Self {
161            events: vec![
162                LanguageModelCompletionEvent::ToolUse(tool_use),
163                LanguageModelCompletionEvent::Stop(StopReason::ToolUse),
164            ],
165        }
166    }
167}
168
169impl LanguageModel for FakeLanguageModel {
170    fn id(&self) -> LanguageModelId {
171        language_model_id()
172    }
173
174    fn name(&self) -> LanguageModelName {
175        language_model_name()
176    }
177
178    fn provider_id(&self) -> LanguageModelProviderId {
179        provider_id()
180    }
181
182    fn provider_name(&self) -> LanguageModelProviderName {
183        provider_name()
184    }
185
186    fn supports_tools(&self) -> bool {
187        false
188    }
189
190    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
191        false
192    }
193
194    fn supports_images(&self) -> bool {
195        false
196    }
197
198    fn telemetry_id(&self) -> String {
199        "fake".to_string()
200    }
201
202    fn max_token_count(&self) -> u64 {
203        1000000
204    }
205
206    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
207        futures::future::ready(Ok(0)).boxed()
208    }
209
210    fn stream_completion(
211        &self,
212        request: LanguageModelRequest,
213        _: &AsyncApp,
214    ) -> BoxFuture<
215        'static,
216        Result<
217            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
218            LanguageModelCompletionError,
219        >,
220    > {
221        let (tx, rx) = mpsc::unbounded();
222        self.current_completion_txs.lock().push((request, tx));
223        async move { Ok(rx.map(Ok).boxed()) }.boxed()
224    }
225
226    fn as_fake(&self) -> &Self {
227        self
228    }
229}