1use crate::{
2 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
3 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
4 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
5};
6use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
8use http_client::Result;
9use parking_lot::Mutex;
10use std::sync::Arc;
11
12pub fn language_model_id() -> LanguageModelId {
13 LanguageModelId::from("fake".to_string())
14}
15
16pub fn language_model_name() -> LanguageModelName {
17 LanguageModelName::from("Fake".to_string())
18}
19
20pub fn provider_id() -> LanguageModelProviderId {
21 LanguageModelProviderId::from("fake".to_string())
22}
23
24pub fn provider_name() -> LanguageModelProviderName {
25 LanguageModelProviderName::from("Fake".to_string())
26}
27
28#[derive(Clone, Default)]
29pub struct FakeLanguageModelProvider;
30
31impl LanguageModelProviderState for FakeLanguageModelProvider {
32 type ObservableEntity = ();
33
34 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
35 None
36 }
37}
38
39impl LanguageModelProvider for FakeLanguageModelProvider {
40 fn id(&self) -> LanguageModelProviderId {
41 provider_id()
42 }
43
44 fn name(&self) -> LanguageModelProviderName {
45 provider_name()
46 }
47
48 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
49 Some(Arc::new(FakeLanguageModel::default()))
50 }
51
52 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
53 Some(Arc::new(FakeLanguageModel::default()))
54 }
55
56 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
57 vec![Arc::new(FakeLanguageModel::default())]
58 }
59
60 fn is_authenticated(&self, _: &App) -> bool {
61 true
62 }
63
64 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
65 Task::ready(Ok(()))
66 }
67
68 fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
69 unimplemented!()
70 }
71
72 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
73 Task::ready(Ok(()))
74 }
75}
76
77impl FakeLanguageModelProvider {
78 pub fn test_model(&self) -> FakeLanguageModel {
79 FakeLanguageModel::default()
80 }
81}
82
83#[derive(Debug, PartialEq)]
84pub struct ToolUseRequest {
85 pub request: LanguageModelRequest,
86 pub name: String,
87 pub description: String,
88 pub schema: serde_json::Value,
89}
90
91#[derive(Default)]
92pub struct FakeLanguageModel {
93 current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
94}
95
96impl FakeLanguageModel {
97 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
98 self.current_completion_txs
99 .lock()
100 .iter()
101 .map(|(request, _)| request.clone())
102 .collect()
103 }
104
105 pub fn completion_count(&self) -> usize {
106 self.current_completion_txs.lock().len()
107 }
108
109 pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
110 let current_completion_txs = self.current_completion_txs.lock();
111 let tx = current_completion_txs
112 .iter()
113 .find(|(req, _)| req == request)
114 .map(|(_, tx)| tx)
115 .unwrap();
116 tx.unbounded_send(chunk).unwrap();
117 }
118
119 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
120 self.current_completion_txs
121 .lock()
122 .retain(|(req, _)| req != request);
123 }
124
125 pub fn stream_last_completion_response(&self, chunk: String) {
126 self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
127 }
128
129 pub fn end_last_completion_stream(&self) {
130 self.end_completion_stream(self.pending_completions().last().unwrap());
131 }
132}
133
134impl LanguageModel for FakeLanguageModel {
135 fn id(&self) -> LanguageModelId {
136 language_model_id()
137 }
138
139 fn name(&self) -> LanguageModelName {
140 language_model_name()
141 }
142
143 fn provider_id(&self) -> LanguageModelProviderId {
144 provider_id()
145 }
146
147 fn provider_name(&self) -> LanguageModelProviderName {
148 provider_name()
149 }
150
151 fn supports_tools(&self) -> bool {
152 false
153 }
154
155 fn telemetry_id(&self) -> String {
156 "fake".to_string()
157 }
158
159 fn max_token_count(&self) -> usize {
160 1000000
161 }
162
163 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
164 futures::future::ready(Ok(0)).boxed()
165 }
166
167 fn stream_completion(
168 &self,
169 request: LanguageModelRequest,
170 _: &AsyncApp,
171 ) -> BoxFuture<
172 'static,
173 Result<
174 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
175 >,
176 > {
177 let (tx, rx) = mpsc::unbounded();
178 self.current_completion_txs.lock().push((request, tx));
179 async move {
180 Ok(rx
181 .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
182 .boxed())
183 }
184 .boxed()
185 }
186
187 fn as_fake(&self) -> &Self {
188 self
189 }
190}