1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
3 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
4 LanguageModelRequest,
5};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
9use gpui::{AnyView, AppContext, AsyncAppContext, Task};
10use http_client::Result;
11use std::{
12 future,
13 sync::{Arc, Mutex},
14};
15use ui::WindowContext;
16
17pub fn language_model_id() -> LanguageModelId {
18 LanguageModelId::from("fake".to_string())
19}
20
21pub fn language_model_name() -> LanguageModelName {
22 LanguageModelName::from("Fake".to_string())
23}
24
25pub fn provider_id() -> LanguageModelProviderId {
26 LanguageModelProviderId::from("fake".to_string())
27}
28
29pub fn provider_name() -> LanguageModelProviderName {
30 LanguageModelProviderName::from("Fake".to_string())
31}
32
33#[derive(Clone, Default)]
34pub struct FakeLanguageModelProvider {
35 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
36}
37
38impl LanguageModelProviderState for FakeLanguageModelProvider {
39 fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
40 None
41 }
42}
43
44impl LanguageModelProvider for FakeLanguageModelProvider {
45 fn id(&self) -> LanguageModelProviderId {
46 provider_id()
47 }
48
49 fn name(&self) -> LanguageModelProviderName {
50 provider_name()
51 }
52
53 fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
54 vec![Arc::new(FakeLanguageModel {
55 current_completion_txs: self.current_completion_txs.clone(),
56 })]
57 }
58
59 fn is_authenticated(&self, _: &AppContext) -> bool {
60 true
61 }
62
63 fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
64 Task::ready(Ok(()))
65 }
66
67 fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
68 unimplemented!()
69 }
70
71 fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
72 Task::ready(Ok(()))
73 }
74}
75
76impl FakeLanguageModelProvider {
77 pub fn test_model(&self) -> FakeLanguageModel {
78 FakeLanguageModel {
79 current_completion_txs: self.current_completion_txs.clone(),
80 }
81 }
82}
83
84pub struct FakeLanguageModel {
85 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
86}
87
88impl FakeLanguageModel {
89 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
90 self.current_completion_txs
91 .lock()
92 .unwrap()
93 .keys()
94 .map(|k| serde_json::from_str(k).unwrap())
95 .collect()
96 }
97
98 pub fn completion_count(&self) -> usize {
99 self.current_completion_txs.lock().unwrap().len()
100 }
101
102 pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
103 let json = serde_json::to_string(request).unwrap();
104 self.current_completion_txs
105 .lock()
106 .unwrap()
107 .get(&json)
108 .unwrap()
109 .unbounded_send(chunk)
110 .unwrap();
111 }
112
113 pub fn send_last_completion_chunk(&self, chunk: String) {
114 self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
115 }
116
117 pub fn finish_completion(&self, request: &LanguageModelRequest) {
118 self.current_completion_txs
119 .lock()
120 .unwrap()
121 .remove(&serde_json::to_string(request).unwrap())
122 .unwrap();
123 }
124
125 pub fn finish_last_completion(&self) {
126 self.finish_completion(self.pending_completions().last().unwrap());
127 }
128}
129
130impl LanguageModel for FakeLanguageModel {
131 fn id(&self) -> LanguageModelId {
132 language_model_id()
133 }
134
135 fn name(&self) -> LanguageModelName {
136 language_model_name()
137 }
138
139 fn provider_id(&self) -> LanguageModelProviderId {
140 provider_id()
141 }
142
143 fn provider_name(&self) -> LanguageModelProviderName {
144 provider_name()
145 }
146
147 fn telemetry_id(&self) -> String {
148 "fake".to_string()
149 }
150
151 fn max_token_count(&self) -> usize {
152 1000000
153 }
154
155 fn count_tokens(
156 &self,
157 _: LanguageModelRequest,
158 _: &AppContext,
159 ) -> BoxFuture<'static, Result<usize>> {
160 futures::future::ready(Ok(0)).boxed()
161 }
162
163 fn stream_completion(
164 &self,
165 request: LanguageModelRequest,
166 _: &AsyncAppContext,
167 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
168 let (tx, rx) = mpsc::unbounded();
169 self.current_completion_txs
170 .lock()
171 .unwrap()
172 .insert(serde_json::to_string(&request).unwrap(), tx);
173 async move { Ok(rx.map(Ok).boxed()) }.boxed()
174 }
175
176 fn use_any_tool(
177 &self,
178 _request: LanguageModelRequest,
179 _name: String,
180 _description: String,
181 _schema: serde_json::Value,
182 _cx: &AsyncAppContext,
183 ) -> BoxFuture<'static, Result<serde_json::Value>> {
184 future::ready(Err(anyhow!("not implemented"))).boxed()
185 }
186}