1use crate::{
2 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
3 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
4 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
5 LanguageModelRequest, LanguageModelToolChoice,
6};
7use anyhow::anyhow;
8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
10use http_client::Result;
11use parking_lot::Mutex;
12use smol::stream::StreamExt;
13use std::sync::{
14 Arc,
15 atomic::{AtomicBool, Ordering::SeqCst},
16};
17
18#[derive(Clone)]
19pub struct FakeLanguageModelProvider {
20 id: LanguageModelProviderId,
21 name: LanguageModelProviderName,
22 models: Vec<Arc<dyn LanguageModel>>,
23}
24
25impl Default for FakeLanguageModelProvider {
26 fn default() -> Self {
27 Self {
28 id: LanguageModelProviderId::from("fake".to_string()),
29 name: LanguageModelProviderName::from("Fake".to_string()),
30 models: vec![Arc::new(FakeLanguageModel::default())],
31 }
32 }
33}
34
35impl LanguageModelProviderState for FakeLanguageModelProvider {
36 type ObservableEntity = ();
37
38 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
39 None
40 }
41}
42
43impl LanguageModelProvider for FakeLanguageModelProvider {
44 fn id(&self) -> LanguageModelProviderId {
45 self.id.clone()
46 }
47
48 fn name(&self) -> LanguageModelProviderName {
49 self.name.clone()
50 }
51
52 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
53 self.models.first().cloned()
54 }
55
56 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
57 self.models.first().cloned()
58 }
59
60 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
61 self.models.clone()
62 }
63
64 fn is_authenticated(&self, _: &App) -> bool {
65 true
66 }
67
68 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
69 Task::ready(Ok(()))
70 }
71
72 fn configuration_view(
73 &self,
74 _target_agent: ConfigurationViewTargetAgent,
75 _window: &mut Window,
76 _: &mut App,
77 ) -> AnyView {
78 unimplemented!()
79 }
80
81 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
82 Task::ready(Ok(()))
83 }
84}
85
86impl FakeLanguageModelProvider {
87 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
88 Self {
89 id,
90 name,
91 models: vec![Arc::new(FakeLanguageModel::default())],
92 }
93 }
94
95 pub fn with_models(mut self, models: Vec<Arc<dyn LanguageModel>>) -> Self {
96 self.models = models;
97 self
98 }
99
100 pub fn test_model(&self) -> FakeLanguageModel {
101 FakeLanguageModel::default()
102 }
103}
104
105#[derive(Debug, PartialEq)]
106pub struct ToolUseRequest {
107 pub request: LanguageModelRequest,
108 pub name: String,
109 pub description: String,
110 pub schema: serde_json::Value,
111}
112
113pub struct FakeLanguageModel {
114 id: LanguageModelId,
115 name: LanguageModelName,
116 provider_id: LanguageModelProviderId,
117 provider_name: LanguageModelProviderName,
118 current_completion_txs: Mutex<
119 Vec<(
120 LanguageModelRequest,
121 mpsc::UnboundedSender<
122 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
123 >,
124 )>,
125 >,
126 forbid_requests: AtomicBool,
127 supports_thinking: AtomicBool,
128 supports_streaming_tools: AtomicBool,
129}
130
131impl Default for FakeLanguageModel {
132 fn default() -> Self {
133 Self {
134 id: LanguageModelId::from("fake".to_string()),
135 name: LanguageModelName::from("Fake".to_string()),
136 provider_id: LanguageModelProviderId::from("fake".to_string()),
137 provider_name: LanguageModelProviderName::from("Fake".to_string()),
138 current_completion_txs: Mutex::new(Vec::new()),
139 forbid_requests: AtomicBool::new(false),
140 supports_thinking: AtomicBool::new(false),
141 supports_streaming_tools: AtomicBool::new(false),
142 }
143 }
144}
145
146impl FakeLanguageModel {
147 pub fn with_id_and_thinking(
148 provider_id: &str,
149 id: &str,
150 name: &str,
151 supports_thinking: bool,
152 ) -> Self {
153 Self {
154 id: LanguageModelId::from(id.to_string()),
155 name: LanguageModelName::from(name.to_string()),
156 provider_id: LanguageModelProviderId::from(provider_id.to_string()),
157 supports_thinking: AtomicBool::new(supports_thinking),
158 ..Default::default()
159 }
160 }
161
162 pub fn allow_requests(&self) {
163 self.forbid_requests.store(false, SeqCst);
164 }
165
166 pub fn forbid_requests(&self) {
167 self.forbid_requests.store(true, SeqCst);
168 }
169
170 pub fn set_supports_thinking(&self, supports: bool) {
171 self.supports_thinking.store(supports, SeqCst);
172 }
173
174 pub fn set_supports_streaming_tools(&self, supports: bool) {
175 self.supports_streaming_tools.store(supports, SeqCst);
176 }
177
178 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
179 self.current_completion_txs
180 .lock()
181 .iter()
182 .map(|(request, _)| request.clone())
183 .collect()
184 }
185
186 pub fn completion_count(&self) -> usize {
187 self.current_completion_txs.lock().len()
188 }
189
190 pub fn send_completion_stream_text_chunk(
191 &self,
192 request: &LanguageModelRequest,
193 chunk: impl Into<String>,
194 ) {
195 self.send_completion_stream_event(
196 request,
197 LanguageModelCompletionEvent::Text(chunk.into()),
198 );
199 }
200
201 pub fn send_completion_stream_event(
202 &self,
203 request: &LanguageModelRequest,
204 event: impl Into<LanguageModelCompletionEvent>,
205 ) {
206 let current_completion_txs = self.current_completion_txs.lock();
207 let tx = current_completion_txs
208 .iter()
209 .find(|(req, _)| req == request)
210 .map(|(_, tx)| tx)
211 .unwrap();
212 tx.unbounded_send(Ok(event.into())).unwrap();
213 }
214
215 pub fn send_completion_stream_error(
216 &self,
217 request: &LanguageModelRequest,
218 error: impl Into<LanguageModelCompletionError>,
219 ) {
220 let current_completion_txs = self.current_completion_txs.lock();
221 let tx = current_completion_txs
222 .iter()
223 .find(|(req, _)| req == request)
224 .map(|(_, tx)| tx)
225 .unwrap();
226 tx.unbounded_send(Err(error.into())).unwrap();
227 }
228
229 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
230 self.current_completion_txs
231 .lock()
232 .retain(|(req, _)| req != request);
233 }
234
235 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
236 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
237 }
238
239 pub fn send_last_completion_stream_event(
240 &self,
241 event: impl Into<LanguageModelCompletionEvent>,
242 ) {
243 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
244 }
245
246 pub fn send_last_completion_stream_error(
247 &self,
248 error: impl Into<LanguageModelCompletionError>,
249 ) {
250 self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
251 }
252
253 pub fn end_last_completion_stream(&self) {
254 self.end_completion_stream(self.pending_completions().last().unwrap());
255 }
256}
257
258impl LanguageModel for FakeLanguageModel {
259 fn id(&self) -> LanguageModelId {
260 self.id.clone()
261 }
262
263 fn name(&self) -> LanguageModelName {
264 self.name.clone()
265 }
266
267 fn provider_id(&self) -> LanguageModelProviderId {
268 self.provider_id.clone()
269 }
270
271 fn provider_name(&self) -> LanguageModelProviderName {
272 self.provider_name.clone()
273 }
274
275 fn supports_tools(&self) -> bool {
276 false
277 }
278
279 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
280 false
281 }
282
283 fn supports_images(&self) -> bool {
284 false
285 }
286
287 fn supports_thinking(&self) -> bool {
288 self.supports_thinking.load(SeqCst)
289 }
290
291 fn supports_streaming_tools(&self) -> bool {
292 self.supports_streaming_tools.load(SeqCst)
293 }
294
295 fn telemetry_id(&self) -> String {
296 "fake".to_string()
297 }
298
299 fn max_token_count(&self) -> u64 {
300 1000000
301 }
302
303 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
304 futures::future::ready(Ok(0)).boxed()
305 }
306
307 fn stream_completion(
308 &self,
309 request: LanguageModelRequest,
310 _: &AsyncApp,
311 ) -> BoxFuture<
312 'static,
313 Result<
314 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
315 LanguageModelCompletionError,
316 >,
317 > {
318 if self.forbid_requests.load(SeqCst) {
319 async move {
320 Err(LanguageModelCompletionError::Other(anyhow!(
321 "requests are forbidden"
322 )))
323 }
324 .boxed()
325 } else {
326 let (tx, rx) = mpsc::unbounded();
327 self.current_completion_txs.lock().push((request, tx));
328 async move { Ok(rx.boxed()) }.boxed()
329 }
330 }
331
332 fn as_fake(&self) -> &Self {
333 self
334 }
335}