1use std::sync::{Arc, Mutex};
2
3use collections::HashMap;
4use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
5
6use crate::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
9};
10use gpui::{AnyView, AppContext, AsyncAppContext, Task};
11use http::Result;
12use ui::WindowContext;
13
14pub fn language_model_id() -> LanguageModelId {
15 LanguageModelId::from("fake".to_string())
16}
17
18pub fn language_model_name() -> LanguageModelName {
19 LanguageModelName::from("Fake".to_string())
20}
21
22pub fn provider_name() -> LanguageModelProviderName {
23 LanguageModelProviderName::from("fake".to_string())
24}
25
26#[derive(Clone, Default)]
27pub struct FakeLanguageModelProvider {
28 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
29}
30
31impl LanguageModelProviderState for FakeLanguageModelProvider {
32 fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
33 None
34 }
35}
36
37impl LanguageModelProvider for FakeLanguageModelProvider {
38 fn name(&self) -> LanguageModelProviderName {
39 provider_name()
40 }
41
42 fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
43 vec![Arc::new(FakeLanguageModel {
44 current_completion_txs: self.current_completion_txs.clone(),
45 })]
46 }
47
48 fn is_authenticated(&self, _: &AppContext) -> bool {
49 true
50 }
51
52 fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
53 Task::ready(Ok(()))
54 }
55
56 fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
57 unimplemented!()
58 }
59
60 fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
61 Task::ready(Ok(()))
62 }
63}
64
65impl FakeLanguageModelProvider {
66 pub fn test_model(&self) -> FakeLanguageModel {
67 FakeLanguageModel {
68 current_completion_txs: self.current_completion_txs.clone(),
69 }
70 }
71}
72
73pub struct FakeLanguageModel {
74 current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
75}
76
77impl FakeLanguageModel {
78 pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
79 self.current_completion_txs
80 .lock()
81 .unwrap()
82 .keys()
83 .map(|k| serde_json::from_str(k).unwrap())
84 .collect()
85 }
86
87 pub fn completion_count(&self) -> usize {
88 self.current_completion_txs.lock().unwrap().len()
89 }
90
91 pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
92 let json = serde_json::to_string(request).unwrap();
93 self.current_completion_txs
94 .lock()
95 .unwrap()
96 .get(&json)
97 .unwrap()
98 .unbounded_send(chunk)
99 .unwrap();
100 }
101
102 pub fn send_last_completion_chunk(&self, chunk: String) {
103 self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
104 }
105
106 pub fn finish_completion(&self, request: &LanguageModelRequest) {
107 self.current_completion_txs
108 .lock()
109 .unwrap()
110 .remove(&serde_json::to_string(request).unwrap())
111 .unwrap();
112 }
113
114 pub fn finish_last_completion(&self) {
115 self.finish_completion(self.pending_completions().last().unwrap());
116 }
117}
118
119impl LanguageModel for FakeLanguageModel {
120 fn id(&self) -> LanguageModelId {
121 language_model_id()
122 }
123
124 fn name(&self) -> LanguageModelName {
125 language_model_name()
126 }
127
128 fn provider_name(&self) -> LanguageModelProviderName {
129 provider_name()
130 }
131
132 fn telemetry_id(&self) -> String {
133 "fake".to_string()
134 }
135
136 fn max_token_count(&self) -> usize {
137 1000000
138 }
139
140 fn count_tokens(
141 &self,
142 _: LanguageModelRequest,
143 _: &AppContext,
144 ) -> BoxFuture<'static, Result<usize>> {
145 futures::future::ready(Ok(0)).boxed()
146 }
147
148 fn stream_completion(
149 &self,
150 request: LanguageModelRequest,
151 _: &AsyncAppContext,
152 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
153 let (tx, rx) = mpsc::unbounded();
154 self.current_completion_txs
155 .lock()
156 .unwrap()
157 .insert(serde_json::to_string(&request).unwrap(), tx);
158 async move { Ok(rx.map(Ok).boxed()) }.boxed()
159 }
160}