1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
3 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
4 LanguageModelRequest,
5};
6use anyhow::Context as _;
7use futures::{
8 channel::{mpsc, oneshot},
9 future::BoxFuture,
10 stream::BoxStream,
11 FutureExt, StreamExt,
12};
13use gpui::{AnyView, AppContext, AsyncAppContext, Task};
14use http_client::Result;
15use parking_lot::Mutex;
16use std::sync::Arc;
17use ui::WindowContext;
18
19pub fn language_model_id() -> LanguageModelId {
20 LanguageModelId::from("fake".to_string())
21}
22
23pub fn language_model_name() -> LanguageModelName {
24 LanguageModelName::from("Fake".to_string())
25}
26
27pub fn provider_id() -> LanguageModelProviderId {
28 LanguageModelProviderId::from("fake".to_string())
29}
30
31pub fn provider_name() -> LanguageModelProviderName {
32 LanguageModelProviderName::from("Fake".to_string())
33}
34
35#[derive(Clone, Default)]
36pub struct FakeLanguageModelProvider;
37
38impl LanguageModelProviderState for FakeLanguageModelProvider {
39 type ObservableEntity = ();
40
41 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
42 None
43 }
44}
45
46impl LanguageModelProvider for FakeLanguageModelProvider {
47 fn id(&self) -> LanguageModelProviderId {
48 provider_id()
49 }
50
51 fn name(&self) -> LanguageModelProviderName {
52 provider_name()
53 }
54
55 fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
56 vec![Arc::new(FakeLanguageModel::default())]
57 }
58
59 fn is_authenticated(&self, _: &AppContext) -> bool {
60 true
61 }
62
63 fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
64 Task::ready(Ok(()))
65 }
66
67 fn configuration_view(&self, _: &mut WindowContext) -> AnyView {
68 unimplemented!()
69 }
70
71 fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
72 Task::ready(Ok(()))
73 }
74}
75
76impl FakeLanguageModelProvider {
77 pub fn test_model(&self) -> FakeLanguageModel {
78 FakeLanguageModel::default()
79 }
80}
81
82#[derive(Debug, PartialEq)]
83pub struct ToolUseRequest {
84 pub request: LanguageModelRequest,
85 pub name: String,
86 pub description: String,
87 pub schema: serde_json::Value,
88}
89
90#[derive(Default)]
91pub struct FakeLanguageModel {
92 current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
93 current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
94}
95
96impl FakeLanguageModel {
97 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
98 self.current_completion_txs
99 .lock()
100 .iter()
101 .map(|(request, _)| request.clone())
102 .collect()
103 }
104
105 pub fn completion_count(&self) -> usize {
106 self.current_completion_txs.lock().len()
107 }
108
109 pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
110 let current_completion_txs = self.current_completion_txs.lock();
111 let tx = current_completion_txs
112 .iter()
113 .find(|(req, _)| req == request)
114 .map(|(_, tx)| tx)
115 .unwrap();
116 tx.unbounded_send(chunk).unwrap();
117 }
118
119 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
120 self.current_completion_txs
121 .lock()
122 .retain(|(req, _)| req != request);
123 }
124
125 pub fn stream_last_completion_response(&self, chunk: String) {
126 self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
127 }
128
129 pub fn end_last_completion_stream(&self) {
130 self.end_completion_stream(self.pending_completions().last().unwrap());
131 }
132
133 pub fn respond_to_tool_use(
134 &self,
135 tool_call: &ToolUseRequest,
136 response: Result<serde_json::Value>,
137 ) {
138 let mut current_tool_call_txs = self.current_tool_use_txs.lock();
139 if let Some(index) = current_tool_call_txs
140 .iter()
141 .position(|(call, _)| call == tool_call)
142 {
143 let (_, tx) = current_tool_call_txs.remove(index);
144 tx.send(response).unwrap();
145 }
146 }
147
148 pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
149 let mut current_tool_call_txs = self.current_tool_use_txs.lock();
150 let (_, tx) = current_tool_call_txs.pop().unwrap();
151 tx.send(response).unwrap();
152 }
153}
154
155impl LanguageModel for FakeLanguageModel {
156 fn id(&self) -> LanguageModelId {
157 language_model_id()
158 }
159
160 fn name(&self) -> LanguageModelName {
161 language_model_name()
162 }
163
164 fn provider_id(&self) -> LanguageModelProviderId {
165 provider_id()
166 }
167
168 fn provider_name(&self) -> LanguageModelProviderName {
169 provider_name()
170 }
171
172 fn telemetry_id(&self) -> String {
173 "fake".to_string()
174 }
175
176 fn max_token_count(&self) -> usize {
177 1000000
178 }
179
180 fn count_tokens(
181 &self,
182 _: LanguageModelRequest,
183 _: &AppContext,
184 ) -> BoxFuture<'static, Result<usize>> {
185 futures::future::ready(Ok(0)).boxed()
186 }
187
188 fn stream_completion(
189 &self,
190 request: LanguageModelRequest,
191 _: &AsyncAppContext,
192 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
193 let (tx, rx) = mpsc::unbounded();
194 self.current_completion_txs.lock().push((request, tx));
195 async move { Ok(rx.map(Ok).boxed()) }.boxed()
196 }
197
198 fn use_any_tool(
199 &self,
200 request: LanguageModelRequest,
201 name: String,
202 description: String,
203 schema: serde_json::Value,
204 _cx: &AsyncAppContext,
205 ) -> BoxFuture<'static, Result<serde_json::Value>> {
206 let (tx, rx) = oneshot::channel();
207 let tool_call = ToolUseRequest {
208 request,
209 name,
210 description,
211 schema,
212 };
213 self.current_tool_use_txs.lock().push((tool_call, tx));
214 async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
215 }
216
217 fn as_fake(&self) -> &Self {
218 self
219 }
220}