1use crate::{
2 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
3 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
4 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
5 LanguageModelRequest, LanguageModelToolChoice,
6};
7use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
9use http_client::Result;
10use parking_lot::Mutex;
11use smol::stream::StreamExt;
12use std::sync::Arc;
13
14#[derive(Clone)]
15pub struct FakeLanguageModelProvider {
16 id: LanguageModelProviderId,
17 name: LanguageModelProviderName,
18}
19
20impl Default for FakeLanguageModelProvider {
21 fn default() -> Self {
22 Self {
23 id: LanguageModelProviderId::from("fake".to_string()),
24 name: LanguageModelProviderName::from("Fake".to_string()),
25 }
26 }
27}
28
29impl LanguageModelProviderState for FakeLanguageModelProvider {
30 type ObservableEntity = ();
31
32 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
33 None
34 }
35}
36
37impl LanguageModelProvider for FakeLanguageModelProvider {
38 fn id(&self) -> LanguageModelProviderId {
39 self.id.clone()
40 }
41
42 fn name(&self) -> LanguageModelProviderName {
43 self.name.clone()
44 }
45
46 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
47 Some(Arc::new(FakeLanguageModel::default()))
48 }
49
50 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
51 Some(Arc::new(FakeLanguageModel::default()))
52 }
53
54 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
55 vec![Arc::new(FakeLanguageModel::default())]
56 }
57
58 fn is_authenticated(&self, _: &App) -> bool {
59 true
60 }
61
62 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
63 Task::ready(Ok(()))
64 }
65
66 fn configuration_view(
67 &self,
68 _target_agent: ConfigurationViewTargetAgent,
69 _window: &mut Window,
70 _: &mut App,
71 ) -> AnyView {
72 unimplemented!()
73 }
74
75 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
76 Task::ready(Ok(()))
77 }
78}
79
80impl FakeLanguageModelProvider {
81 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
82 Self { id, name }
83 }
84
85 pub fn test_model(&self) -> FakeLanguageModel {
86 FakeLanguageModel::default()
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub struct ToolUseRequest {
92 pub request: LanguageModelRequest,
93 pub name: String,
94 pub description: String,
95 pub schema: serde_json::Value,
96}
97
98pub struct FakeLanguageModel {
99 provider_id: LanguageModelProviderId,
100 provider_name: LanguageModelProviderName,
101 current_completion_txs: Mutex<
102 Vec<(
103 LanguageModelRequest,
104 mpsc::UnboundedSender<
105 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
106 >,
107 )>,
108 >,
109}
110
111impl Default for FakeLanguageModel {
112 fn default() -> Self {
113 Self {
114 provider_id: LanguageModelProviderId::from("fake".to_string()),
115 provider_name: LanguageModelProviderName::from("Fake".to_string()),
116 current_completion_txs: Mutex::new(Vec::new()),
117 }
118 }
119}
120
121impl FakeLanguageModel {
122 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
123 self.current_completion_txs
124 .lock()
125 .iter()
126 .map(|(request, _)| request.clone())
127 .collect()
128 }
129
130 pub fn completion_count(&self) -> usize {
131 self.current_completion_txs.lock().len()
132 }
133
134 pub fn send_completion_stream_text_chunk(
135 &self,
136 request: &LanguageModelRequest,
137 chunk: impl Into<String>,
138 ) {
139 self.send_completion_stream_event(
140 request,
141 LanguageModelCompletionEvent::Text(chunk.into()),
142 );
143 }
144
145 pub fn send_completion_stream_event(
146 &self,
147 request: &LanguageModelRequest,
148 event: impl Into<LanguageModelCompletionEvent>,
149 ) {
150 let current_completion_txs = self.current_completion_txs.lock();
151 let tx = current_completion_txs
152 .iter()
153 .find(|(req, _)| req == request)
154 .map(|(_, tx)| tx)
155 .unwrap();
156 tx.unbounded_send(Ok(event.into())).unwrap();
157 }
158
159 pub fn send_completion_stream_error(
160 &self,
161 request: &LanguageModelRequest,
162 error: impl Into<LanguageModelCompletionError>,
163 ) {
164 let current_completion_txs = self.current_completion_txs.lock();
165 let tx = current_completion_txs
166 .iter()
167 .find(|(req, _)| req == request)
168 .map(|(_, tx)| tx)
169 .unwrap();
170 tx.unbounded_send(Err(error.into())).unwrap();
171 }
172
173 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
174 self.current_completion_txs
175 .lock()
176 .retain(|(req, _)| req != request);
177 }
178
179 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
180 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
181 }
182
183 pub fn send_last_completion_stream_event(
184 &self,
185 event: impl Into<LanguageModelCompletionEvent>,
186 ) {
187 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
188 }
189
190 pub fn send_last_completion_stream_error(
191 &self,
192 error: impl Into<LanguageModelCompletionError>,
193 ) {
194 self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
195 }
196
197 pub fn end_last_completion_stream(&self) {
198 self.end_completion_stream(self.pending_completions().last().unwrap());
199 }
200}
201
202impl LanguageModel for FakeLanguageModel {
203 fn id(&self) -> LanguageModelId {
204 LanguageModelId::from("fake".to_string())
205 }
206
207 fn name(&self) -> LanguageModelName {
208 LanguageModelName::from("Fake".to_string())
209 }
210
211 fn provider_id(&self) -> LanguageModelProviderId {
212 self.provider_id.clone()
213 }
214
215 fn provider_name(&self) -> LanguageModelProviderName {
216 self.provider_name.clone()
217 }
218
219 fn supports_tools(&self) -> bool {
220 false
221 }
222
223 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
224 false
225 }
226
227 fn supports_images(&self) -> bool {
228 false
229 }
230
231 fn telemetry_id(&self) -> String {
232 "fake".to_string()
233 }
234
235 fn max_token_count(&self) -> u64 {
236 1000000
237 }
238
239 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
240 futures::future::ready(Ok(0)).boxed()
241 }
242
243 fn stream_completion(
244 &self,
245 request: LanguageModelRequest,
246 _: &AsyncApp,
247 ) -> BoxFuture<
248 'static,
249 Result<
250 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
251 LanguageModelCompletionError,
252 >,
253 > {
254 let (tx, rx) = mpsc::unbounded();
255 self.current_completion_txs.lock().push((request, tx));
256 async move { Ok(rx.boxed()) }.boxed()
257 }
258
259 fn as_fake(&self) -> &Self {
260 self
261 }
262}