fake_provider.rs

  1use crate::{
  2    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  3    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  4    LanguageModelProviderState, LanguageModelRequest,
  5};
  6use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  7use gpui::{AnyView, AppContext, AsyncAppContext, Task};
  8use http_client::Result;
  9use parking_lot::Mutex;
 10use serde::Serialize;
 11use std::sync::Arc;
 12use ui::WindowContext;
 13
 14pub fn language_model_id() -> LanguageModelId {
 15    LanguageModelId::from("fake".to_string())
 16}
 17
 18pub fn language_model_name() -> LanguageModelName {
 19    LanguageModelName::from("Fake".to_string())
 20}
 21
 22pub fn provider_id() -> LanguageModelProviderId {
 23    LanguageModelProviderId::from("fake".to_string())
 24}
 25
 26pub fn provider_name() -> LanguageModelProviderName {
 27    LanguageModelProviderName::from("Fake".to_string())
 28}
 29
 30#[derive(Clone, Default)]
 31pub struct FakeLanguageModelProvider;
 32
 33impl LanguageModelProviderState for FakeLanguageModelProvider {
 34    type ObservableEntity = ();
 35
 36    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 37        None
 38    }
 39}
 40
 41impl LanguageModelProvider for FakeLanguageModelProvider {
 42    fn id(&self) -> LanguageModelProviderId {
 43        provider_id()
 44    }
 45
 46    fn name(&self) -> LanguageModelProviderName {
 47        provider_name()
 48    }
 49
 50    fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 51        vec![Arc::new(FakeLanguageModel::default())]
 52    }
 53
 54    fn is_authenticated(&self, _: &AppContext) -> bool {
 55        true
 56    }
 57
 58    fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
 59        Task::ready(Ok(()))
 60    }
 61
 62    fn configuration_view(&self, _: &mut WindowContext) -> AnyView {
 63        unimplemented!()
 64    }
 65
 66    fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
 67        Task::ready(Ok(()))
 68    }
 69}
 70
 71impl FakeLanguageModelProvider {
 72    pub fn test_model(&self) -> FakeLanguageModel {
 73        FakeLanguageModel::default()
 74    }
 75}
 76
 77#[derive(Debug, PartialEq)]
 78pub struct ToolUseRequest {
 79    pub request: LanguageModelRequest,
 80    pub name: String,
 81    pub description: String,
 82    pub schema: serde_json::Value,
 83}
 84
 85#[derive(Default)]
 86pub struct FakeLanguageModel {
 87    current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
 88    current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
 89}
 90
 91impl FakeLanguageModel {
 92    pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
 93        self.current_completion_txs
 94            .lock()
 95            .iter()
 96            .map(|(request, _)| request.clone())
 97            .collect()
 98    }
 99
100    pub fn completion_count(&self) -> usize {
101        self.current_completion_txs.lock().len()
102    }
103
104    pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
105        let current_completion_txs = self.current_completion_txs.lock();
106        let tx = current_completion_txs
107            .iter()
108            .find(|(req, _)| req == request)
109            .map(|(_, tx)| tx)
110            .unwrap();
111        tx.unbounded_send(chunk).unwrap();
112    }
113
114    pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
115        self.current_completion_txs
116            .lock()
117            .retain(|(req, _)| req != request);
118    }
119
120    pub fn stream_last_completion_response(&self, chunk: String) {
121        self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
122    }
123
124    pub fn end_last_completion_stream(&self) {
125        self.end_completion_stream(self.pending_completions().last().unwrap());
126    }
127
128    pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
129        let response = serde_json::to_string(&response).unwrap();
130        let mut current_tool_call_txs = self.current_tool_use_txs.lock();
131        let (_, tx) = current_tool_call_txs.pop().unwrap();
132        tx.unbounded_send(response).unwrap();
133    }
134}
135
136impl LanguageModel for FakeLanguageModel {
137    fn id(&self) -> LanguageModelId {
138        language_model_id()
139    }
140
141    fn name(&self) -> LanguageModelName {
142        language_model_name()
143    }
144
145    fn provider_id(&self) -> LanguageModelProviderId {
146        provider_id()
147    }
148
149    fn provider_name(&self) -> LanguageModelProviderName {
150        provider_name()
151    }
152
153    fn telemetry_id(&self) -> String {
154        "fake".to_string()
155    }
156
157    fn max_token_count(&self) -> usize {
158        1000000
159    }
160
161    fn count_tokens(
162        &self,
163        _: LanguageModelRequest,
164        _: &AppContext,
165    ) -> BoxFuture<'static, Result<usize>> {
166        futures::future::ready(Ok(0)).boxed()
167    }
168
169    fn stream_completion(
170        &self,
171        request: LanguageModelRequest,
172        _: &AsyncAppContext,
173    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
174        let (tx, rx) = mpsc::unbounded();
175        self.current_completion_txs.lock().push((request, tx));
176        async move {
177            Ok(rx
178                .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
179                .boxed())
180        }
181        .boxed()
182    }
183
184    fn use_any_tool(
185        &self,
186        request: LanguageModelRequest,
187        name: String,
188        description: String,
189        schema: serde_json::Value,
190        _cx: &AsyncAppContext,
191    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
192        let (tx, rx) = mpsc::unbounded();
193        let tool_call = ToolUseRequest {
194            request,
195            name,
196            description,
197            schema,
198        };
199        self.current_tool_use_txs.lock().push((tool_call, tx));
200        async move { Ok(rx.map(Ok).boxed()) }.boxed()
201    }
202
203    fn as_fake(&self) -> &Self {
204        self
205    }
206}