fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  3    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  4    LanguageModelProviderState, LanguageModelRequest,
  5};
  6use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  8use http_client::Result;
  9use parking_lot::Mutex;
 10use serde::Serialize;
 11use std::sync::Arc;
 12
 13pub fn language_model_id() -> LanguageModelId {
 14    LanguageModelId::from("fake".to_string())
 15}
 16
 17pub fn language_model_name() -> LanguageModelName {
 18    LanguageModelName::from("Fake".to_string())
 19}
 20
 21pub fn provider_id() -> LanguageModelProviderId {
 22    LanguageModelProviderId::from("fake".to_string())
 23}
 24
 25pub fn provider_name() -> LanguageModelProviderName {
 26    LanguageModelProviderName::from("Fake".to_string())
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct FakeLanguageModelProvider;
 31
 32impl LanguageModelProviderState for FakeLanguageModelProvider {
 33    type ObservableEntity = ();
 34
 35    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 36        None
 37    }
 38}
 39
 40impl LanguageModelProvider for FakeLanguageModelProvider {
 41    fn id(&self) -> LanguageModelProviderId {
 42        provider_id()
 43    }
 44
 45    fn name(&self) -> LanguageModelProviderName {
 46        provider_name()
 47    }
 48
 49    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 50        vec![Arc::new(FakeLanguageModel::default())]
 51    }
 52
 53    fn is_authenticated(&self, _: &App) -> bool {
 54        true
 55    }
 56
 57    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 58        Task::ready(Ok(()))
 59    }
 60
 61    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 62        unimplemented!()
 63    }
 64
 65    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 66        Task::ready(Ok(()))
 67    }
 68}
 69
 70impl FakeLanguageModelProvider {
 71    pub fn test_model(&self) -> FakeLanguageModel {
 72        FakeLanguageModel::default()
 73    }
 74}
 75
 76#[derive(Debug, PartialEq)]
 77pub struct ToolUseRequest {
 78    pub request: LanguageModelRequest,
 79    pub name: String,
 80    pub description: String,
 81    pub schema: serde_json::Value,
 82}
 83
 84#[derive(Default)]
 85pub struct FakeLanguageModel {
 86    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 87    current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
 88}
 89
 90impl FakeLanguageModel {
 91    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 92        self.current_completion_txs
 93            .lock()
 94            .iter()
 95            .map(|(request, _)| request.clone())
 96            .collect()
 97    }
 98
 99    pub fn completion_count(&self) -> usize {
100        self.current_completion_txs.lock().len()
101    }
102
103    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
104        let current_completion_txs = self.current_completion_txs.lock();
105        let tx = current_completion_txs
106            .iter()
107            .find(|(req, _)| req == request)
108            .map(|(_, tx)| tx)
109            .unwrap();
110        tx.unbounded_send(chunk).unwrap();
111    }
112
113    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
114        self.current_completion_txs
115            .lock()
116            .retain(|(req, _)| req != request);
117    }
118
119    pub fn stream_last_completion_response(&self, chunk: String) {
120        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
121    }
122
123    pub fn end_last_completion_stream(&self) {
124        self.end_completion_stream(self.pending_completions().last().unwrap());
125    }
126
127    pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
128        let response = serde_json::to_string(&response).unwrap();
129        let mut current_tool_call_txs = self.current_tool_use_txs.lock();
130        let (_, tx) = current_tool_call_txs.pop().unwrap();
131        tx.unbounded_send(response).unwrap();
132    }
133}
134
135impl LanguageModel for FakeLanguageModel {
136    fn id(&self) -> LanguageModelId {
137        language_model_id()
138    }
139
140    fn name(&self) -> LanguageModelName {
141        language_model_name()
142    }
143
144    fn provider_id(&self) -> LanguageModelProviderId {
145        provider_id()
146    }
147
148    fn provider_name(&self) -> LanguageModelProviderName {
149        provider_name()
150    }
151
152    fn telemetry_id(&self) -> String {
153        "fake".to_string()
154    }
155
156    fn max_token_count(&self) -> usize {
157        1000000
158    }
159
160    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
161        futures::future::ready(Ok(0)).boxed()
162    }
163
164    fn stream_completion(
165        &self,
166        request: LanguageModelRequest,
167        _: &AsyncApp,
168    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
169        let (tx, rx) = mpsc::unbounded();
170        self.current_completion_txs.lock().push((request, tx));
171        async move {
172            Ok(rx
173                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
174                .boxed())
175        }
176        .boxed()
177    }
178
179    fn use_any_tool(
180        &self,
181        request: LanguageModelRequest,
182        name: String,
183        description: String,
184        schema: serde_json::Value,
185        _cx: &AsyncApp,
186    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
187        let (tx, rx) = mpsc::unbounded();
188        let tool_call = ToolUseRequest {
189            request,
190            name,
191            description,
192            schema,
193        };
194        self.current_tool_use_txs.lock().push((tool_call, tx));
195        async move { Ok(rx.map(Ok).boxed()) }.boxed()
196    }
197
198    fn as_fake(&self) -> &Self {
199        self
200    }
201}