1use crate::{
2 LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
3 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
4 LanguageModelProviderState, LanguageModelRequest,
5};
6use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
7use gpui::{AnyView, AppContext, AsyncAppContext, Task};
8use http_client::Result;
9use parking_lot::Mutex;
10use serde::Serialize;
11use std::sync::Arc;
12use ui::WindowContext;
13
14pub fn language_model_id() -> LanguageModelId {
15 LanguageModelId::from("fake".to_string())
16}
17
18pub fn language_model_name() -> LanguageModelName {
19 LanguageModelName::from("Fake".to_string())
20}
21
22pub fn provider_id() -> LanguageModelProviderId {
23 LanguageModelProviderId::from("fake".to_string())
24}
25
26pub fn provider_name() -> LanguageModelProviderName {
27 LanguageModelProviderName::from("Fake".to_string())
28}
29
30#[derive(Clone, Default)]
31pub struct FakeLanguageModelProvider;
32
33impl LanguageModelProviderState for FakeLanguageModelProvider {
34 type ObservableEntity = ();
35
36 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
37 None
38 }
39}
40
41impl LanguageModelProvider for FakeLanguageModelProvider {
42 fn id(&self) -> LanguageModelProviderId {
43 provider_id()
44 }
45
46 fn name(&self) -> LanguageModelProviderName {
47 provider_name()
48 }
49
50 fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
51 vec![Arc::new(FakeLanguageModel::default())]
52 }
53
54 fn is_authenticated(&self, _: &AppContext) -> bool {
55 true
56 }
57
58 fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
59 Task::ready(Ok(()))
60 }
61
62 fn configuration_view(&self, _: &mut WindowContext) -> AnyView {
63 unimplemented!()
64 }
65
66 fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
67 Task::ready(Ok(()))
68 }
69}
70
71impl FakeLanguageModelProvider {
72 pub fn test_model(&self) -> FakeLanguageModel {
73 FakeLanguageModel::default()
74 }
75}
76
77#[derive(Debug, PartialEq)]
78pub struct ToolUseRequest {
79 pub request: LanguageModelRequest,
80 pub name: String,
81 pub description: String,
82 pub schema: serde_json::Value,
83}
84
85#[derive(Default)]
86pub struct FakeLanguageModel {
87 current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
88 current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
89}
90
91impl FakeLanguageModel {
92 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
93 self.current_completion_txs
94 .lock()
95 .iter()
96 .map(|(request, _)| request.clone())
97 .collect()
98 }
99
100 pub fn completion_count(&self) -> usize {
101 self.current_completion_txs.lock().len()
102 }
103
104 pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
105 let current_completion_txs = self.current_completion_txs.lock();
106 let tx = current_completion_txs
107 .iter()
108 .find(|(req, _)| req == request)
109 .map(|(_, tx)| tx)
110 .unwrap();
111 tx.unbounded_send(chunk).unwrap();
112 }
113
114 pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
115 self.current_completion_txs
116 .lock()
117 .retain(|(req, _)| req != request);
118 }
119
120 pub fn stream_last_completion_response(&self, chunk: String) {
121 self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
122 }
123
124 pub fn end_last_completion_stream(&self) {
125 self.end_completion_stream(self.pending_completions().last().unwrap());
126 }
127
128 pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
129 let response = serde_json::to_string(&response).unwrap();
130 let mut current_tool_call_txs = self.current_tool_use_txs.lock();
131 let (_, tx) = current_tool_call_txs.pop().unwrap();
132 tx.unbounded_send(response).unwrap();
133 }
134}
135
136impl LanguageModel for FakeLanguageModel {
137 fn id(&self) -> LanguageModelId {
138 language_model_id()
139 }
140
141 fn name(&self) -> LanguageModelName {
142 language_model_name()
143 }
144
145 fn provider_id(&self) -> LanguageModelProviderId {
146 provider_id()
147 }
148
149 fn provider_name(&self) -> LanguageModelProviderName {
150 provider_name()
151 }
152
153 fn telemetry_id(&self) -> String {
154 "fake".to_string()
155 }
156
157 fn max_token_count(&self) -> usize {
158 1000000
159 }
160
161 fn count_tokens(
162 &self,
163 _: LanguageModelRequest,
164 _: &AppContext,
165 ) -> BoxFuture<'static, Result<usize>> {
166 futures::future::ready(Ok(0)).boxed()
167 }
168
169 fn stream_completion(
170 &self,
171 request: LanguageModelRequest,
172 _: &AsyncAppContext,
173 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
174 let (tx, rx) = mpsc::unbounded();
175 self.current_completion_txs.lock().push((request, tx));
176 async move {
177 Ok(rx
178 .map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
179 .boxed())
180 }
181 .boxed()
182 }
183
184 fn use_any_tool(
185 &self,
186 request: LanguageModelRequest,
187 name: String,
188 description: String,
189 schema: serde_json::Value,
190 _cx: &AsyncAppContext,
191 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
192 let (tx, rx) = mpsc::unbounded();
193 let tool_call = ToolUseRequest {
194 request,
195 name,
196 description,
197 schema,
198 };
199 self.current_tool_use_txs.lock().push((tool_call, tx));
200 async move { Ok(rx.map(Ok).boxed()) }.boxed()
201 }
202
203 fn as_fake(&self) -> &Self {
204 self
205 }
206}