1use anyhow::Result;
2use collections::HashMap;
3use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
4use gpui::{AnyView, AppContext, Task};
5use std::sync::Arc;
6use ui::WindowContext;
7
8use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
9
10#[derive(Clone, Default)]
11pub struct FakeCompletionProvider {
12 current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
13}
14
15impl FakeCompletionProvider {
16 pub fn setup_test(cx: &mut AppContext) -> Self {
17 use crate::CompletionProvider;
18 use parking_lot::RwLock;
19
20 let this = Self::default();
21 let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
22 cx.set_global(provider);
23 this
24 }
25
26 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
27 self.current_completion_txs
28 .lock()
29 .keys()
30 .map(|k| serde_json::from_str(k).unwrap())
31 .collect()
32 }
33
34 pub fn completion_count(&self) -> usize {
35 self.current_completion_txs.lock().len()
36 }
37
38 pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
39 let json = serde_json::to_string(request).unwrap();
40 self.current_completion_txs
41 .lock()
42 .get(&json)
43 .unwrap()
44 .unbounded_send(chunk)
45 .unwrap();
46 }
47
48 pub fn send_last_completion_chunk(&self, chunk: String) {
49 self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
50 }
51
52 pub fn finish_completion(&self, request: &LanguageModelRequest) {
53 self.current_completion_txs
54 .lock()
55 .remove(&serde_json::to_string(request).unwrap())
56 .unwrap();
57 }
58
59 pub fn finish_last_completion(&self) {
60 self.finish_completion(self.pending_completions().last().unwrap());
61 }
62}
63
64impl LanguageModelCompletionProvider for FakeCompletionProvider {
65 fn available_models(&self) -> Vec<LanguageModel> {
66 vec![LanguageModel::default()]
67 }
68
69 fn settings_version(&self) -> usize {
70 0
71 }
72
73 fn is_authenticated(&self) -> bool {
74 true
75 }
76
77 fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
78 Task::ready(Ok(()))
79 }
80
81 fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
82 unimplemented!()
83 }
84
85 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
86 Task::ready(Ok(()))
87 }
88
89 fn model(&self) -> LanguageModel {
90 LanguageModel::default()
91 }
92
93 fn count_tokens(
94 &self,
95 _request: LanguageModelRequest,
96 _cx: &AppContext,
97 ) -> BoxFuture<'static, Result<usize>> {
98 futures::future::ready(Ok(0)).boxed()
99 }
100
101 fn stream_completion(
102 &self,
103 _request: LanguageModelRequest,
104 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
105 let (tx, rx) = mpsc::unbounded();
106 self.current_completion_txs
107 .lock()
108 .insert(serde_json::to_string(&_request).unwrap(), tx);
109 async move { Ok(rx.map(Ok).boxed()) }.boxed()
110 }
111
112 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
113 self
114 }
115}