1use crate::{
2 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
3 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
4 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
5 LanguageModelRequest, LanguageModelToolChoice,
6};
7use anyhow::anyhow;
8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
10use http_client::Result;
11use parking_lot::Mutex;
12use smol::stream::StreamExt;
13use std::sync::{
14 Arc,
15 atomic::{AtomicBool, Ordering::SeqCst},
16};
17
18#[derive(Clone)]
19pub struct FakeLanguageModelProvider {
20 id: LanguageModelProviderId,
21 name: LanguageModelProviderName,
22}
23
24impl Default for FakeLanguageModelProvider {
25 fn default() -> Self {
26 Self {
27 id: LanguageModelProviderId::from("fake".to_string()),
28 name: LanguageModelProviderName::from("Fake".to_string()),
29 }
30 }
31}
32
33impl LanguageModelProviderState for FakeLanguageModelProvider {
34 type ObservableEntity = ();
35
36 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
37 None
38 }
39}
40
41impl LanguageModelProvider for FakeLanguageModelProvider {
42 fn id(&self) -> LanguageModelProviderId {
43 self.id.clone()
44 }
45
46 fn name(&self) -> LanguageModelProviderName {
47 self.name.clone()
48 }
49
50 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
51 Some(Arc::new(FakeLanguageModel::default()))
52 }
53
54 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
55 Some(Arc::new(FakeLanguageModel::default()))
56 }
57
58 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
59 vec![Arc::new(FakeLanguageModel::default())]
60 }
61
62 fn is_authenticated(&self, _: &App) -> bool {
63 true
64 }
65
66 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
67 Task::ready(Ok(()))
68 }
69
70 fn configuration_view(
71 &self,
72 _target_agent: ConfigurationViewTargetAgent,
73 _window: &mut Window,
74 _: &mut App,
75 ) -> AnyView {
76 unimplemented!()
77 }
78
79 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
80 Task::ready(Ok(()))
81 }
82}
83
84impl FakeLanguageModelProvider {
85 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
86 Self { id, name }
87 }
88
89 pub fn test_model(&self) -> FakeLanguageModel {
90 FakeLanguageModel::default()
91 }
92}
93
94#[derive(Debug, PartialEq)]
95pub struct ToolUseRequest {
96 pub request: LanguageModelRequest,
97 pub name: String,
98 pub description: String,
99 pub schema: serde_json::Value,
100}
101
102pub struct FakeLanguageModel {
103 provider_id: LanguageModelProviderId,
104 provider_name: LanguageModelProviderName,
105 current_completion_txs: Mutex<
106 Vec<(
107 LanguageModelRequest,
108 mpsc::UnboundedSender<
109 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
110 >,
111 )>,
112 >,
113 forbid_requests: AtomicBool,
114}
115
116impl Default for FakeLanguageModel {
117 fn default() -> Self {
118 Self {
119 provider_id: LanguageModelProviderId::from("fake".to_string()),
120 provider_name: LanguageModelProviderName::from("Fake".to_string()),
121 current_completion_txs: Mutex::new(Vec::new()),
122 forbid_requests: AtomicBool::new(false),
123 }
124 }
125}
126
127impl FakeLanguageModel {
128 pub fn allow_requests(&self) {
129 self.forbid_requests.store(false, SeqCst);
130 }
131
132 pub fn forbid_requests(&self) {
133 self.forbid_requests.store(true, SeqCst);
134 }
135
136 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
137 self.current_completion_txs
138 .lock()
139 .iter()
140 .map(|(request, _)| request.clone())
141 .collect()
142 }
143
144 pub fn completion_count(&self) -> usize {
145 self.current_completion_txs.lock().len()
146 }
147
148 pub fn send_completion_stream_text_chunk(
149 &self,
150 request: &LanguageModelRequest,
151 chunk: impl Into<String>,
152 ) {
153 self.send_completion_stream_event(
154 request,
155 LanguageModelCompletionEvent::Text(chunk.into()),
156 );
157 }
158
159 pub fn send_completion_stream_event(
160 &self,
161 request: &LanguageModelRequest,
162 event: impl Into<LanguageModelCompletionEvent>,
163 ) {
164 let current_completion_txs = self.current_completion_txs.lock();
165 let tx = current_completion_txs
166 .iter()
167 .find(|(req, _)| req == request)
168 .map(|(_, tx)| tx)
169 .unwrap();
170 tx.unbounded_send(Ok(event.into())).unwrap();
171 }
172
173 pub fn send_completion_stream_error(
174 &self,
175 request: &LanguageModelRequest,
176 error: impl Into<LanguageModelCompletionError>,
177 ) {
178 let current_completion_txs = self.current_completion_txs.lock();
179 let tx = current_completion_txs
180 .iter()
181 .find(|(req, _)| req == request)
182 .map(|(_, tx)| tx)
183 .unwrap();
184 tx.unbounded_send(Err(error.into())).unwrap();
185 }
186
187 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
188 self.current_completion_txs
189 .lock()
190 .retain(|(req, _)| req != request);
191 }
192
193 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
194 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
195 }
196
197 pub fn send_last_completion_stream_event(
198 &self,
199 event: impl Into<LanguageModelCompletionEvent>,
200 ) {
201 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
202 }
203
204 pub fn send_last_completion_stream_error(
205 &self,
206 error: impl Into<LanguageModelCompletionError>,
207 ) {
208 self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
209 }
210
211 pub fn end_last_completion_stream(&self) {
212 self.end_completion_stream(self.pending_completions().last().unwrap());
213 }
214}
215
216impl LanguageModel for FakeLanguageModel {
217 fn id(&self) -> LanguageModelId {
218 LanguageModelId::from("fake".to_string())
219 }
220
221 fn name(&self) -> LanguageModelName {
222 LanguageModelName::from("Fake".to_string())
223 }
224
225 fn provider_id(&self) -> LanguageModelProviderId {
226 self.provider_id.clone()
227 }
228
229 fn provider_name(&self) -> LanguageModelProviderName {
230 self.provider_name.clone()
231 }
232
233 fn supports_tools(&self) -> bool {
234 false
235 }
236
237 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
238 false
239 }
240
241 fn supports_images(&self) -> bool {
242 false
243 }
244
245 fn telemetry_id(&self) -> String {
246 "fake".to_string()
247 }
248
249 fn max_token_count(&self) -> u64 {
250 1000000
251 }
252
253 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
254 futures::future::ready(Ok(0)).boxed()
255 }
256
257 fn stream_completion(
258 &self,
259 request: LanguageModelRequest,
260 _: &AsyncApp,
261 ) -> BoxFuture<
262 'static,
263 Result<
264 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
265 LanguageModelCompletionError,
266 >,
267 > {
268 if self.forbid_requests.load(SeqCst) {
269 async move {
270 Err(LanguageModelCompletionError::Other(anyhow!(
271 "requests are forbidden"
272 )))
273 }
274 .boxed()
275 } else {
276 let (tx, rx) = mpsc::unbounded();
277 self.current_completion_txs.lock().push((request, tx));
278 async move { Ok(rx.boxed()) }.boxed()
279 }
280 }
281
282 fn as_fake(&self) -> &Self {
283 self
284 }
285}