fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  3    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  4    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  5};
  6use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  8use http_client::Result;
  9use parking_lot::Mutex;
 10use std::sync::Arc;
 11
 12pub fn language_model_id() -> LanguageModelId {
 13    LanguageModelId::from("fake".to_string())
 14}
 15
 16pub fn language_model_name() -> LanguageModelName {
 17    LanguageModelName::from("Fake".to_string())
 18}
 19
 20pub fn provider_id() -> LanguageModelProviderId {
 21    LanguageModelProviderId::from("fake".to_string())
 22}
 23
 24pub fn provider_name() -> LanguageModelProviderName {
 25    LanguageModelProviderName::from("Fake".to_string())
 26}
 27
 28#[derive(Clone, Default)]
 29pub struct FakeLanguageModelProvider;
 30
 31impl LanguageModelProviderState for FakeLanguageModelProvider {
 32    type ObservableEntity = ();
 33
 34    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 35        None
 36    }
 37}
 38
 39impl LanguageModelProvider for FakeLanguageModelProvider {
 40    fn id(&self) -> LanguageModelProviderId {
 41        provider_id()
 42    }
 43
 44    fn name(&self) -> LanguageModelProviderName {
 45        provider_name()
 46    }
 47
 48    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 49        Some(Arc::new(FakeLanguageModel::default()))
 50    }
 51
 52    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 53        Some(Arc::new(FakeLanguageModel::default()))
 54    }
 55
 56    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 57        vec![Arc::new(FakeLanguageModel::default())]
 58    }
 59
 60    fn is_authenticated(&self, _: &App) -> bool {
 61        true
 62    }
 63
 64    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 65        Task::ready(Ok(()))
 66    }
 67
 68    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 69        unimplemented!()
 70    }
 71
 72    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 73        Task::ready(Ok(()))
 74    }
 75}
 76
 77impl FakeLanguageModelProvider {
 78    pub fn test_model(&self) -> FakeLanguageModel {
 79        FakeLanguageModel::default()
 80    }
 81}
 82
 83#[derive(Debug, PartialEq)]
 84pub struct ToolUseRequest {
 85    pub request: LanguageModelRequest,
 86    pub name: String,
 87    pub description: String,
 88    pub schema: serde_json::Value,
 89}
 90
 91#[derive(Default)]
 92pub struct FakeLanguageModel {
 93    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 94}
 95
 96impl FakeLanguageModel {
 97    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 98        self.current_completion_txs
 99            .lock()
100            .iter()
101            .map(|(request, _)| request.clone())
102            .collect()
103    }
104
105    pub fn completion_count(&self) -> usize {
106        self.current_completion_txs.lock().len()
107    }
108
109    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
110        let current_completion_txs = self.current_completion_txs.lock();
111        let tx = current_completion_txs
112            .iter()
113            .find(|(req, _)| req == request)
114            .map(|(_, tx)| tx)
115            .unwrap();
116        tx.unbounded_send(chunk).unwrap();
117    }
118
119    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
120        self.current_completion_txs
121            .lock()
122            .retain(|(req, _)| req != request);
123    }
124
125    pub fn stream_last_completion_response(&self, chunk: String) {
126        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
127    }
128
129    pub fn end_last_completion_stream(&self) {
130        self.end_completion_stream(self.pending_completions().last().unwrap());
131    }
132}
133
134impl LanguageModel for FakeLanguageModel {
135    fn id(&self) -> LanguageModelId {
136        language_model_id()
137    }
138
139    fn name(&self) -> LanguageModelName {
140        language_model_name()
141    }
142
143    fn provider_id(&self) -> LanguageModelProviderId {
144        provider_id()
145    }
146
147    fn provider_name(&self) -> LanguageModelProviderName {
148        provider_name()
149    }
150
151    fn supports_tools(&self) -> bool {
152        false
153    }
154
155    fn telemetry_id(&self) -> String {
156        "fake".to_string()
157    }
158
159    fn max_token_count(&self) -> usize {
160        1000000
161    }
162
163    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
164        futures::future::ready(Ok(0)).boxed()
165    }
166
167    fn stream_completion(
168        &self,
169        request: LanguageModelRequest,
170        _: &AsyncApp,
171    ) -> BoxFuture<
172        'static,
173        Result<
174            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
175        >,
176    > {
177        let (tx, rx) = mpsc::unbounded();
178        self.current_completion_txs.lock().push((request, tx));
179        async move {
180            Ok(rx
181                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
182                .boxed())
183        }
184        .boxed()
185    }
186
187    fn as_fake(&self) -> &Self {
188        self
189    }
190}