1use crate::{
2 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
3 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
4 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
5 LanguageModelRequest, LanguageModelToolChoice,
6};
7use anyhow::anyhow;
8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
10use http_client::Result;
11use parking_lot::Mutex;
12use smol::stream::StreamExt;
13use std::sync::{
14 Arc,
15 atomic::{AtomicBool, Ordering::SeqCst},
16};
17
18#[derive(Clone)]
19pub struct FakeLanguageModelProvider {
20 id: LanguageModelProviderId,
21 name: LanguageModelProviderName,
22 models: Vec<Arc<dyn LanguageModel>>,
23}
24
25impl Default for FakeLanguageModelProvider {
26 fn default() -> Self {
27 Self {
28 id: LanguageModelProviderId::from("fake".to_string()),
29 name: LanguageModelProviderName::from("Fake".to_string()),
30 models: vec![Arc::new(FakeLanguageModel::default())],
31 }
32 }
33}
34
35impl LanguageModelProviderState for FakeLanguageModelProvider {
36 type ObservableEntity = ();
37
38 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
39 None
40 }
41}
42
43impl LanguageModelProvider for FakeLanguageModelProvider {
44 fn id(&self) -> LanguageModelProviderId {
45 self.id.clone()
46 }
47
48 fn name(&self) -> LanguageModelProviderName {
49 self.name.clone()
50 }
51
52 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
53 self.models.first().cloned()
54 }
55
56 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
57 self.models.first().cloned()
58 }
59
60 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
61 self.models.clone()
62 }
63
64 fn is_authenticated(&self, _: &App) -> bool {
65 true
66 }
67
68 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
69 Task::ready(Ok(()))
70 }
71
72 fn configuration_view(
73 &self,
74 _target_agent: ConfigurationViewTargetAgent,
75 _window: &mut Window,
76 _: &mut App,
77 ) -> AnyView {
78 unimplemented!()
79 }
80
81 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
82 Task::ready(Ok(()))
83 }
84}
85
86impl FakeLanguageModelProvider {
87 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
88 Self {
89 id,
90 name,
91 models: vec![Arc::new(FakeLanguageModel::default())],
92 }
93 }
94
95 pub fn with_models(mut self, models: Vec<Arc<dyn LanguageModel>>) -> Self {
96 self.models = models;
97 self
98 }
99
100 pub fn test_model(&self) -> FakeLanguageModel {
101 FakeLanguageModel::default()
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub struct ToolUseRequest {
107 pub request: LanguageModelRequest,
108 pub name: String,
109 pub description: String,
110 pub schema: serde_json::Value,
111}
112
113pub struct FakeLanguageModel {
114 id: LanguageModelId,
115 name: LanguageModelName,
116 provider_id: LanguageModelProviderId,
117 provider_name: LanguageModelProviderName,
118 current_completion_txs: Mutex<
119 Vec<(
120 LanguageModelRequest,
121 mpsc::UnboundedSender<
122 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
123 >,
124 )>,
125 >,
126 forbid_requests: AtomicBool,
127 supports_thinking: AtomicBool,
128}
129
130impl Default for FakeLanguageModel {
131 fn default() -> Self {
132 Self {
133 id: LanguageModelId::from("fake".to_string()),
134 name: LanguageModelName::from("Fake".to_string()),
135 provider_id: LanguageModelProviderId::from("fake".to_string()),
136 provider_name: LanguageModelProviderName::from("Fake".to_string()),
137 current_completion_txs: Mutex::new(Vec::new()),
138 forbid_requests: AtomicBool::new(false),
139 supports_thinking: AtomicBool::new(false),
140 }
141 }
142}
143
144impl FakeLanguageModel {
145 pub fn with_id_and_thinking(
146 provider_id: &str,
147 id: &str,
148 name: &str,
149 supports_thinking: bool,
150 ) -> Self {
151 Self {
152 id: LanguageModelId::from(id.to_string()),
153 name: LanguageModelName::from(name.to_string()),
154 provider_id: LanguageModelProviderId::from(provider_id.to_string()),
155 supports_thinking: AtomicBool::new(supports_thinking),
156 ..Default::default()
157 }
158 }
159
160 pub fn allow_requests(&self) {
161 self.forbid_requests.store(false, SeqCst);
162 }
163
164 pub fn forbid_requests(&self) {
165 self.forbid_requests.store(true, SeqCst);
166 }
167
168 pub fn set_supports_thinking(&self, supports: bool) {
169 self.supports_thinking.store(supports, SeqCst);
170 }
171
172 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
173 self.current_completion_txs
174 .lock()
175 .iter()
176 .map(|(request, _)| request.clone())
177 .collect()
178 }
179
180 pub fn completion_count(&self) -> usize {
181 self.current_completion_txs.lock().len()
182 }
183
184 pub fn send_completion_stream_text_chunk(
185 &self,
186 request: &LanguageModelRequest,
187 chunk: impl Into<String>,
188 ) {
189 self.send_completion_stream_event(
190 request,
191 LanguageModelCompletionEvent::Text(chunk.into()),
192 );
193 }
194
195 pub fn send_completion_stream_event(
196 &self,
197 request: &LanguageModelRequest,
198 event: impl Into<LanguageModelCompletionEvent>,
199 ) {
200 let current_completion_txs = self.current_completion_txs.lock();
201 let tx = current_completion_txs
202 .iter()
203 .find(|(req, _)| req == request)
204 .map(|(_, tx)| tx)
205 .unwrap();
206 tx.unbounded_send(Ok(event.into())).unwrap();
207 }
208
209 pub fn send_completion_stream_error(
210 &self,
211 request: &LanguageModelRequest,
212 error: impl Into<LanguageModelCompletionError>,
213 ) {
214 let current_completion_txs = self.current_completion_txs.lock();
215 let tx = current_completion_txs
216 .iter()
217 .find(|(req, _)| req == request)
218 .map(|(_, tx)| tx)
219 .unwrap();
220 tx.unbounded_send(Err(error.into())).unwrap();
221 }
222
223 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
224 self.current_completion_txs
225 .lock()
226 .retain(|(req, _)| req != request);
227 }
228
229 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
230 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
231 }
232
233 pub fn send_last_completion_stream_event(
234 &self,
235 event: impl Into<LanguageModelCompletionEvent>,
236 ) {
237 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
238 }
239
240 pub fn send_last_completion_stream_error(
241 &self,
242 error: impl Into<LanguageModelCompletionError>,
243 ) {
244 self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
245 }
246
247 pub fn end_last_completion_stream(&self) {
248 self.end_completion_stream(self.pending_completions().last().unwrap());
249 }
250}
251
252impl LanguageModel for FakeLanguageModel {
253 fn id(&self) -> LanguageModelId {
254 self.id.clone()
255 }
256
257 fn name(&self) -> LanguageModelName {
258 self.name.clone()
259 }
260
261 fn provider_id(&self) -> LanguageModelProviderId {
262 self.provider_id.clone()
263 }
264
265 fn provider_name(&self) -> LanguageModelProviderName {
266 self.provider_name.clone()
267 }
268
269 fn supports_tools(&self) -> bool {
270 false
271 }
272
273 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
274 false
275 }
276
277 fn supports_images(&self) -> bool {
278 false
279 }
280
281 fn supports_thinking(&self) -> bool {
282 self.supports_thinking.load(SeqCst)
283 }
284
285 fn telemetry_id(&self) -> String {
286 "fake".to_string()
287 }
288
289 fn max_token_count(&self) -> u64 {
290 1000000
291 }
292
293 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
294 futures::future::ready(Ok(0)).boxed()
295 }
296
297 fn stream_completion(
298 &self,
299 request: LanguageModelRequest,
300 _: &AsyncApp,
301 ) -> BoxFuture<
302 'static,
303 Result<
304 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
305 LanguageModelCompletionError,
306 >,
307 > {
308 if self.forbid_requests.load(SeqCst) {
309 async move {
310 Err(LanguageModelCompletionError::Other(anyhow!(
311 "requests are forbidden"
312 )))
313 }
314 .boxed()
315 } else {
316 let (tx, rx) = mpsc::unbounded();
317 self.current_completion_txs.lock().push((request, tx));
318 async move { Ok(rx.boxed()) }.boxed()
319 }
320 }
321
322 fn as_fake(&self) -> &Self {
323 self
324 }
325}