1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
3use gpui::{AppContext, Global, Model, ModelContext, Task};
4use language_model::{
5 LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
6 LanguageModelRequest,
7};
8use smol::lock::{Semaphore, SemaphoreGuardArc};
9use std::{pin::Pin, sync::Arc, task::Poll};
10use ui::Context;
11
12pub fn init(cx: &mut AppContext) {
13 let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
14 cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
15}
16
17struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
18
19impl Global for GlobalLanguageModelCompletionProvider {}
20
21pub struct LanguageModelCompletionProvider {
22 active_provider: Option<Arc<dyn LanguageModelProvider>>,
23 active_model: Option<Arc<dyn LanguageModel>>,
24 request_limiter: Arc<Semaphore>,
25}
26
27const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
28
29pub struct LanguageModelCompletionResponse {
30 inner: BoxStream<'static, Result<String>>,
31 _lock: SemaphoreGuardArc,
32}
33
34impl futures::Stream for LanguageModelCompletionResponse {
35 type Item = Result<String>;
36
37 fn poll_next(
38 mut self: Pin<&mut Self>,
39 cx: &mut std::task::Context<'_>,
40 ) -> Poll<Option<Self::Item>> {
41 Pin::new(&mut self.inner).poll_next(cx)
42 }
43}
44
45impl LanguageModelCompletionProvider {
46 pub fn global(cx: &AppContext) -> Model<Self> {
47 cx.global::<GlobalLanguageModelCompletionProvider>()
48 .0
49 .clone()
50 }
51
52 pub fn read_global(cx: &AppContext) -> &Self {
53 cx.global::<GlobalLanguageModelCompletionProvider>()
54 .0
55 .read(cx)
56 }
57
58 #[cfg(any(test, feature = "test-support"))]
59 pub fn test(cx: &mut AppContext) {
60 let provider = cx.new_model(|cx| {
61 let mut this = Self::new(cx);
62 let available_model = LanguageModelRegistry::read_global(cx)
63 .available_models(cx)
64 .first()
65 .unwrap()
66 .clone();
67 this.set_active_model(available_model, cx);
68 this
69 });
70 cx.set_global(GlobalLanguageModelCompletionProvider(provider));
71 }
72
73 pub fn new(cx: &mut ModelContext<Self>) -> Self {
74 cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
75 cx.notify();
76 })
77 .detach();
78
79 Self {
80 active_provider: None,
81 active_model: None,
82 request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
83 }
84 }
85
86 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
87 self.active_provider.clone()
88 }
89
90 pub fn set_active_provider(
91 &mut self,
92 provider_id: LanguageModelProviderId,
93 cx: &mut ModelContext<Self>,
94 ) {
95 self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
96 self.active_model = None;
97 cx.notify();
98 }
99
100 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
101 self.active_model.clone()
102 }
103
104 pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
105 if self.active_model.as_ref().map_or(false, |m| {
106 m.id() == model.id() && m.provider_id() == model.provider_id()
107 }) {
108 return;
109 }
110
111 self.active_provider =
112 LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
113 self.active_model = Some(model.clone());
114
115 if let Some(provider) = self.active_provider.as_ref() {
116 provider.load_model(model, cx);
117 }
118
119 cx.notify();
120 }
121
122 pub fn is_authenticated(&self, cx: &AppContext) -> bool {
123 self.active_provider
124 .as_ref()
125 .map_or(false, |provider| provider.is_authenticated(cx))
126 }
127
128 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
129 self.active_provider
130 .as_ref()
131 .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
132 }
133
134 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
135 self.active_provider
136 .as_ref()
137 .map_or(Task::ready(Ok(())), |provider| {
138 provider.reset_credentials(cx)
139 })
140 }
141
142 pub fn count_tokens(
143 &self,
144 request: LanguageModelRequest,
145 cx: &AppContext,
146 ) -> Option<BoxFuture<'static, Result<usize>>> {
147 if let Some(model) = self.active_model() {
148 Some(model.count_tokens(request, cx))
149 } else {
150 None
151 }
152 }
153
154 pub fn stream_completion(
155 &self,
156 request: LanguageModelRequest,
157 cx: &AppContext,
158 ) -> Task<Result<LanguageModelCompletionResponse>> {
159 if let Some(language_model) = self.active_model() {
160 let rate_limiter = self.request_limiter.clone();
161 cx.spawn(|cx| async move {
162 let lock = rate_limiter.acquire_arc().await;
163 let response = language_model.stream_completion(request, &cx).await?;
164 Ok(LanguageModelCompletionResponse {
165 inner: response,
166 _lock: lock,
167 })
168 })
169 } else {
170 Task::ready(Err(anyhow!("No active model set")))
171 }
172 }
173
174 pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
175 let response = self.stream_completion(request, cx);
176 cx.foreground_executor().spawn(async move {
177 let mut chunks = response.await?;
178 let mut completion = String::new();
179 while let Some(chunk) = chunks.next().await {
180 let chunk = chunk?;
181 completion.push_str(&chunk);
182 }
183 Ok(completion)
184 })
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use futures::StreamExt;
191 use gpui::AppContext;
192 use settings::SettingsStore;
193 use ui::Context;
194
195 use crate::{
196 LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
197 };
198
199 use language_model::LanguageModelRegistry;
200
201 #[gpui::test]
202 fn test_rate_limiting(cx: &mut AppContext) {
203 SettingsStore::test(cx);
204 let fake_provider = LanguageModelRegistry::test(cx);
205
206 let model = LanguageModelRegistry::read_global(cx)
207 .available_models(cx)
208 .first()
209 .cloned()
210 .unwrap();
211
212 let provider = cx.new_model(|cx| {
213 let mut provider = LanguageModelCompletionProvider::new(cx);
214 provider.set_active_model(model.clone(), cx);
215 provider
216 });
217
218 let fake_model = fake_provider.test_model();
219
220 // Enqueue some requests
221 for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
222 let response = provider.read(cx).stream_completion(
223 LanguageModelRequest {
224 temperature: i as f32 / 10.0,
225 ..Default::default()
226 },
227 cx,
228 );
229 cx.background_executor()
230 .spawn(async move {
231 let mut stream = response.await.unwrap();
232 while let Some(message) = stream.next().await {
233 message.unwrap();
234 }
235 })
236 .detach();
237 }
238 cx.background_executor().run_until_parked();
239 assert_eq!(
240 fake_model.completion_count(),
241 MAX_CONCURRENT_COMPLETION_REQUESTS
242 );
243
244 // Get the first completion request that is in flight and mark it as completed.
245 let completion = fake_model.pending_completions().into_iter().next().unwrap();
246 fake_model.finish_completion(&completion);
247
248 // Ensure that the number of in-flight completion requests is reduced.
249 assert_eq!(
250 fake_model.completion_count(),
251 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
252 );
253
254 cx.background_executor().run_until_parked();
255
256 // Ensure that another completion request was allowed to acquire the lock.
257 assert_eq!(
258 fake_model.completion_count(),
259 MAX_CONCURRENT_COMPLETION_REQUESTS
260 );
261
262 // Mark all completion requests as finished that are in flight.
263 for request in fake_model.pending_completions() {
264 fake_model.finish_completion(&request);
265 }
266
267 assert_eq!(fake_model.completion_count(), 0);
268
269 // Wait until the background tasks acquire the lock again.
270 cx.background_executor().run_until_parked();
271
272 assert_eq!(
273 fake_model.completion_count(),
274 MAX_CONCURRENT_COMPLETION_REQUESTS - 1
275 );
276
277 // Finish all remaining completion requests.
278 for request in fake_model.pending_completions() {
279 fake_model.finish_completion(&request);
280 }
281
282 cx.background_executor().run_until_parked();
283
284 assert_eq!(fake_model.completion_count(), 0);
285 }
286}