fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  3    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  4    LanguageModelProviderState, LanguageModelRequest,
  5};
  6use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  8use http_client::Result;
  9use parking_lot::Mutex;
 10use std::sync::Arc;
 11
 12pub fn language_model_id() -> LanguageModelId {
 13    LanguageModelId::from("fake".to_string())
 14}
 15
 16pub fn language_model_name() -> LanguageModelName {
 17    LanguageModelName::from("Fake".to_string())
 18}
 19
 20pub fn provider_id() -> LanguageModelProviderId {
 21    LanguageModelProviderId::from("fake".to_string())
 22}
 23
 24pub fn provider_name() -> LanguageModelProviderName {
 25    LanguageModelProviderName::from("Fake".to_string())
 26}
 27
 28#[derive(Clone, Default)]
 29pub struct FakeLanguageModelProvider;
 30
 31impl LanguageModelProviderState for FakeLanguageModelProvider {
 32    type ObservableEntity = ();
 33
 34    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 35        None
 36    }
 37}
 38
 39impl LanguageModelProvider for FakeLanguageModelProvider {
 40    fn id(&self) -> LanguageModelProviderId {
 41        provider_id()
 42    }
 43
 44    fn name(&self) -> LanguageModelProviderName {
 45        provider_name()
 46    }
 47
 48    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 49        Some(Arc::new(FakeLanguageModel::default()))
 50    }
 51
 52    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 53        vec![Arc::new(FakeLanguageModel::default())]
 54    }
 55
 56    fn is_authenticated(&self, _: &App) -> bool {
 57        true
 58    }
 59
 60    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 61        Task::ready(Ok(()))
 62    }
 63
 64    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 65        unimplemented!()
 66    }
 67
 68    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 69        Task::ready(Ok(()))
 70    }
 71}
 72
 73impl FakeLanguageModelProvider {
 74    pub fn test_model(&self) -> FakeLanguageModel {
 75        FakeLanguageModel::default()
 76    }
 77}
 78
 79#[derive(Debug, PartialEq)]
 80pub struct ToolUseRequest {
 81    pub request: LanguageModelRequest,
 82    pub name: String,
 83    pub description: String,
 84    pub schema: serde_json::Value,
 85}
 86
 87#[derive(Default)]
 88pub struct FakeLanguageModel {
 89    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 90}
 91
 92impl FakeLanguageModel {
 93    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 94        self.current_completion_txs
 95            .lock()
 96            .iter()
 97            .map(|(request, _)| request.clone())
 98            .collect()
 99    }
100
101    pub fn completion_count(&self) -> usize {
102        self.current_completion_txs.lock().len()
103    }
104
105    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
106        let current_completion_txs = self.current_completion_txs.lock();
107        let tx = current_completion_txs
108            .iter()
109            .find(|(req, _)| req == request)
110            .map(|(_, tx)| tx)
111            .unwrap();
112        tx.unbounded_send(chunk).unwrap();
113    }
114
115    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
116        self.current_completion_txs
117            .lock()
118            .retain(|(req, _)| req != request);
119    }
120
121    pub fn stream_last_completion_response(&self, chunk: String) {
122        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
123    }
124
125    pub fn end_last_completion_stream(&self) {
126        self.end_completion_stream(self.pending_completions().last().unwrap());
127    }
128}
129
130impl LanguageModel for FakeLanguageModel {
131    fn id(&self) -> LanguageModelId {
132        language_model_id()
133    }
134
135    fn name(&self) -> LanguageModelName {
136        language_model_name()
137    }
138
139    fn provider_id(&self) -> LanguageModelProviderId {
140        provider_id()
141    }
142
143    fn provider_name(&self) -> LanguageModelProviderName {
144        provider_name()
145    }
146
147    fn supports_tools(&self) -> bool {
148        false
149    }
150
151    fn telemetry_id(&self) -> String {
152        "fake".to_string()
153    }
154
155    fn max_token_count(&self) -> usize {
156        1000000
157    }
158
159    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
160        futures::future::ready(Ok(0)).boxed()
161    }
162
163    fn stream_completion(
164        &self,
165        request: LanguageModelRequest,
166        _: &AsyncApp,
167    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
168        let (tx, rx) = mpsc::unbounded();
169        self.current_completion_txs.lock().push((request, tx));
170        async move {
171            Ok(rx
172                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
173                .boxed())
174        }
175        .boxed()
176    }
177
178    fn as_fake(&self) -> &Self {
179        self
180    }
181}