1use anyhow::Result;
2use collections::HashMap;
3use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
4use gpui::{AnyView, AppContext, Task};
5use std::sync::Arc;
6use ui::WindowContext;
7
8use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
9
10#[derive(Clone, Default)]
11pub struct FakeCompletionProvider {
12 current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
13}
14
15impl FakeCompletionProvider {
16 #[cfg(test)]
17 pub fn setup_test(cx: &mut AppContext) -> Self {
18 use crate::CompletionProvider;
19 use parking_lot::RwLock;
20
21 let this = Self::default();
22 let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
23 cx.set_global(provider);
24 this
25 }
26
27 pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
28 self.current_completion_txs
29 .lock()
30 .keys()
31 .map(|k| serde_json::from_str(k).unwrap())
32 .collect()
33 }
34
35 pub fn completion_count(&self) -> usize {
36 self.current_completion_txs.lock().len()
37 }
38
39 pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
40 let json = serde_json::to_string(request).unwrap();
41 self.current_completion_txs
42 .lock()
43 .get(&json)
44 .unwrap()
45 .unbounded_send(chunk)
46 .unwrap();
47 }
48
49 pub fn finish_completion(&self, request: &LanguageModelRequest) {
50 self.current_completion_txs
51 .lock()
52 .remove(&serde_json::to_string(request).unwrap());
53 }
54}
55
56impl LanguageModelCompletionProvider for FakeCompletionProvider {
57 fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
58 vec![LanguageModel::default()]
59 }
60
61 fn settings_version(&self) -> usize {
62 0
63 }
64
65 fn is_authenticated(&self) -> bool {
66 true
67 }
68
69 fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
70 Task::ready(Ok(()))
71 }
72
73 fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
74 unimplemented!()
75 }
76
77 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
78 Task::ready(Ok(()))
79 }
80
81 fn model(&self) -> LanguageModel {
82 LanguageModel::default()
83 }
84
85 fn count_tokens(
86 &self,
87 _request: LanguageModelRequest,
88 _cx: &AppContext,
89 ) -> BoxFuture<'static, Result<usize>> {
90 futures::future::ready(Ok(0)).boxed()
91 }
92
93 fn complete(
94 &self,
95 _request: LanguageModelRequest,
96 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
97 let (tx, rx) = mpsc::unbounded();
98 self.current_completion_txs
99 .lock()
100 .insert(serde_json::to_string(&_request).unwrap(), tx);
101 async move { Ok(rx.map(Ok).boxed()) }.boxed()
102 }
103
104 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
105 self
106 }
107}