1use crate::{
2 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
3 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
4 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
5 LanguageModelToolChoice,
6};
7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
9use http_client::Result;
10use parking_lot::Mutex;
11use std::sync::Arc;
12
13#[derive(Clone)]
14pub struct FakeLanguageModelProvider {
15 id: LanguageModelProviderId,
16 name: LanguageModelProviderName,
17}
18
19impl Default for FakeLanguageModelProvider {
20 fn default() -> Self {
21 Self {
22 id: LanguageModelProviderId::from("fake".to_string()),
23 name: LanguageModelProviderName::from("Fake".to_string()),
24 }
25 }
26}
27
28impl LanguageModelProviderState for FakeLanguageModelProvider {
29 type ObservableEntity = ();
30
31 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
32 None
33 }
34}
35
36impl LanguageModelProvider for FakeLanguageModelProvider {
37 fn id(&self) -> LanguageModelProviderId {
38 self.id.clone()
39 }
40
41 fn name(&self) -> LanguageModelProviderName {
42 self.name.clone()
43 }
44
45 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
46 Some(Arc::new(FakeLanguageModel::default()))
47 }
48
49 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
50 Some(Arc::new(FakeLanguageModel::default()))
51 }
52
53 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
54 vec![Arc::new(FakeLanguageModel::default())]
55 }
56
57 fn is_authenticated(&self, _: &App) -> bool {
58 true
59 }
60
61 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
62 Task::ready(Ok(()))
63 }
64
65 fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
66 unimplemented!()
67 }
68
69 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
70 Task::ready(Ok(()))
71 }
72}
73
74impl FakeLanguageModelProvider {
75 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
76 Self { id, name }
77 }
78
79 pub fn test_model(&self) -> FakeLanguageModel {
80 FakeLanguageModel::default()
81 }
82}
83
84#[derive(Debug, PartialEq)]
85pub struct ToolUseRequest {
86 pub request: LanguageModelRequest,
87 pub name: String,
88 pub description: String,
89 pub schema: serde_json::Value,
90}
91
92pub struct FakeLanguageModel {
93 provider_id: LanguageModelProviderId,
94 provider_name: LanguageModelProviderName,
95 current_completion_txs: Mutex<
96 Vec<(
97 LanguageModelRequest,
98 mpsc::UnboundedSender<LanguageModelCompletionEvent>,
99 )>,
100 >,
101}
102
103impl Default for FakeLanguageModel {
104 fn default() -> Self {
105 dbg!("default......");
106 eprintln!("{}", std::backtrace::Backtrace::force_capture());
107 Self {
108 provider_id: LanguageModelProviderId::from("fake".to_string()),
109 provider_name: LanguageModelProviderName::from("Fake".to_string()),
110 current_completion_txs: Mutex::new(Vec::new()),
111 }
112 }
113}
114
115impl FakeLanguageModel {
116 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
117 self.current_completion_txs
118 .lock()
119 .iter()
120 .map(|(request, _)| request.clone())
121 .collect()
122 }
123
124 pub fn completion_count(&self) -> usize {
125 self.current_completion_txs.lock().len()
126 }
127
128 pub fn send_completion_stream_text_chunk(
129 &self,
130 request: &LanguageModelRequest,
131 chunk: impl Into<String>,
132 ) {
133 self.send_completion_stream_event(
134 request,
135 LanguageModelCompletionEvent::Text(chunk.into()),
136 );
137 }
138
139 pub fn send_completion_stream_event(
140 &self,
141 request: &LanguageModelRequest,
142 event: impl Into<LanguageModelCompletionEvent>,
143 ) {
144 let current_completion_txs = self.current_completion_txs.lock();
145 let tx = current_completion_txs
146 .iter()
147 .find(|(req, _)| req == request)
148 .map(|(_, tx)| tx)
149 .unwrap();
150 tx.unbounded_send(event.into()).unwrap();
151 }
152
153 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
154 dbg!("remove...");
155 self.current_completion_txs
156 .lock()
157 .retain(|(req, _)| req != request);
158 }
159
160 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
161 dbg!("read...");
162 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
163 }
164
165 pub fn send_last_completion_stream_event(
166 &self,
167 event: impl Into<LanguageModelCompletionEvent>,
168 ) {
169 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
170 }
171
172 pub fn end_last_completion_stream(&self) {
173 self.end_completion_stream(self.pending_completions().last().unwrap());
174 }
175}
176
177impl LanguageModel for FakeLanguageModel {
178 fn id(&self) -> LanguageModelId {
179 LanguageModelId::from("fake".to_string())
180 }
181
182 fn name(&self) -> LanguageModelName {
183 LanguageModelName::from("Fake".to_string())
184 }
185
186 fn provider_id(&self) -> LanguageModelProviderId {
187 self.provider_id.clone()
188 }
189
190 fn provider_name(&self) -> LanguageModelProviderName {
191 self.provider_name.clone()
192 }
193
194 fn supports_tools(&self) -> bool {
195 false
196 }
197
198 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
199 false
200 }
201
202 fn supports_images(&self) -> bool {
203 false
204 }
205
206 fn telemetry_id(&self) -> String {
207 "fake".to_string()
208 }
209
210 fn max_token_count(&self) -> u64 {
211 1000000
212 }
213
214 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
215 futures::future::ready(Ok(0)).boxed()
216 }
217
218 fn stream_completion(
219 &self,
220 request: LanguageModelRequest,
221 _: &AsyncApp,
222 ) -> BoxFuture<
223 'static,
224 Result<
225 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
226 LanguageModelCompletionError,
227 >,
228 > {
229 let (tx, rx) = mpsc::unbounded();
230 dbg!("insert...");
231 self.current_completion_txs.lock().push((request, tx));
232 async move { Ok(rx.map(Ok).boxed()) }.boxed()
233 }
234
235 fn as_fake(&self) -> &Self {
236 self
237 }
238}