1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
3 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
4 LanguageModelRequest,
5};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
9use gpui::{AnyView, AppContext, AsyncAppContext, Task};
10use http_client::Result;
11use std::{
12 future,
13 sync::{Arc, Mutex},
14};
15use ui::WindowContext;
16
17pub fn language_model_id() -> LanguageModelId {
18 LanguageModelId::from("fake".to_string())
19}
20
21pub fn language_model_name() -> LanguageModelName {
22 LanguageModelName::from("Fake".to_string())
23}
24
25pub fn provider_id() -> LanguageModelProviderId {
26 LanguageModelProviderId::from("fake".to_string())
27}
28
29pub fn provider_name() -> LanguageModelProviderName {
30 LanguageModelProviderName::from("Fake".to_string())
31}
32
33#[derive(Clone, Default)]
34pub struct FakeLanguageModelProvider {
35 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
36}
37
38impl LanguageModelProviderState for FakeLanguageModelProvider {
39 type ObservableEntity = ();
40
41 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
42 None
43 }
44}
45
46impl LanguageModelProvider for FakeLanguageModelProvider {
47 fn id(&self) -> LanguageModelProviderId {
48 provider_id()
49 }
50
51 fn name(&self) -> LanguageModelProviderName {
52 provider_name()
53 }
54
55 fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
56 vec![Arc::new(FakeLanguageModel {
57 current_completion_txs: self.current_completion_txs.clone(),
58 })]
59 }
60
61 fn is_authenticated(&self, _: &AppContext) -> bool {
62 true
63 }
64
65 fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
66 Task::ready(Ok(()))
67 }
68
69 fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
70 unimplemented!()
71 }
72
73 fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
74 Task::ready(Ok(()))
75 }
76}
77
78impl FakeLanguageModelProvider {
79 pub fn test_model(&self) -> FakeLanguageModel {
80 FakeLanguageModel {
81 current_completion_txs: self.current_completion_txs.clone(),
82 }
83 }
84}
85
86pub struct FakeLanguageModel {
87 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
88}
89
90impl FakeLanguageModel {
91 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
92 self.current_completion_txs
93 .lock()
94 .unwrap()
95 .keys()
96 .map(|k| serde_json::from_str(k).unwrap())
97 .collect()
98 }
99
100 pub fn completion_count(&self) -> usize {
101 self.current_completion_txs.lock().unwrap().len()
102 }
103
104 pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
105 let json = serde_json::to_string(request).unwrap();
106 self.current_completion_txs
107 .lock()
108 .unwrap()
109 .get(&json)
110 .unwrap()
111 .unbounded_send(chunk)
112 .unwrap();
113 }
114
115 pub fn send_last_completion_chunk(&self, chunk: String) {
116 self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
117 }
118
119 pub fn finish_completion(&self, request: &LanguageModelRequest) {
120 self.current_completion_txs
121 .lock()
122 .unwrap()
123 .remove(&serde_json::to_string(request).unwrap())
124 .unwrap();
125 }
126
127 pub fn finish_last_completion(&self) {
128 self.finish_completion(self.pending_completions().last().unwrap());
129 }
130}
131
132impl LanguageModel for FakeLanguageModel {
133 fn id(&self) -> LanguageModelId {
134 language_model_id()
135 }
136
137 fn name(&self) -> LanguageModelName {
138 language_model_name()
139 }
140
141 fn provider_id(&self) -> LanguageModelProviderId {
142 provider_id()
143 }
144
145 fn provider_name(&self) -> LanguageModelProviderName {
146 provider_name()
147 }
148
149 fn telemetry_id(&self) -> String {
150 "fake".to_string()
151 }
152
153 fn max_token_count(&self) -> usize {
154 1000000
155 }
156
157 fn count_tokens(
158 &self,
159 _: LanguageModelRequest,
160 _: &AppContext,
161 ) -> BoxFuture<'static, Result<usize>> {
162 futures::future::ready(Ok(0)).boxed()
163 }
164
165 fn stream_completion(
166 &self,
167 request: LanguageModelRequest,
168 _: &AsyncAppContext,
169 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
170 let (tx, rx) = mpsc::unbounded();
171 self.current_completion_txs
172 .lock()
173 .unwrap()
174 .insert(serde_json::to_string(&request).unwrap(), tx);
175 async move { Ok(rx.map(Ok).boxed()) }.boxed()
176 }
177
178 fn use_any_tool(
179 &self,
180 _request: LanguageModelRequest,
181 _name: String,
182 _description: String,
183 _schema: serde_json::Value,
184 _cx: &AsyncAppContext,
185 ) -> BoxFuture<'static, Result<serde_json::Value>> {
186 future::ready(Err(anyhow!("not implemented"))).boxed()
187 }
188}