1mod anthropic;
2mod cloud;
3#[cfg(any(test, feature = "test-support"))]
4mod fake;
5mod ollama;
6mod open_ai;
7
8pub use anthropic::*;
9use anyhow::Result;
10use client::Client;
11pub use cloud::*;
12#[cfg(any(test, feature = "test-support"))]
13pub use fake::*;
14use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
15use gpui::{AnyView, AppContext, Task, WindowContext};
16use language_model::{LanguageModel, LanguageModelRequest};
17pub use ollama::*;
18pub use open_ai::*;
19use parking_lot::RwLock;
20use smol::lock::{Semaphore, SemaphoreGuardArc};
21use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
22
23pub struct CompletionResponse {
24 inner: BoxStream<'static, Result<String>>,
25 _lock: SemaphoreGuardArc,
26}
27
28impl futures::Stream for CompletionResponse {
29 type Item = Result<String>;
30
31 fn poll_next(
32 mut self: Pin<&mut Self>,
33 cx: &mut std::task::Context<'_>,
34 ) -> Poll<Option<Self::Item>> {
35 Pin::new(&mut self.inner).poll_next(cx)
36 }
37}
38
39pub trait LanguageModelCompletionProvider: Send + Sync {
40 fn available_models(&self) -> Vec<LanguageModel>;
41 fn settings_version(&self) -> usize;
42 fn is_authenticated(&self) -> bool;
43 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
44 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
45 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
46 fn model(&self) -> LanguageModel;
47 fn count_tokens(
48 &self,
49 request: LanguageModelRequest,
50 cx: &AppContext,
51 ) -> BoxFuture<'static, Result<usize>>;
52 fn stream_completion(
53 &self,
54 request: LanguageModelRequest,
55 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
56
57 fn as_any_mut(&mut self) -> &mut dyn Any;
58}
59
60const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
61
62pub struct CompletionProvider {
63 provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
64 client: Option<Arc<Client>>,
65 request_limiter: Arc<Semaphore>,
66}
67
68impl CompletionProvider {
69 pub fn new(
70 provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
71 client: Option<Arc<Client>>,
72 ) -> Self {
73 Self {
74 provider,
75 client,
76 request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
77 }
78 }
79
80 pub fn available_models(&self) -> Vec<LanguageModel> {
81 self.provider.read().available_models()
82 }
83
84 pub fn settings_version(&self) -> usize {
85 self.provider.read().settings_version()
86 }
87
88 pub fn is_authenticated(&self) -> bool {
89 self.provider.read().is_authenticated()
90 }
91
92 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
93 self.provider.read().authenticate(cx)
94 }
95
96 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
97 self.provider.read().authentication_prompt(cx)
98 }
99
100 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
101 self.provider.read().reset_credentials(cx)
102 }
103
104 pub fn model(&self) -> LanguageModel {
105 self.provider.read().model()
106 }
107
108 pub fn count_tokens(
109 &self,
110 request: LanguageModelRequest,
111 cx: &AppContext,
112 ) -> BoxFuture<'static, Result<usize>> {
113 self.provider.read().count_tokens(request, cx)
114 }
115
116 pub fn stream_completion(
117 &self,
118 request: LanguageModelRequest,
119 cx: &AppContext,
120 ) -> Task<Result<CompletionResponse>> {
121 let rate_limiter = self.request_limiter.clone();
122 let provider = self.provider.clone();
123 cx.foreground_executor().spawn(async move {
124 let lock = rate_limiter.acquire_arc().await;
125 let response = provider.read().stream_completion(request);
126 let response = response.await?;
127 Ok(CompletionResponse {
128 inner: response,
129 _lock: lock,
130 })
131 })
132 }
133
134 pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
135 let response = self.stream_completion(request, cx);
136 cx.foreground_executor().spawn(async move {
137 let mut chunks = response.await?;
138 let mut completion = String::new();
139 while let Some(chunk) = chunks.next().await {
140 let chunk = chunk?;
141 completion.push_str(&chunk);
142 }
143 Ok(completion)
144 })
145 }
146
147 pub fn update_provider(
148 &mut self,
149 get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
150 ) {
151 if let Some(client) = &self.client {
152 self.provider = get_provider(Arc::clone(client));
153 } else {
154 log::warn!("completion provider cannot be updated because its client was not set");
155 }
156 }
157}
158
159impl gpui::Global for CompletionProvider {}
160
161impl CompletionProvider {
162 pub fn global(cx: &AppContext) -> &Self {
163 cx.global::<Self>()
164 }
165
166 pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
167 &mut self,
168 update: impl FnOnce(&mut T) -> R,
169 ) -> Option<R> {
170 let mut provider = self.provider.write();
171 if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
172 Some(update(provider))
173 } else {
174 None
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::sync::Arc;
182
183 use gpui::AppContext;
184 use parking_lot::RwLock;
185 use settings::SettingsStore;
186 use smol::stream::StreamExt;
187
188 use crate::{
189 CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
190 MAX_CONCURRENT_COMPLETION_REQUESTS,
191 };
192
193 #[gpui::test]
194 fn test_rate_limiting(cx: &mut AppContext) {
195 SettingsStore::test(cx);
196 let fake_provider = FakeCompletionProvider::setup_test(cx);
197
198 let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
199
200 // Enqueue some requests
201 for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
202 let response = provider.stream_completion(
203 LanguageModelRequest {
204 temperature: i as f32 / 10.0,
205 ..Default::default()
206 },
207 cx,
208 );
209 cx.background_executor()
210 .spawn(async move {
211 let mut stream = response.await.unwrap();
212 while let Some(message) = stream.next().await {
213 message.unwrap();
214 }
215 })
216 .detach();
217 }
218 cx.background_executor().run_until_parked();
219
220 assert_eq!(
221 fake_provider.completion_count(),
222 MAX_CONCURRENT_COMPLETION_REQUESTS
223 );
224
225 // Get the first completion request that is in flight and mark it as completed.
226 let completion = fake_provider
227 .pending_completions()
228 .into_iter()
229 .next()
230 .unwrap();
231 fake_provider.finish_completion(&completion);
232
233 // Ensure that the number of in-flight completion requests is reduced.
234 assert_eq!(
235 fake_provider.completion_count(),
236 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
237 );
238
239 cx.background_executor().run_until_parked();
240
241 // Ensure that another completion request was allowed to acquire the lock.
242 assert_eq!(
243 fake_provider.completion_count(),
244 MAX_CONCURRENT_COMPLETION_REQUESTS
245 );
246
247 // Mark all completion requests as finished that are in flight.
248 for request in fake_provider.pending_completions() {
249 fake_provider.finish_completion(&request);
250 }
251
252 assert_eq!(fake_provider.completion_count(), 0);
253
254 // Wait until the background tasks acquire the lock again.
255 cx.background_executor().run_until_parked();
256
257 assert_eq!(
258 fake_provider.completion_count(),
259 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
260 );
261
262 // Finish all remaining completion requests.
263 for request in fake_provider.pending_completions() {
264 fake_provider.finish_completion(&request);
265 }
266
267 cx.background_executor().run_until_parked();
268
269 assert_eq!(fake_provider.completion_count(), 0);
270 }
271}