1use crate::{
2 AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
3 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
4 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
5 LanguageModelRequest, LanguageModelToolChoice,
6};
7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
9use http_client::Result;
10use parking_lot::Mutex;
11use std::sync::Arc;
12
13#[derive(Clone)]
14pub struct FakeLanguageModelProvider {
15 id: LanguageModelProviderId,
16 name: LanguageModelProviderName,
17}
18
19impl Default for FakeLanguageModelProvider {
20 fn default() -> Self {
21 Self {
22 id: LanguageModelProviderId::from("fake".to_string()),
23 name: LanguageModelProviderName::from("Fake".to_string()),
24 }
25 }
26}
27
28impl LanguageModelProviderState for FakeLanguageModelProvider {
29 type ObservableEntity = ();
30
31 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
32 None
33 }
34}
35
36impl LanguageModelProvider for FakeLanguageModelProvider {
37 fn id(&self) -> LanguageModelProviderId {
38 self.id.clone()
39 }
40
41 fn name(&self) -> LanguageModelProviderName {
42 self.name.clone()
43 }
44
45 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
46 Some(Arc::new(FakeLanguageModel::default()))
47 }
48
49 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
50 Some(Arc::new(FakeLanguageModel::default()))
51 }
52
53 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
54 vec![Arc::new(FakeLanguageModel::default())]
55 }
56
57 fn is_authenticated(&self, _: &App) -> bool {
58 true
59 }
60
61 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
62 Task::ready(Ok(()))
63 }
64
65 fn configuration_view(
66 &self,
67 _target_agent: ConfigurationViewTargetAgent,
68 _window: &mut Window,
69 _: &mut App,
70 ) -> AnyView {
71 unimplemented!()
72 }
73
74 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
75 Task::ready(Ok(()))
76 }
77}
78
79impl FakeLanguageModelProvider {
80 pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
81 Self { id, name }
82 }
83
84 pub fn test_model(&self) -> FakeLanguageModel {
85 FakeLanguageModel::default()
86 }
87}
88
89#[derive(Debug, PartialEq)]
90pub struct ToolUseRequest {
91 pub request: LanguageModelRequest,
92 pub name: String,
93 pub description: String,
94 pub schema: serde_json::Value,
95}
96
97pub struct FakeLanguageModel {
98 provider_id: LanguageModelProviderId,
99 provider_name: LanguageModelProviderName,
100 current_completion_txs: Mutex<
101 Vec<(
102 LanguageModelRequest,
103 mpsc::UnboundedSender<LanguageModelCompletionEvent>,
104 )>,
105 >,
106}
107
108impl Default for FakeLanguageModel {
109 fn default() -> Self {
110 Self {
111 provider_id: LanguageModelProviderId::from("fake".to_string()),
112 provider_name: LanguageModelProviderName::from("Fake".to_string()),
113 current_completion_txs: Mutex::new(Vec::new()),
114 }
115 }
116}
117
118impl FakeLanguageModel {
119 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
120 self.current_completion_txs
121 .lock()
122 .iter()
123 .map(|(request, _)| request.clone())
124 .collect()
125 }
126
127 pub fn completion_count(&self) -> usize {
128 self.current_completion_txs.lock().len()
129 }
130
131 pub fn send_completion_stream_text_chunk(
132 &self,
133 request: &LanguageModelRequest,
134 chunk: impl Into<String>,
135 ) {
136 self.send_completion_stream_event(
137 request,
138 LanguageModelCompletionEvent::Text(chunk.into()),
139 );
140 }
141
142 pub fn send_completion_stream_event(
143 &self,
144 request: &LanguageModelRequest,
145 event: impl Into<LanguageModelCompletionEvent>,
146 ) {
147 let current_completion_txs = self.current_completion_txs.lock();
148 let tx = current_completion_txs
149 .iter()
150 .find(|(req, _)| req == request)
151 .map(|(_, tx)| tx)
152 .unwrap();
153 tx.unbounded_send(event.into()).unwrap();
154 }
155
156 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
157 self.current_completion_txs
158 .lock()
159 .retain(|(req, _)| req != request);
160 }
161
162 pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
163 self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
164 }
165
166 pub fn send_last_completion_stream_event(
167 &self,
168 event: impl Into<LanguageModelCompletionEvent>,
169 ) {
170 self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
171 }
172
173 pub fn end_last_completion_stream(&self) {
174 self.end_completion_stream(self.pending_completions().last().unwrap());
175 }
176}
177
178impl LanguageModel for FakeLanguageModel {
179 fn id(&self) -> LanguageModelId {
180 LanguageModelId::from("fake".to_string())
181 }
182
183 fn name(&self) -> LanguageModelName {
184 LanguageModelName::from("Fake".to_string())
185 }
186
187 fn provider_id(&self) -> LanguageModelProviderId {
188 self.provider_id.clone()
189 }
190
191 fn provider_name(&self) -> LanguageModelProviderName {
192 self.provider_name.clone()
193 }
194
195 fn supports_tools(&self) -> bool {
196 false
197 }
198
199 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
200 false
201 }
202
203 fn supports_images(&self) -> bool {
204 false
205 }
206
207 fn telemetry_id(&self) -> String {
208 "fake".to_string()
209 }
210
211 fn max_token_count(&self) -> u64 {
212 1000000
213 }
214
215 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
216 futures::future::ready(Ok(0)).boxed()
217 }
218
219 fn stream_completion(
220 &self,
221 request: LanguageModelRequest,
222 _: &AsyncApp,
223 ) -> BoxFuture<
224 'static,
225 Result<
226 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
227 LanguageModelCompletionError,
228 >,
229 > {
230 let (tx, rx) = mpsc::unbounded();
231 self.current_completion_txs.lock().push((request, tx));
232 async move { Ok(rx.map(Ok).boxed()) }.boxed()
233 }
234
235 fn as_fake(&self) -> &Self {
236 self
237 }
238}