fake.rs

  1use crate::{
  2    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  3    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  4    LanguageModelRequest,
  5};
  6use anyhow::Context as _;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    future::BoxFuture,
 10    stream::BoxStream,
 11    FutureExt, StreamExt,
 12};
 13use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 14use http_client::Result;
 15use parking_lot::Mutex;
 16use std::sync::Arc;
 17use ui::WindowContext;
 18
 19pub fn language_model_id() -> LanguageModelId {
 20    LanguageModelId::from("fake".to_string())
 21}
 22
 23pub fn language_model_name() -> LanguageModelName {
 24    LanguageModelName::from("Fake".to_string())
 25}
 26
 27pub fn provider_id() -> LanguageModelProviderId {
 28    LanguageModelProviderId::from("fake".to_string())
 29}
 30
 31pub fn provider_name() -> LanguageModelProviderName {
 32    LanguageModelProviderName::from("Fake".to_string())
 33}
 34
 35#[derive(Clone, Default)]
 36pub struct FakeLanguageModelProvider;
 37
 38impl LanguageModelProviderState for FakeLanguageModelProvider {
 39    type ObservableEntity = ();
 40
 41    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 42        None
 43    }
 44}
 45
 46impl LanguageModelProvider for FakeLanguageModelProvider {
 47    fn id(&self) -> LanguageModelProviderId {
 48        provider_id()
 49    }
 50
 51    fn name(&self) -> LanguageModelProviderName {
 52        provider_name()
 53    }
 54
 55    fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 56        vec![Arc::new(FakeLanguageModel::default())]
 57    }
 58
 59    fn is_authenticated(&self, _: &AppContext) -> bool {
 60        true
 61    }
 62
 63    fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
 64        Task::ready(Ok(()))
 65    }
 66
 67    fn configuration_view(&self, _: &mut WindowContext) -> AnyView {
 68        unimplemented!()
 69    }
 70
 71    fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
 72        Task::ready(Ok(()))
 73    }
 74}
 75
 76impl FakeLanguageModelProvider {
 77    pub fn test_model(&self) -> FakeLanguageModel {
 78        FakeLanguageModel::default()
 79    }
 80}
 81
 82#[derive(Debug, PartialEq)]
 83pub struct ToolUseRequest {
 84    pub request: LanguageModelRequest,
 85    pub name: String,
 86    pub description: String,
 87    pub schema: serde_json::Value,
 88}
 89
 90#[derive(Default)]
 91pub struct FakeLanguageModel {
 92    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 93    current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
 94}
 95
 96impl FakeLanguageModel {
 97    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 98        self.current_completion_txs
 99            .lock()
100            .iter()
101            .map(|(request, _)| request.clone())
102            .collect()
103    }
104
105    pub fn completion_count(&self) -> usize {
106        self.current_completion_txs.lock().len()
107    }
108
109    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
110        let current_completion_txs = self.current_completion_txs.lock();
111        let tx = current_completion_txs
112            .iter()
113            .find(|(req, _)| req == request)
114            .map(|(_, tx)| tx)
115            .unwrap();
116        tx.unbounded_send(chunk).unwrap();
117    }
118
119    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
120        self.current_completion_txs
121            .lock()
122            .retain(|(req, _)| req != request);
123    }
124
125    pub fn stream_last_completion_response(&self, chunk: String) {
126        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
127    }
128
129    pub fn end_last_completion_stream(&self) {
130        self.end_completion_stream(self.pending_completions().last().unwrap());
131    }
132
133    pub fn respond_to_tool_use(
134        &self,
135        tool_call: &ToolUseRequest,
136        response: Result<serde_json::Value>,
137    ) {
138        let mut current_tool_call_txs = self.current_tool_use_txs.lock();
139        if let Some(index) = current_tool_call_txs
140            .iter()
141            .position(|(call, _)| call == tool_call)
142        {
143            let (_, tx) = current_tool_call_txs.remove(index);
144            tx.send(response).unwrap();
145        }
146    }
147
148    pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
149        let mut current_tool_call_txs = self.current_tool_use_txs.lock();
150        let (_, tx) = current_tool_call_txs.pop().unwrap();
151        tx.send(response).unwrap();
152    }
153}
154
155impl LanguageModel for FakeLanguageModel {
156    fn id(&self) -> LanguageModelId {
157        language_model_id()
158    }
159
160    fn name(&self) -> LanguageModelName {
161        language_model_name()
162    }
163
164    fn provider_id(&self) -> LanguageModelProviderId {
165        provider_id()
166    }
167
168    fn provider_name(&self) -> LanguageModelProviderName {
169        provider_name()
170    }
171
172    fn telemetry_id(&self) -> String {
173        "fake".to_string()
174    }
175
176    fn max_token_count(&self) -> usize {
177        1000000
178    }
179
180    fn count_tokens(
181        &self,
182        _: LanguageModelRequest,
183        _: &AppContext,
184    ) -> BoxFuture<'static, Result<usize>> {
185        futures::future::ready(Ok(0)).boxed()
186    }
187
188    fn stream_completion(
189        &self,
190        request: LanguageModelRequest,
191        _: &AsyncAppContext,
192    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
193        let (tx, rx) = mpsc::unbounded();
194        self.current_completion_txs.lock().push((request, tx));
195        async move { Ok(rx.map(Ok).boxed()) }.boxed()
196    }
197
198    fn use_any_tool(
199        &self,
200        request: LanguageModelRequest,
201        name: String,
202        description: String,
203        schema: serde_json::Value,
204        _cx: &AsyncAppContext,
205    ) -> BoxFuture<'static, Result<serde_json::Value>> {
206        let (tx, rx) = oneshot::channel();
207        let tool_call = ToolUseRequest {
208            request,
209            name,
210            description,
211            schema,
212        };
213        self.current_tool_use_txs.lock().push((tool_call, tx));
214        async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
215    }
216
217    fn as_fake(&self) -> &Self {
218        self
219    }
220}