1use crate::{
2 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
3 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
4 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
5 LanguageModelToolChoice, LanguageModelToolUse, StopReason,
6};
7use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
8use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
9use http_client::Result;
10use parking_lot::Mutex;
11use std::sync::Arc;
12
13pub fn language_model_id() -> LanguageModelId {
14 LanguageModelId::from("fake".to_string())
15}
16
17pub fn language_model_name() -> LanguageModelName {
18 LanguageModelName::from("Fake".to_string())
19}
20
21pub fn provider_id() -> LanguageModelProviderId {
22 LanguageModelProviderId::from("fake".to_string())
23}
24
25pub fn provider_name() -> LanguageModelProviderName {
26 LanguageModelProviderName::from("Fake".to_string())
27}
28
29#[derive(Clone, Default)]
30pub struct FakeLanguageModelProvider;
31
32impl LanguageModelProviderState for FakeLanguageModelProvider {
33 type ObservableEntity = ();
34
35 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
36 None
37 }
38}
39
40impl LanguageModelProvider for FakeLanguageModelProvider {
41 fn id(&self) -> LanguageModelProviderId {
42 provider_id()
43 }
44
45 fn name(&self) -> LanguageModelProviderName {
46 provider_name()
47 }
48
49 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
50 Some(Arc::new(FakeLanguageModel::default()))
51 }
52
53 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
54 Some(Arc::new(FakeLanguageModel::default()))
55 }
56
57 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
58 vec![Arc::new(FakeLanguageModel::default())]
59 }
60
61 fn is_authenticated(&self, _: &App) -> bool {
62 true
63 }
64
65 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
66 Task::ready(Ok(()))
67 }
68
69 fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
70 unimplemented!()
71 }
72
73 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
74 Task::ready(Ok(()))
75 }
76}
77
78impl FakeLanguageModelProvider {
79 pub fn test_model(&self) -> FakeLanguageModel {
80 FakeLanguageModel::default()
81 }
82}
83
84#[derive(Debug, PartialEq)]
85pub struct ToolUseRequest {
86 pub request: LanguageModelRequest,
87 pub name: String,
88 pub description: String,
89 pub schema: serde_json::Value,
90}
91
92#[derive(Default)]
93pub struct FakeLanguageModel {
94 current_completion_txs: Mutex<
95 Vec<(
96 LanguageModelRequest,
97 mpsc::UnboundedSender<LanguageModelCompletionEvent>,
98 )>,
99 >,
100}
101
102impl FakeLanguageModel {
103 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
104 self.current_completion_txs
105 .lock()
106 .iter()
107 .map(|(request, _)| request.clone())
108 .collect()
109 }
110
111 pub fn completion_count(&self) -> usize {
112 self.current_completion_txs.lock().len()
113 }
114
115 pub fn stream_completion_response(
116 &self,
117 request: &LanguageModelRequest,
118 stream: impl Into<FakeLanguageModelStream>,
119 ) {
120 let current_completion_txs = self.current_completion_txs.lock();
121 let tx = current_completion_txs
122 .iter()
123 .find(|(req, _)| req == request)
124 .map(|(_, tx)| tx)
125 .unwrap();
126 for event in stream.into().events {
127 tx.unbounded_send(event).unwrap();
128 }
129 }
130
131 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
132 self.current_completion_txs
133 .lock()
134 .retain(|(req, _)| req != request);
135 }
136
137 pub fn stream_last_completion_response(&self, chunk: impl Into<FakeLanguageModelStream>) {
138 self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
139 }
140
141 pub fn end_last_completion_stream(&self) {
142 self.end_completion_stream(self.pending_completions().last().unwrap());
143 }
144}
145
146pub struct FakeLanguageModelStream {
147 events: Vec<LanguageModelCompletionEvent>,
148}
149
150impl<T: Into<String>> From<T> for FakeLanguageModelStream {
151 fn from(chunk: T) -> Self {
152 Self {
153 events: vec![LanguageModelCompletionEvent::Text(chunk.into())],
154 }
155 }
156}
157
158impl From<LanguageModelToolUse> for FakeLanguageModelStream {
159 fn from(tool_use: LanguageModelToolUse) -> Self {
160 Self {
161 events: vec![
162 LanguageModelCompletionEvent::ToolUse(tool_use),
163 LanguageModelCompletionEvent::Stop(StopReason::ToolUse),
164 ],
165 }
166 }
167}
168
169impl LanguageModel for FakeLanguageModel {
170 fn id(&self) -> LanguageModelId {
171 language_model_id()
172 }
173
174 fn name(&self) -> LanguageModelName {
175 language_model_name()
176 }
177
178 fn provider_id(&self) -> LanguageModelProviderId {
179 provider_id()
180 }
181
182 fn provider_name(&self) -> LanguageModelProviderName {
183 provider_name()
184 }
185
186 fn supports_tools(&self) -> bool {
187 false
188 }
189
190 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
191 false
192 }
193
194 fn supports_images(&self) -> bool {
195 false
196 }
197
198 fn telemetry_id(&self) -> String {
199 "fake".to_string()
200 }
201
202 fn max_token_count(&self) -> u64 {
203 1000000
204 }
205
206 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<u64>> {
207 futures::future::ready(Ok(0)).boxed()
208 }
209
210 fn stream_completion(
211 &self,
212 request: LanguageModelRequest,
213 _: &AsyncApp,
214 ) -> BoxFuture<
215 'static,
216 Result<
217 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
218 LanguageModelCompletionError,
219 >,
220 > {
221 let (tx, rx) = mpsc::unbounded();
222 self.current_completion_txs.lock().push((request, tx));
223 async move { Ok(rx.map(Ok).boxed()) }.boxed()
224 }
225
226 fn as_fake(&self) -> &Self {
227 self
228 }
229}