fake_provider.rs

  1use crate::{
  2    AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  3    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  4    LanguageModelProviderState, LanguageModelRequest,
  5};
  6use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
  8use http_client::Result;
  9use parking_lot::Mutex;
 10use serde::Serialize;
 11use std::sync::Arc;
 12
 13pub fn language_model_id() -> LanguageModelId {
 14    LanguageModelId::from("fake".to_string())
 15}
 16
 17pub fn language_model_name() -> LanguageModelName {
 18    LanguageModelName::from("Fake".to_string())
 19}
 20
 21pub fn provider_id() -> LanguageModelProviderId {
 22    LanguageModelProviderId::from("fake".to_string())
 23}
 24
 25pub fn provider_name() -> LanguageModelProviderName {
 26    LanguageModelProviderName::from("Fake".to_string())
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct FakeLanguageModelProvider;
 31
 32impl LanguageModelProviderState for FakeLanguageModelProvider {
 33    type ObservableEntity = ();
 34
 35    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 36        None
 37    }
 38}
 39
 40impl LanguageModelProvider for FakeLanguageModelProvider {
 41    fn id(&self) -> LanguageModelProviderId {
 42        provider_id()
 43    }
 44
 45    fn name(&self) -> LanguageModelProviderName {
 46        provider_name()
 47    }
 48
 49    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 50        Some(Arc::new(FakeLanguageModel::default()))
 51    }
 52
 53    fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
 54        vec![Arc::new(FakeLanguageModel::default())]
 55    }
 56
 57    fn is_authenticated(&self, _: &App) -> bool {
 58        true
 59    }
 60
 61    fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
 62        Task::ready(Ok(()))
 63    }
 64
 65    fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
 66        unimplemented!()
 67    }
 68
 69    fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
 70        Task::ready(Ok(()))
 71    }
 72}
 73
 74impl FakeLanguageModelProvider {
 75    pub fn test_model(&self) -> FakeLanguageModel {
 76        FakeLanguageModel::default()
 77    }
 78}
 79
 80#[derive(Debug, PartialEq)]
 81pub struct ToolUseRequest {
 82    pub request: LanguageModelRequest,
 83    pub name: String,
 84    pub description: String,
 85    pub schema: serde_json::Value,
 86}
 87
 88#[derive(Default)]
 89pub struct FakeLanguageModel {
 90    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 91    current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
 92}
 93
 94impl FakeLanguageModel {
 95    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 96        self.current_completion_txs
 97            .lock()
 98            .iter()
 99            .map(|(request, _)| request.clone())
100            .collect()
101    }
102
103    pub fn completion_count(&self) -> usize {
104        self.current_completion_txs.lock().len()
105    }
106
107    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
108        let current_completion_txs = self.current_completion_txs.lock();
109        let tx = current_completion_txs
110            .iter()
111            .find(|(req, _)| req == request)
112            .map(|(_, tx)| tx)
113            .unwrap();
114        tx.unbounded_send(chunk).unwrap();
115    }
116
117    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
118        self.current_completion_txs
119            .lock()
120            .retain(|(req, _)| req != request);
121    }
122
123    pub fn stream_last_completion_response(&self, chunk: String) {
124        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
125    }
126
127    pub fn end_last_completion_stream(&self) {
128        self.end_completion_stream(self.pending_completions().last().unwrap());
129    }
130
131    pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
132        let response = serde_json::to_string(&response).unwrap();
133        let mut current_tool_call_txs = self.current_tool_use_txs.lock();
134        let (_, tx) = current_tool_call_txs.pop().unwrap();
135        tx.unbounded_send(response).unwrap();
136    }
137}
138
139impl LanguageModel for FakeLanguageModel {
140    fn id(&self) -> LanguageModelId {
141        language_model_id()
142    }
143
144    fn name(&self) -> LanguageModelName {
145        language_model_name()
146    }
147
148    fn provider_id(&self) -> LanguageModelProviderId {
149        provider_id()
150    }
151
152    fn provider_name(&self) -> LanguageModelProviderName {
153        provider_name()
154    }
155
156    fn telemetry_id(&self) -> String {
157        "fake".to_string()
158    }
159
160    fn max_token_count(&self) -> usize {
161        1000000
162    }
163
164    fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
165        futures::future::ready(Ok(0)).boxed()
166    }
167
168    fn stream_completion(
169        &self,
170        request: LanguageModelRequest,
171        _: &AsyncApp,
172    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
173        let (tx, rx) = mpsc::unbounded();
174        self.current_completion_txs.lock().push((request, tx));
175        async move {
176            Ok(rx
177                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
178                .boxed())
179        }
180        .boxed()
181    }
182
183    fn use_any_tool(
184        &self,
185        request: LanguageModelRequest,
186        name: String,
187        description: String,
188        schema: serde_json::Value,
189        _cx: &AsyncApp,
190    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
191        let (tx, rx) = mpsc::unbounded();
192        let tool_call = ToolUseRequest {
193            request,
194            name,
195            description,
196            schema,
197        };
198        self.current_tool_use_txs.lock().push((tool_call, tx));
199        async move { Ok(rx.map(Ok).boxed()) }.boxed()
200    }
201
202    fn as_fake(&self) -> &Self {
203        self
204    }
205}