1use crate::{
2 AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
3 LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
4 LanguageModelProviderState, LanguageModelRequest,
5};
6use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
8use http_client::Result;
9use parking_lot::Mutex;
10use serde::Serialize;
11use std::sync::Arc;
12
13pub fn language_model_id() -> LanguageModelId {
14 LanguageModelId::from("fake".to_string())
15}
16
17pub fn language_model_name() -> LanguageModelName {
18 LanguageModelName::from("Fake".to_string())
19}
20
21pub fn provider_id() -> LanguageModelProviderId {
22 LanguageModelProviderId::from("fake".to_string())
23}
24
25pub fn provider_name() -> LanguageModelProviderName {
26 LanguageModelProviderName::from("Fake".to_string())
27}
28
29#[derive(Clone, Default)]
30pub struct FakeLanguageModelProvider;
31
32impl LanguageModelProviderState for FakeLanguageModelProvider {
33 type ObservableEntity = ();
34
35 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
36 None
37 }
38}
39
40impl LanguageModelProvider for FakeLanguageModelProvider {
41 fn id(&self) -> LanguageModelProviderId {
42 provider_id()
43 }
44
45 fn name(&self) -> LanguageModelProviderName {
46 provider_name()
47 }
48
49 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
50 Some(Arc::new(FakeLanguageModel::default()))
51 }
52
53 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
54 vec![Arc::new(FakeLanguageModel::default())]
55 }
56
57 fn is_authenticated(&self, _: &App) -> bool {
58 true
59 }
60
61 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
62 Task::ready(Ok(()))
63 }
64
65 fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
66 unimplemented!()
67 }
68
69 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
70 Task::ready(Ok(()))
71 }
72}
73
74impl FakeLanguageModelProvider {
75 pub fn test_model(&self) -> FakeLanguageModel {
76 FakeLanguageModel::default()
77 }
78}
79
80#[derive(Debug, PartialEq)]
81pub struct ToolUseRequest {
82 pub request: LanguageModelRequest,
83 pub name: String,
84 pub description: String,
85 pub schema: serde_json::Value,
86}
87
88#[derive(Default)]
89pub struct FakeLanguageModel {
90 current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
91 current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
92}
93
94impl FakeLanguageModel {
95 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
96 self.current_completion_txs
97 .lock()
98 .iter()
99 .map(|(request, _)| request.clone())
100 .collect()
101 }
102
103 pub fn completion_count(&self) -> usize {
104 self.current_completion_txs.lock().len()
105 }
106
107 pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
108 let current_completion_txs = self.current_completion_txs.lock();
109 let tx = current_completion_txs
110 .iter()
111 .find(|(req, _)| req == request)
112 .map(|(_, tx)| tx)
113 .unwrap();
114 tx.unbounded_send(chunk).unwrap();
115 }
116
117 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
118 self.current_completion_txs
119 .lock()
120 .retain(|(req, _)| req != request);
121 }
122
123 pub fn stream_last_completion_response(&self, chunk: String) {
124 self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
125 }
126
127 pub fn end_last_completion_stream(&self) {
128 self.end_completion_stream(self.pending_completions().last().unwrap());
129 }
130
131 pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
132 let response = serde_json::to_string(&response).unwrap();
133 let mut current_tool_call_txs = self.current_tool_use_txs.lock();
134 let (_, tx) = current_tool_call_txs.pop().unwrap();
135 tx.unbounded_send(response).unwrap();
136 }
137}
138
139impl LanguageModel for FakeLanguageModel {
140 fn id(&self) -> LanguageModelId {
141 language_model_id()
142 }
143
144 fn name(&self) -> LanguageModelName {
145 language_model_name()
146 }
147
148 fn provider_id(&self) -> LanguageModelProviderId {
149 provider_id()
150 }
151
152 fn provider_name(&self) -> LanguageModelProviderName {
153 provider_name()
154 }
155
156 fn supports_tools(&self) -> bool {
157 false
158 }
159
160 fn telemetry_id(&self) -> String {
161 "fake".to_string()
162 }
163
164 fn max_token_count(&self) -> usize {
165 1000000
166 }
167
168 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
169 futures::future::ready(Ok(0)).boxed()
170 }
171
172 fn stream_completion(
173 &self,
174 request: LanguageModelRequest,
175 _: &AsyncApp,
176 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
177 let (tx, rx) = mpsc::unbounded();
178 self.current_completion_txs.lock().push((request, tx));
179 async move {
180 Ok(rx
181 .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
182 .boxed())
183 }
184 .boxed()
185 }
186
187 fn use_any_tool(
188 &self,
189 request: LanguageModelRequest,
190 name: String,
191 description: String,
192 schema: serde_json::Value,
193 _cx: &AsyncApp,
194 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
195 let (tx, rx) = mpsc::unbounded();
196 let tool_call = ToolUseRequest {
197 request,
198 name,
199 description,
200 schema,
201 };
202 self.current_tool_use_txs.lock().push((tool_call, tx));
203 async move { Ok(rx.map(Ok).boxed()) }.boxed()
204 }
205
206 fn as_fake(&self) -> &Self {
207 self
208 }
209}