1use crate::{
2 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
3 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
4 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
5 LanguageModelRequest, LanguageModelToolChoice,
6};
7use anyhow::anyhow;
8use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt};
9use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
10use http_client::Result;
11use parking_lot::Mutex;
12use std::sync::{
13 Arc,
14 atomic::{AtomicBool, Ordering::SeqCst},
15};
16
17#[derive(Clone)]
18pub struct FakeLanguageModelProvider {
19 id: LanguageModelProviderId,
20 name: LanguageModelProviderName,
21 models: Vec<Arc<dyn LanguageModel>>,
22}
23
24impl Default for FakeLanguageModelProvider {
25 fn default() -> Self {
26 Self {
27 id: LanguageModelProviderId::from("fake".to_string()),
28 name: LanguageModelProviderName::from("Fake".to_string()),
29 models: vec![Arc::new(FakeLanguageModel::default())],
30 }
31 }
32}
33
34impl LanguageModelProviderState for FakeLanguageModelProvider {
35 type ObservableEntity = ();
36
37 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
38 None
39 }
40}
41
42impl LanguageModelProvider for FakeLanguageModelProvider {
43 fn id(&self) -> LanguageModelProviderId {
44 self.id.clone()
45 }
46
47 fn name(&self) -> LanguageModelProviderName {
48 self.name.clone()
49 }
50
51 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
52 self.models.first().cloned()
53 }
54
55 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
56 self.models.first().cloned()
57 }
58
59 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
60 self.models.clone()
61 }
62
63 fn is_authenticated(&self, _: &App) -> bool {
64 true
65 }
66
67 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
68 Task::ready(Ok(()))
69 }
70
71 fn configuration_view(
72 &self,
73 _target_agent: ConfigurationViewTargetAgent,
74 _window: &mut Window,
75 _: &mut App,
76 ) -> AnyView {
77 unimplemented!()
78 }
79
80 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
81 Task::ready(Ok(()))
82 }
83}
84
85impl FakeLanguageModelProvider {
86 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
87 Self {
88 id,
89 name,
90 models: vec![Arc::new(FakeLanguageModel::default())],
91 }
92 }
93
94 pub fn with_models(mut self, models: Vec<Arc<dyn LanguageModel>>) -> Self {
95 self.models = models;
96 self
97 }
98
99 pub fn test_model(&self) -> FakeLanguageModel {
100 FakeLanguageModel::default()
101 }
102}
103
104#[derive(Debug, PartialEq)]
105pub struct ToolUseRequest {
106 pub request: LanguageModelRequest,
107 pub name: String,
108 pub description: String,
109 pub schema: serde_json::Value,
110}
111
112pub struct FakeLanguageModel {
113 id: LanguageModelId,
114 name: LanguageModelName,
115 provider_id: LanguageModelProviderId,
116 provider_name: LanguageModelProviderName,
117 current_completion_txs: Mutex<
118 Vec<(
119 LanguageModelRequest,
120 mpsc::UnboundedSender<
121 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
122 >,
123 )>,
124 >,
125 forbid_requests: AtomicBool,
126 supports_thinking: AtomicBool,
127 supports_streaming_tools: AtomicBool,
128}
129
130impl Default for FakeLanguageModel {
131 fn default() -> Self {
132 Self {
133 id: LanguageModelId::from("fake".to_string()),
134 name: LanguageModelName::from("Fake".to_string()),
135 provider_id: LanguageModelProviderId::from("fake".to_string()),
136 provider_name: LanguageModelProviderName::from("Fake".to_string()),
137 current_completion_txs: Mutex::new(Vec::new()),
138 forbid_requests: AtomicBool::new(false),
139 supports_thinking: AtomicBool::new(false),
140 supports_streaming_tools: AtomicBool::new(false),
141 }
142 }
143}
144
145impl FakeLanguageModel {
146 pub fn with_id_and_thinking(
147 provider_id: &str,
148 id: &str,
149 name: &str,
150 supports_thinking: bool,
151 ) -> Self {
152 Self {
153 id: LanguageModelId::from(id.to_string()),
154 name: LanguageModelName::from(name.to_string()),
155 provider_id: LanguageModelProviderId::from(provider_id.to_string()),
156 supports_thinking: AtomicBool::new(supports_thinking),
157 ..Default::default()
158 }
159 }
160
161 pub fn allow_requests(&self) {
162 self.forbid_requests.store(false, SeqCst);
163 }
164
165 pub fn forbid_requests(&self) {
166 self.forbid_requests.store(true, SeqCst);
167 }
168
169 pub fn set_supports_thinking(&self, supports: bool) {
170 self.supports_thinking.store(supports, SeqCst);
171 }
172
173 pub fn set_supports_streaming_tools(&self, supports: bool) {
174 self.supports_streaming_tools.store(supports, SeqCst);
175 }
176
177 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
178 self.current_completion_txs
179 .lock()
180 .iter()
181 .map(|(request, _)| request.clone())
182 .collect()
183 }
184
185 pub fn completion_count(&self) -> usize {
186 self.current_completion_txs.lock().len()
187 }
188
189 pub fn send_completion_stream_text_chunk(
190 &self,
191 request: &LanguageModelRequest,
192 chunk: impl Into<String>,
193 ) {
194 self.send_completion_stream_event(
195 request,
196 LanguageModelCompletionEvent::Text(chunk.into()),
197 );
198 }
199
200 pub fn send_completion_stream_event(
201 &self,
202 request: &LanguageModelRequest,
203 event: impl Into<LanguageModelCompletionEvent>,
204 ) {
205 let current_completion_txs = self.current_completion_txs.lock();
206 let tx = current_completion_txs
207 .iter()
208 .find(|(req, _)| req == request)
209 .map(|(_, tx)| tx)
210 .unwrap();
211 tx.unbounded_send(Ok(event.into())).unwrap();
212 }
213
214 pub fn send_completion_stream_error(
215 &self,
216 request: &LanguageModelRequest,
217 error: impl Into<LanguageModelCompletionError>,
218 ) {
219 let current_completion_txs = self.current_completion_txs.lock();
220 let tx = current_completion_txs
221 .iter()
222 .find(|(req, _)| req == request)
223 .map(|(_, tx)| tx)
224 .unwrap();
225 tx.unbounded_send(Err(error.into())).unwrap();
226 }
227
228 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
229 self.current_completion_txs
230 .lock()
231 .retain(|(req, _)| req != request);
232 }
233
234 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
235 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
236 }
237
238 pub fn send_last_completion_stream_event(
239 &self,
240 event: impl Into<LanguageModelCompletionEvent>,
241 ) {
242 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
243 }
244
245 pub fn send_last_completion_stream_error(
246 &self,
247 error: impl Into<LanguageModelCompletionError>,
248 ) {
249 self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
250 }
251
252 pub fn end_last_completion_stream(&self) {
253 self.end_completion_stream(self.pending_completions().last().unwrap());
254 }
255}
256
257impl LanguageModel for FakeLanguageModel {
258 fn id(&self) -> LanguageModelId {
259 self.id.clone()
260 }
261
262 fn name(&self) -> LanguageModelName {
263 self.name.clone()
264 }
265
266 fn provider_id(&self) -> LanguageModelProviderId {
267 self.provider_id.clone()
268 }
269
270 fn provider_name(&self) -> LanguageModelProviderName {
271 self.provider_name.clone()
272 }
273
274 fn supports_tools(&self) -> bool {
275 false
276 }
277
278 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
279 false
280 }
281
282 fn supports_images(&self) -> bool {
283 false
284 }
285
286 fn supports_thinking(&self) -> bool {
287 self.supports_thinking.load(SeqCst)
288 }
289
290 fn supports_streaming_tools(&self) -> bool {
291 self.supports_streaming_tools.load(SeqCst)
292 }
293
294 fn telemetry_id(&self) -> String {
295 "fake".to_string()
296 }
297
298 fn max_token_count(&self) -> u64 {
299 1000000
300 }
301
302 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
303 futures::future::ready(Ok(0)).boxed()
304 }
305
306 fn stream_completion(
307 &self,
308 request: LanguageModelRequest,
309 _: &AsyncApp,
310 ) -> BoxFuture<
311 'static,
312 Result<
313 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
314 LanguageModelCompletionError,
315 >,
316 > {
317 if self.forbid_requests.load(SeqCst) {
318 async move {
319 Err(LanguageModelCompletionError::Other(anyhow!(
320 "requests are forbidden"
321 )))
322 }
323 .boxed()
324 } else {
325 let (tx, rx) = mpsc::unbounded();
326 self.current_completion_txs.lock().push((request, tx));
327 async move { Ok(rx.boxed()) }.boxed()
328 }
329 }
330
331 fn as_fake(&self) -> &Self {
332 self
333 }
334}