1use crate::{
2 AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
3 LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
4 LanguageModelProviderState, LanguageModelRequest,
5};
6use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
7use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
8use http_client::Result;
9use parking_lot::Mutex;
10use std::sync::Arc;
11
12pub fn language_model_id() -> LanguageModelId {
13 LanguageModelId::from("fake".to_string())
14}
15
16pub fn language_model_name() -> LanguageModelName {
17 LanguageModelName::from("Fake".to_string())
18}
19
20pub fn provider_id() -> LanguageModelProviderId {
21 LanguageModelProviderId::from("fake".to_string())
22}
23
24pub fn provider_name() -> LanguageModelProviderName {
25 LanguageModelProviderName::from("Fake".to_string())
26}
27
28#[derive(Clone, Default)]
29pub struct FakeLanguageModelProvider;
30
31impl LanguageModelProviderState for FakeLanguageModelProvider {
32 type ObservableEntity = ();
33
34 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
35 None
36 }
37}
38
39impl LanguageModelProvider for FakeLanguageModelProvider {
40 fn id(&self) -> LanguageModelProviderId {
41 provider_id()
42 }
43
44 fn name(&self) -> LanguageModelProviderName {
45 provider_name()
46 }
47
48 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
49 Some(Arc::new(FakeLanguageModel::default()))
50 }
51
52 fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
53 vec![Arc::new(FakeLanguageModel::default())]
54 }
55
56 fn is_authenticated(&self, _: &App) -> bool {
57 true
58 }
59
60 fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
61 Task::ready(Ok(()))
62 }
63
64 fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
65 unimplemented!()
66 }
67
68 fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
69 Task::ready(Ok(()))
70 }
71}
72
73impl FakeLanguageModelProvider {
74 pub fn test_model(&self) -> FakeLanguageModel {
75 FakeLanguageModel::default()
76 }
77}
78
79#[derive(Debug, PartialEq)]
80pub struct ToolUseRequest {
81 pub request: LanguageModelRequest,
82 pub name: String,
83 pub description: String,
84 pub schema: serde_json::Value,
85}
86
87#[derive(Default)]
88pub struct FakeLanguageModel {
89 current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
90}
91
92impl FakeLanguageModel {
93 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
94 self.current_completion_txs
95 .lock()
96 .iter()
97 .map(|(request, _)| request.clone())
98 .collect()
99 }
100
101 pub fn completion_count(&self) -> usize {
102 self.current_completion_txs.lock().len()
103 }
104
105 pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
106 let current_completion_txs = self.current_completion_txs.lock();
107 let tx = current_completion_txs
108 .iter()
109 .find(|(req, _)| req == request)
110 .map(|(_, tx)| tx)
111 .unwrap();
112 tx.unbounded_send(chunk).unwrap();
113 }
114
115 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
116 self.current_completion_txs
117 .lock()
118 .retain(|(req, _)| req != request);
119 }
120
121 pub fn stream_last_completion_response(&self, chunk: String) {
122 self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
123 }
124
125 pub fn end_last_completion_stream(&self) {
126 self.end_completion_stream(self.pending_completions().last().unwrap());
127 }
128}
129
130impl LanguageModel for FakeLanguageModel {
131 fn id(&self) -> LanguageModelId {
132 language_model_id()
133 }
134
135 fn name(&self) -> LanguageModelName {
136 language_model_name()
137 }
138
139 fn provider_id(&self) -> LanguageModelProviderId {
140 provider_id()
141 }
142
143 fn provider_name(&self) -> LanguageModelProviderName {
144 provider_name()
145 }
146
147 fn supports_tools(&self) -> bool {
148 false
149 }
150
151 fn telemetry_id(&self) -> String {
152 "fake".to_string()
153 }
154
155 fn max_token_count(&self) -> usize {
156 1000000
157 }
158
159 fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
160 futures::future::ready(Ok(0)).boxed()
161 }
162
163 fn stream_completion(
164 &self,
165 request: LanguageModelRequest,
166 _: &AsyncApp,
167 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
168 let (tx, rx) = mpsc::unbounded();
169 self.current_completion_txs.lock().push((request, tx));
170 async move {
171 Ok(rx
172 .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
173 .boxed())
174 }
175 .boxed()
176 }
177
178 fn as_fake(&self) -> &Self {
179 self
180 }
181}