1use anyhow::Result;
2use collections::HashMap;
3use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
4use gpui::{AnyView, AppContext, Task};
5use std::sync::Arc;
6use ui::WindowContext;
7
8use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
9
10#[derive(Clone, Default)]
11pub struct FakeCompletionProvider {
12 current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
13}
14
15impl FakeCompletionProvider {
16 pub fn setup_test(cx: &mut AppContext) -> Self {
17 use crate::CompletionProvider;
18 use parking_lot::RwLock;
19
20 let this = Self::default();
21 let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
22 cx.set_global(provider);
23 this
24 }
25
26 pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
27 self.current_completion_txs
28 .lock()
29 .keys()
30 .map(|k| serde_json::from_str(k).unwrap())
31 .collect()
32 }
33
34 pub fn completion_count(&self) -> usize {
35 self.current_completion_txs.lock().len()
36 }
37
38 pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
39 let json = serde_json::to_string(request).unwrap();
40 self.current_completion_txs
41 .lock()
42 .get(&json)
43 .unwrap()
44 .unbounded_send(chunk)
45 .unwrap();
46 }
47
48 pub fn finish_completion(&self, request: &LanguageModelRequest) {
49 self.current_completion_txs
50 .lock()
51 .remove(&serde_json::to_string(request).unwrap());
52 }
53}
54
55impl LanguageModelCompletionProvider for FakeCompletionProvider {
56 fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
57 vec![LanguageModel::default()]
58 }
59
60 fn settings_version(&self) -> usize {
61 0
62 }
63
64 fn is_authenticated(&self) -> bool {
65 true
66 }
67
68 fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
69 Task::ready(Ok(()))
70 }
71
72 fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
73 unimplemented!()
74 }
75
76 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
77 Task::ready(Ok(()))
78 }
79
80 fn model(&self) -> LanguageModel {
81 LanguageModel::default()
82 }
83
84 fn count_tokens(
85 &self,
86 _request: LanguageModelRequest,
87 _cx: &AppContext,
88 ) -> BoxFuture<'static, Result<usize>> {
89 futures::future::ready(Ok(0)).boxed()
90 }
91
92 fn complete(
93 &self,
94 _request: LanguageModelRequest,
95 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
96 let (tx, rx) = mpsc::unbounded();
97 self.current_completion_txs
98 .lock()
99 .insert(serde_json::to_string(&_request).unwrap(), tx);
100 async move { Ok(rx.map(Ok).boxed()) }.boxed()
101 }
102
103 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
104 self
105 }
106}