1use std::sync::{Arc, Mutex};
2
3use collections::HashMap;
4use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
5
6use crate::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest,
10};
11use gpui::{AnyView, AppContext, AsyncAppContext, Task};
12use http_client::Result;
13use ui::WindowContext;
14
15pub fn language_model_id() -> LanguageModelId {
16 LanguageModelId::from("fake".to_string())
17}
18
19pub fn language_model_name() -> LanguageModelName {
20 LanguageModelName::from("Fake".to_string())
21}
22
23pub fn provider_id() -> LanguageModelProviderId {
24 LanguageModelProviderId::from("fake".to_string())
25}
26
27pub fn provider_name() -> LanguageModelProviderName {
28 LanguageModelProviderName::from("Fake".to_string())
29}
30
31#[derive(Clone, Default)]
32pub struct FakeLanguageModelProvider {
33 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
34}
35
36impl LanguageModelProviderState for FakeLanguageModelProvider {
37 fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
38 None
39 }
40}
41
42impl LanguageModelProvider for FakeLanguageModelProvider {
43 fn id(&self) -> LanguageModelProviderId {
44 provider_id()
45 }
46
47 fn name(&self) -> LanguageModelProviderName {
48 provider_name()
49 }
50
51 fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
52 vec![Arc::new(FakeLanguageModel {
53 current_completion_txs: self.current_completion_txs.clone(),
54 })]
55 }
56
57 fn is_authenticated(&self, _: &AppContext) -> bool {
58 true
59 }
60
61 fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
62 Task::ready(Ok(()))
63 }
64
65 fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
66 unimplemented!()
67 }
68
69 fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
70 Task::ready(Ok(()))
71 }
72}
73
74impl FakeLanguageModelProvider {
75 pub fn test_model(&self) -> FakeLanguageModel {
76 FakeLanguageModel {
77 current_completion_txs: self.current_completion_txs.clone(),
78 }
79 }
80}
81
82pub struct FakeLanguageModel {
83 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
84}
85
86impl FakeLanguageModel {
87 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
88 self.current_completion_txs
89 .lock()
90 .unwrap()
91 .keys()
92 .map(|k| serde_json::from_str(k).unwrap())
93 .collect()
94 }
95
96 pub fn completion_count(&self) -> usize {
97 self.current_completion_txs.lock().unwrap().len()
98 }
99
100 pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
101 let json = serde_json::to_string(request).unwrap();
102 self.current_completion_txs
103 .lock()
104 .unwrap()
105 .get(&json)
106 .unwrap()
107 .unbounded_send(chunk)
108 .unwrap();
109 }
110
111 pub fn send_last_completion_chunk(&self, chunk: String) {
112 self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
113 }
114
115 pub fn finish_completion(&self, request: &LanguageModelRequest) {
116 self.current_completion_txs
117 .lock()
118 .unwrap()
119 .remove(&serde_json::to_string(request).unwrap())
120 .unwrap();
121 }
122
123 pub fn finish_last_completion(&self) {
124 self.finish_completion(self.pending_completions().last().unwrap());
125 }
126}
127
128impl LanguageModel for FakeLanguageModel {
129 fn id(&self) -> LanguageModelId {
130 language_model_id()
131 }
132
133 fn name(&self) -> LanguageModelName {
134 language_model_name()
135 }
136
137 fn provider_id(&self) -> LanguageModelProviderId {
138 provider_id()
139 }
140
141 fn provider_name(&self) -> LanguageModelProviderName {
142 provider_name()
143 }
144
145 fn telemetry_id(&self) -> String {
146 "fake".to_string()
147 }
148
149 fn max_token_count(&self) -> usize {
150 1000000
151 }
152
153 fn count_tokens(
154 &self,
155 _: LanguageModelRequest,
156 _: &AppContext,
157 ) -> BoxFuture<'static, Result<usize>> {
158 futures::future::ready(Ok(0)).boxed()
159 }
160
161 fn stream_completion(
162 &self,
163 request: LanguageModelRequest,
164 _: &AsyncAppContext,
165 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
166 let (tx, rx) = mpsc::unbounded();
167 self.current_completion_txs
168 .lock()
169 .unwrap()
170 .insert(serde_json::to_string(&request).unwrap(), tx);
171 async move { Ok(rx.map(Ok).boxed()) }.boxed()
172 }
173}