1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AppContext, Global, Model, ModelContext, Task};
4use language_model::{
5 LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
6 LanguageModelRequest,
7};
8use smol::lock::{Semaphore, SemaphoreGuardArc};
9use std::{pin::Pin, sync::Arc, task::Poll};
10use ui::Context;
11
12pub fn init(cx: &mut AppContext) {
13 let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
14 cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
15}
16
17struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
18
19impl Global for GlobalLanguageModelCompletionProvider {}
20
21pub struct LanguageModelCompletionProvider {
22 active_provider: Option<Arc<dyn LanguageModelProvider>>,
23 active_model: Option<Arc<dyn LanguageModel>>,
24 request_limiter: Arc<Semaphore>,
25}
26
27const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
28
29pub struct LanguageModelCompletionResponse {
30 pub inner: BoxStream<'static, Result<String>>,
31 _lock: SemaphoreGuardArc,
32}
33
34impl futures::Stream for LanguageModelCompletionResponse {
35 type Item = Result<String>;
36
37 fn poll_next(
38 mut self: Pin<&mut Self>,
39 cx: &mut std::task::Context<'_>,
40 ) -> Poll<Option<Self::Item>> {
41 Pin::new(&mut self.inner).poll_next(cx)
42 }
43}
44
45impl LanguageModelCompletionProvider {
46 pub fn global(cx: &AppContext) -> Model<Self> {
47 cx.global::<GlobalLanguageModelCompletionProvider>()
48 .0
49 .clone()
50 }
51
52 pub fn read_global(cx: &AppContext) -> &Self {
53 cx.global::<GlobalLanguageModelCompletionProvider>()
54 .0
55 .read(cx)
56 }
57
58 #[cfg(any(test, feature = "test-support"))]
59 pub fn test(cx: &mut AppContext) {
60 let provider = cx.new_model(|cx| {
61 let mut this = Self::new(cx);
62 let available_model = LanguageModelRegistry::read_global(cx)
63 .available_models(cx)
64 .first()
65 .unwrap()
66 .clone();
67 this.set_active_model(available_model, cx);
68 this
69 });
70 cx.set_global(GlobalLanguageModelCompletionProvider(provider));
71 }
72
73 pub fn new(cx: &mut ModelContext<Self>) -> Self {
74 cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
75 cx.notify();
76 })
77 .detach();
78
79 Self {
80 active_provider: None,
81 active_model: None,
82 request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
83 }
84 }
85
86 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
87 self.active_provider.clone()
88 }
89
90 pub fn set_active_provider(
91 &mut self,
92 provider_name: LanguageModelProviderName,
93 cx: &mut ModelContext<Self>,
94 ) {
95 self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
96 self.active_model = None;
97 cx.notify();
98 }
99
100 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
101 self.active_model.clone()
102 }
103
104 pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
105 if self.active_model.as_ref().map_or(false, |m| {
106 m.id() == model.id() && m.provider_name() == model.provider_name()
107 }) {
108 return;
109 }
110
111 self.active_provider =
112 LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
113 self.active_model = Some(model);
114 cx.notify();
115 }
116
117 pub fn is_authenticated(&self, cx: &AppContext) -> bool {
118 self.active_provider
119 .as_ref()
120 .map_or(false, |provider| provider.is_authenticated(cx))
121 }
122
123 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
124 self.active_provider
125 .as_ref()
126 .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
127 }
128
129 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
130 self.active_provider
131 .as_ref()
132 .map_or(Task::ready(Ok(())), |provider| {
133 provider.reset_credentials(cx)
134 })
135 }
136
137 pub fn count_tokens(
138 &self,
139 request: LanguageModelRequest,
140 cx: &AppContext,
141 ) -> BoxFuture<'static, Result<usize>> {
142 if let Some(model) = self.active_model() {
143 model.count_tokens(request, cx)
144 } else {
145 std::future::ready(Err(anyhow!("No active model set"))).boxed()
146 }
147 }
148
149 pub fn stream_completion(
150 &self,
151 request: LanguageModelRequest,
152 cx: &AppContext,
153 ) -> Task<Result<LanguageModelCompletionResponse>> {
154 if let Some(language_model) = self.active_model() {
155 let rate_limiter = self.request_limiter.clone();
156 cx.spawn(|cx| async move {
157 let lock = rate_limiter.acquire_arc().await;
158 let response = language_model.stream_completion(request, &cx).await?;
159 Ok(LanguageModelCompletionResponse {
160 inner: response,
161 _lock: lock,
162 })
163 })
164 } else {
165 Task::ready(Err(anyhow!("No active model set")))
166 }
167 }
168
169 pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
170 let response = self.stream_completion(request, cx);
171 cx.foreground_executor().spawn(async move {
172 let mut chunks = response.await?;
173 let mut completion = String::new();
174 while let Some(chunk) = chunks.next().await {
175 let chunk = chunk?;
176 completion.push_str(&chunk);
177 }
178 Ok(completion)
179 })
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use futures::StreamExt;
186 use gpui::AppContext;
187 use settings::SettingsStore;
188 use ui::Context;
189
190 use crate::{
191 LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
192 };
193
194 use language_model::LanguageModelRegistry;
195
196 #[gpui::test]
197 fn test_rate_limiting(cx: &mut AppContext) {
198 SettingsStore::test(cx);
199 let fake_provider = LanguageModelRegistry::test(cx);
200
201 let model = LanguageModelRegistry::read_global(cx)
202 .available_models(cx)
203 .first()
204 .cloned()
205 .unwrap();
206
207 let provider = cx.new_model(|cx| {
208 let mut provider = LanguageModelCompletionProvider::new(cx);
209 provider.set_active_model(model.clone(), cx);
210 provider
211 });
212
213 let fake_model = fake_provider.test_model();
214
215 // Enqueue some requests
216 for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
217 let response = provider.read(cx).stream_completion(
218 LanguageModelRequest {
219 temperature: i as f32 / 10.0,
220 ..Default::default()
221 },
222 cx,
223 );
224 cx.background_executor()
225 .spawn(async move {
226 let mut stream = response.await.unwrap();
227 while let Some(message) = stream.next().await {
228 message.unwrap();
229 }
230 })
231 .detach();
232 }
233 cx.background_executor().run_until_parked();
234 assert_eq!(
235 fake_model.completion_count(),
236 MAX_CONCURRENT_COMPLETION_REQUESTS
237 );
238
239 // Get the first completion request that is in flight and mark it as completed.
240 let completion = fake_model.pending_completions().into_iter().next().unwrap();
241 fake_model.finish_completion(&completion);
242
243 // Ensure that the number of in-flight completion requests is reduced.
244 assert_eq!(
245 fake_model.completion_count(),
246 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
247 );
248
249 cx.background_executor().run_until_parked();
250
251 // Ensure that another completion request was allowed to acquire the lock.
252 assert_eq!(
253 fake_model.completion_count(),
254 MAX_CONCURRENT_COMPLETION_REQUESTS
255 );
256
257 // Mark all completion requests as finished that are in flight.
258 for request in fake_model.pending_completions() {
259 fake_model.finish_completion(&request);
260 }
261
262 assert_eq!(fake_model.completion_count(), 0);
263
264 // Wait until the background tasks acquire the lock again.
265 cx.background_executor().run_until_parked();
266
267 assert_eq!(
268 fake_model.completion_count(),
269 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
270 );
271
272 // Finish all remaining completion requests.
273 for request in fake_model.pending_completions() {
274 fake_model.finish_completion(&request);
275 }
276
277 cx.background_executor().run_until_parked();
278
279 assert_eq!(fake_model.completion_count(), 0);
280 }
281}