1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9use util::maybe;
10
11pub fn init(cx: &mut App) {
12 let registry = cx.new(|_cx| LanguageModelRegistry::default());
13 cx.set_global(GlobalLanguageModelRegistry(registry));
14}
15
16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
17
18impl Global for GlobalLanguageModelRegistry {}
19
20#[derive(Error)]
21pub enum ConfigurationError {
22 #[error("Configure at least one LLM provider to start using the panel.")]
23 NoProvider,
24 #[error("LLM Provider is not configured or does not support the configured model.")]
25 ModelNotFound,
26 #[error("{} LLM provider is not configured.", .0.name().0)]
27 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
28 #[error("Using the {} LLM provider requires accepting the Terms of Service.",
29 .0.name().0)]
30 ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
31}
32
33impl std::fmt::Debug for ConfigurationError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::NoProvider => write!(f, "NoProvider"),
37 Self::ModelNotFound => write!(f, "ModelNotFound"),
38 Self::ProviderNotAuthenticated(provider) => {
39 write!(f, "ProviderNotAuthenticated({})", provider.id())
40 }
41 Self::ProviderPendingTermsAcceptance(provider) => {
42 write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
43 }
44 }
45 }
46}
47
48#[derive(Default)]
49pub struct LanguageModelRegistry {
50 default_model: Option<ConfiguredModel>,
51 default_fast_model: Option<ConfiguredModel>,
52 inline_assistant_model: Option<ConfiguredModel>,
53 commit_message_model: Option<ConfiguredModel>,
54 thread_summary_model: Option<ConfiguredModel>,
55 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
56 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
57}
58
59#[derive(Debug)]
60pub struct SelectedModel {
61 pub provider: LanguageModelProviderId,
62 pub model: LanguageModelId,
63}
64
65impl FromStr for SelectedModel {
66 type Err = String;
67
68 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
69 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
70 let parts: Vec<&str> = id.split('/').collect();
71 let [provider_id, model_id] = parts.as_slice() else {
72 return Err(format!(
73 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
74 id
75 ));
76 };
77
78 if provider_id.is_empty() || model_id.is_empty() {
79 return Err(format!("Provider and model ids can't be empty: `{}`", id));
80 }
81
82 Ok(SelectedModel {
83 provider: LanguageModelProviderId(provider_id.to_string().into()),
84 model: LanguageModelId(model_id.to_string().into()),
85 })
86 }
87}
88
89#[derive(Clone)]
90pub struct ConfiguredModel {
91 pub provider: Arc<dyn LanguageModelProvider>,
92 pub model: Arc<dyn LanguageModel>,
93}
94
95impl ConfiguredModel {
96 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
97 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
98 }
99
100 pub fn is_provided_by_zed(&self) -> bool {
101 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
102 }
103}
104
105pub enum Event {
106 DefaultModelChanged,
107 InlineAssistantModelChanged,
108 CommitMessageModelChanged,
109 ThreadSummaryModelChanged,
110 ProviderStateChanged,
111 ProviderAuthUpdated,
112 AddedProvider(LanguageModelProviderId),
113 RemovedProvider(LanguageModelProviderId),
114}
115
116impl EventEmitter<Event> for LanguageModelRegistry {}
117
118impl LanguageModelRegistry {
119 pub fn global(cx: &App) -> Entity<Self> {
120 cx.global::<GlobalLanguageModelRegistry>().0.clone()
121 }
122
123 pub fn read_global(cx: &App) -> &Self {
124 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
125 }
126
127 #[cfg(any(test, feature = "test-support"))]
128 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
129 let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
130 let registry = cx.new(|cx| {
131 let mut registry = Self::default();
132 registry.register_provider(fake_provider.clone(), cx);
133 let model = fake_provider.provided_models(cx)[0].clone();
134 let configured_model = ConfiguredModel {
135 provider: Arc::new(fake_provider.clone()),
136 model,
137 };
138 registry.set_default_model(Some(configured_model), cx);
139 registry
140 });
141 cx.set_global(GlobalLanguageModelRegistry(registry));
142 fake_provider
143 }
144
145 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
146 &mut self,
147 provider: T,
148 cx: &mut Context<Self>,
149 ) {
150 let id = provider.id();
151
152 let subscription = provider.subscribe(cx, |_, cx| {
153 cx.emit(Event::ProviderStateChanged);
154 });
155 if let Some(subscription) = subscription {
156 subscription.detach();
157 }
158
159 self.providers.insert(id.clone(), Arc::new(provider));
160 cx.emit(Event::AddedProvider(id));
161 }
162
163 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
164 if self.providers.remove(&id).is_some() {
165 cx.emit(Event::RemovedProvider(id));
166 }
167 }
168
169 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
170 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
171 let mut providers = Vec::with_capacity(self.providers.len());
172 if let Some(provider) = self.providers.get(&zed_provider_id) {
173 providers.push(provider.clone());
174 }
175 providers.extend(self.providers.values().filter_map(|p| {
176 if p.id() != zed_provider_id {
177 Some(p.clone())
178 } else {
179 None
180 }
181 }));
182 providers
183 }
184
185 pub fn configuration_error(
186 &self,
187 model: Option<ConfiguredModel>,
188 cx: &App,
189 ) -> Option<ConfigurationError> {
190 let Some(model) = model else {
191 if !self.has_authenticated_provider(cx) {
192 return Some(ConfigurationError::NoProvider);
193 }
194 return Some(ConfigurationError::ModelNotFound);
195 };
196
197 if !model.provider.is_authenticated(cx) {
198 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
199 }
200
201 if model.provider.must_accept_terms(cx) {
202 return Some(ConfigurationError::ProviderPendingTermsAcceptance(
203 model.provider,
204 ));
205 }
206
207 None
208 }
209
210 /// Returns `true` if at least one provider that is authenticated.
211 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
212 self.providers.values().any(|p| p.is_authenticated(cx))
213 }
214
215 pub fn available_models<'a>(
216 &'a self,
217 cx: &'a App,
218 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
219 self.providers
220 .values()
221 .flat_map(|provider| provider.provided_models(cx))
222 }
223
224 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
225 self.providers.get(id).cloned()
226 }
227
228 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
229 let configured_model = model.and_then(|model| self.select_model(model, cx));
230 self.set_default_model(configured_model, cx);
231 }
232
233 pub fn select_inline_assistant_model(
234 &mut self,
235 model: Option<&SelectedModel>,
236 cx: &mut Context<Self>,
237 ) {
238 let configured_model = model.and_then(|model| self.select_model(model, cx));
239 self.set_inline_assistant_model(configured_model, cx);
240 }
241
242 pub fn select_commit_message_model(
243 &mut self,
244 model: Option<&SelectedModel>,
245 cx: &mut Context<Self>,
246 ) {
247 let configured_model = model.and_then(|model| self.select_model(model, cx));
248 self.set_commit_message_model(configured_model, cx);
249 }
250
251 pub fn select_thread_summary_model(
252 &mut self,
253 model: Option<&SelectedModel>,
254 cx: &mut Context<Self>,
255 ) {
256 let configured_model = model.and_then(|model| self.select_model(model, cx));
257 self.set_thread_summary_model(configured_model, cx);
258 }
259
260 /// Selects and sets the inline alternatives for language models based on
261 /// provider name and id.
262 pub fn select_inline_alternative_models(
263 &mut self,
264 alternatives: impl IntoIterator<Item = SelectedModel>,
265 cx: &mut Context<Self>,
266 ) {
267 self.inline_alternatives = alternatives
268 .into_iter()
269 .flat_map(|alternative| {
270 self.select_model(&alternative, cx)
271 .map(|configured_model| configured_model.model)
272 })
273 .collect::<Vec<_>>();
274 }
275
276 pub fn select_model(
277 &mut self,
278 selected_model: &SelectedModel,
279 cx: &mut Context<Self>,
280 ) -> Option<ConfiguredModel> {
281 let provider = self.provider(&selected_model.provider)?;
282 let model = provider
283 .provided_models(cx)
284 .iter()
285 .find(|model| model.id() == selected_model.model)?
286 .clone();
287 Some(ConfiguredModel { provider, model })
288 }
289
290 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
291 match (self.default_model.as_ref(), model.as_ref()) {
292 (Some(old), Some(new)) if old.is_same_as(new) => {}
293 (None, None) => {}
294 _ => cx.emit(Event::DefaultModelChanged),
295 }
296 self.default_fast_model = maybe!({
297 let provider = &model.as_ref()?.provider;
298 let fast_model = provider.default_fast_model(cx)?;
299 Some(ConfiguredModel {
300 provider: provider.clone(),
301 model: fast_model,
302 })
303 });
304 self.default_model = model;
305 }
306
307 pub fn set_inline_assistant_model(
308 &mut self,
309 model: Option<ConfiguredModel>,
310 cx: &mut Context<Self>,
311 ) {
312 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
313 (Some(old), Some(new)) if old.is_same_as(new) => {}
314 (None, None) => {}
315 _ => cx.emit(Event::InlineAssistantModelChanged),
316 }
317 self.inline_assistant_model = model;
318 }
319
320 pub fn set_commit_message_model(
321 &mut self,
322 model: Option<ConfiguredModel>,
323 cx: &mut Context<Self>,
324 ) {
325 match (self.commit_message_model.as_ref(), model.as_ref()) {
326 (Some(old), Some(new)) if old.is_same_as(new) => {}
327 (None, None) => {}
328 _ => cx.emit(Event::CommitMessageModelChanged),
329 }
330 self.commit_message_model = model;
331 }
332
333 pub fn set_thread_summary_model(
334 &mut self,
335 model: Option<ConfiguredModel>,
336 cx: &mut Context<Self>,
337 ) {
338 match (self.thread_summary_model.as_ref(), model.as_ref()) {
339 (Some(old), Some(new)) if old.is_same_as(new) => {}
340 (None, None) => {}
341 _ => cx.emit(Event::ThreadSummaryModelChanged),
342 }
343 self.thread_summary_model = model;
344 }
345
346 pub fn default_model(&self) -> Option<ConfiguredModel> {
347 #[cfg(debug_assertions)]
348 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
349 return None;
350 }
351
352 self.default_model.clone()
353 }
354
355 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
356 #[cfg(debug_assertions)]
357 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
358 return None;
359 }
360
361 self.inline_assistant_model
362 .clone()
363 .or_else(|| self.default_model.clone())
364 }
365
366 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
367 #[cfg(debug_assertions)]
368 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
369 return None;
370 }
371
372 self.commit_message_model
373 .clone()
374 .or_else(|| self.default_fast_model.clone())
375 .or_else(|| self.default_model.clone())
376 }
377
378 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
379 #[cfg(debug_assertions)]
380 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
381 return None;
382 }
383
384 self.thread_summary_model
385 .clone()
386 .or_else(|| self.default_fast_model.clone())
387 .or_else(|| self.default_model.clone())
388 }
389
390 /// The models to use for inline assists. Returns the union of the active
391 /// model and all inline alternatives. When there are multiple models, the
392 /// user will be able to cycle through results.
393 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
394 &self.inline_alternatives
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::fake_provider::FakeLanguageModelProvider;
402
403 #[gpui::test]
404 fn test_register_providers(cx: &mut App) {
405 let registry = cx.new(|_| LanguageModelRegistry::default());
406
407 let provider = FakeLanguageModelProvider::default();
408 registry.update(cx, |registry, cx| {
409 registry.register_provider(provider.clone(), cx);
410 });
411
412 let providers = registry.read(cx).providers();
413 assert_eq!(providers.len(), 1);
414 assert_eq!(providers[0].id(), provider.id());
415
416 registry.update(cx, |registry, cx| {
417 registry.unregister_provider(provider.id(), cx);
418 });
419
420 let providers = registry.read(cx).providers();
421 assert!(providers.is_empty());
422 }
423}