1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
3use gpui::{AppContext, Global, Model, ModelContext, Task};
4use language_model::{
5 LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
6 LanguageModelRequest, LanguageModelTool,
7};
8use smol::{
9 future::FutureExt,
10 lock::{Semaphore, SemaphoreGuardArc},
11};
12use std::{future, pin::Pin, sync::Arc, task::Poll};
13use ui::Context;
14
15pub fn init(cx: &mut AppContext) {
16 let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
17 cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
18}
19
20struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
21
22impl Global for GlobalLanguageModelCompletionProvider {}
23
24pub struct LanguageModelCompletionProvider {
25 active_provider: Option<Arc<dyn LanguageModelProvider>>,
26 active_model: Option<Arc<dyn LanguageModel>>,
27 request_limiter: Arc<Semaphore>,
28}
29
30const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
31
32pub struct LanguageModelCompletionResponse {
33 inner: BoxStream<'static, Result<String>>,
34 _lock: SemaphoreGuardArc,
35}
36
37impl futures::Stream for LanguageModelCompletionResponse {
38 type Item = Result<String>;
39
40 fn poll_next(
41 mut self: Pin<&mut Self>,
42 cx: &mut std::task::Context<'_>,
43 ) -> Poll<Option<Self::Item>> {
44 Pin::new(&mut self.inner).poll_next(cx)
45 }
46}
47
48impl LanguageModelCompletionProvider {
49 pub fn global(cx: &AppContext) -> Model<Self> {
50 cx.global::<GlobalLanguageModelCompletionProvider>()
51 .0
52 .clone()
53 }
54
55 pub fn read_global(cx: &AppContext) -> &Self {
56 cx.global::<GlobalLanguageModelCompletionProvider>()
57 .0
58 .read(cx)
59 }
60
61 #[cfg(any(test, feature = "test-support"))]
62 pub fn test(cx: &mut AppContext) {
63 let provider = cx.new_model(|cx| {
64 let mut this = Self::new(cx);
65 let available_model = LanguageModelRegistry::read_global(cx)
66 .available_models(cx)
67 .first()
68 .unwrap()
69 .clone();
70 this.set_active_model(available_model, cx);
71 this
72 });
73 cx.set_global(GlobalLanguageModelCompletionProvider(provider));
74 }
75
76 pub fn new(cx: &mut ModelContext<Self>) -> Self {
77 cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
78 cx.notify();
79 })
80 .detach();
81
82 Self {
83 active_provider: None,
84 active_model: None,
85 request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
86 }
87 }
88
89 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
90 self.active_provider.clone()
91 }
92
93 pub fn set_active_provider(
94 &mut self,
95 provider_id: LanguageModelProviderId,
96 cx: &mut ModelContext<Self>,
97 ) {
98 self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
99 self.active_model = None;
100 cx.notify();
101 }
102
103 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
104 self.active_model.clone()
105 }
106
107 pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
108 if self.active_model.as_ref().map_or(false, |m| {
109 m.id() == model.id() && m.provider_id() == model.provider_id()
110 }) {
111 return;
112 }
113
114 self.active_provider =
115 LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
116 self.active_model = Some(model.clone());
117
118 if let Some(provider) = self.active_provider.as_ref() {
119 provider.load_model(model, cx);
120 }
121
122 cx.notify();
123 }
124
125 pub fn is_authenticated(&self, cx: &AppContext) -> bool {
126 self.active_provider
127 .as_ref()
128 .map_or(false, |provider| provider.is_authenticated(cx))
129 }
130
131 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
132 self.active_provider
133 .as_ref()
134 .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
135 }
136
137 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
138 self.active_provider
139 .as_ref()
140 .map_or(Task::ready(Ok(())), |provider| {
141 provider.reset_credentials(cx)
142 })
143 }
144
145 pub fn count_tokens(
146 &self,
147 request: LanguageModelRequest,
148 cx: &AppContext,
149 ) -> BoxFuture<'static, Result<usize>> {
150 if let Some(model) = self.active_model() {
151 model.count_tokens(request, cx)
152 } else {
153 future::ready(Err(anyhow!("no active model"))).boxed()
154 }
155 }
156
157 pub fn stream_completion(
158 &self,
159 request: LanguageModelRequest,
160 cx: &AppContext,
161 ) -> Task<Result<LanguageModelCompletionResponse>> {
162 if let Some(language_model) = self.active_model() {
163 let rate_limiter = self.request_limiter.clone();
164 cx.spawn(|cx| async move {
165 let lock = rate_limiter.acquire_arc().await;
166 let response = language_model.stream_completion(request, &cx).await?;
167 Ok(LanguageModelCompletionResponse {
168 inner: response,
169 _lock: lock,
170 })
171 })
172 } else {
173 Task::ready(Err(anyhow!("No active model set")))
174 }
175 }
176
177 pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
178 let response = self.stream_completion(request, cx);
179 cx.foreground_executor().spawn(async move {
180 let mut chunks = response.await?;
181 let mut completion = String::new();
182 while let Some(chunk) = chunks.next().await {
183 let chunk = chunk?;
184 completion.push_str(&chunk);
185 }
186 Ok(completion)
187 })
188 }
189
190 pub fn use_tool<T: LanguageModelTool>(
191 &self,
192 request: LanguageModelRequest,
193 cx: &AppContext,
194 ) -> Task<Result<T>> {
195 if let Some(language_model) = self.active_model() {
196 cx.spawn(|cx| async move {
197 let schema = schemars::schema_for!(T);
198 let schema_json = serde_json::to_value(&schema).unwrap();
199 let request =
200 language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
201 let response = request.await?;
202 Ok(serde_json::from_value(response)?)
203 })
204 } else {
205 Task::ready(Err(anyhow!("No active model set")))
206 }
207 }
208
209 pub fn active_model_telemetry_id(&self) -> Option<String> {
210 self.active_model.as_ref().map(|m| m.telemetry_id())
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use futures::StreamExt;
217 use gpui::AppContext;
218 use settings::SettingsStore;
219 use ui::Context;
220
221 use crate::{
222 LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
223 };
224
225 use language_model::LanguageModelRegistry;
226
227 #[gpui::test]
228 fn test_rate_limiting(cx: &mut AppContext) {
229 SettingsStore::test(cx);
230 let fake_provider = LanguageModelRegistry::test(cx);
231
232 let model = LanguageModelRegistry::read_global(cx)
233 .available_models(cx)
234 .first()
235 .cloned()
236 .unwrap();
237
238 let provider = cx.new_model(|cx| {
239 let mut provider = LanguageModelCompletionProvider::new(cx);
240 provider.set_active_model(model.clone(), cx);
241 provider
242 });
243
244 let fake_model = fake_provider.test_model();
245
246 // Enqueue some requests
247 for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
248 let response = provider.read(cx).stream_completion(
249 LanguageModelRequest {
250 temperature: i as f32 / 10.0,
251 ..Default::default()
252 },
253 cx,
254 );
255 cx.background_executor()
256 .spawn(async move {
257 let mut stream = response.await.unwrap();
258 while let Some(message) = stream.next().await {
259 message.unwrap();
260 }
261 })
262 .detach();
263 }
264 cx.background_executor().run_until_parked();
265 assert_eq!(
266 fake_model.completion_count(),
267 MAX_CONCURRENT_COMPLETION_REQUESTS
268 );
269
270 // Get the first completion request that is in flight and mark it as completed.
271 let completion = fake_model.pending_completions().into_iter().next().unwrap();
272 fake_model.finish_completion(&completion);
273
274 // Ensure that the number of in-flight completion requests is reduced.
275 assert_eq!(
276 fake_model.completion_count(),
277 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
278 );
279
280 cx.background_executor().run_until_parked();
281
282 // Ensure that another completion request was allowed to acquire the lock.
283 assert_eq!(
284 fake_model.completion_count(),
285 MAX_CONCURRENT_COMPLETION_REQUESTS
286 );
287
288 // Mark all completion requests as finished that are in flight.
289 for request in fake_model.pending_completions() {
290 fake_model.finish_completion(&request);
291 }
292
293 assert_eq!(fake_model.completion_count(), 0);
294
295 // Wait until the background tasks acquire the lock again.
296 cx.background_executor().run_until_parked();
297
298 assert_eq!(
299 fake_model.completion_count(),
300 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
301 );
302
303 // Finish all remaining completion requests.
304 for request in fake_model.pending_completions() {
305 fake_model.finish_completion(&request);
306 }
307
308 cx.background_executor().run_until_parked();
309
310 assert_eq!(fake_model.completion_count(), 0);
311 }
312}