1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use util::maybe;
9
10pub fn init(cx: &mut App) {
11 let registry = cx.new(|_cx| LanguageModelRegistry::default());
12 cx.set_global(GlobalLanguageModelRegistry(registry));
13}
14
15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
16
17impl Global for GlobalLanguageModelRegistry {}
18
19#[derive(Default)]
20pub struct LanguageModelRegistry {
21 default_model: Option<ConfiguredModel>,
22 default_fast_model: Option<ConfiguredModel>,
23 inline_assistant_model: Option<ConfiguredModel>,
24 commit_message_model: Option<ConfiguredModel>,
25 thread_summary_model: Option<ConfiguredModel>,
26 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
27 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
28}
29
30#[derive(Debug)]
31pub struct SelectedModel {
32 pub provider: LanguageModelProviderId,
33 pub model: LanguageModelId,
34}
35
36impl FromStr for SelectedModel {
37 type Err = String;
38
39 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
40 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
41 let parts: Vec<&str> = id.split('/').collect();
42 let [provider_id, model_id] = parts.as_slice() else {
43 return Err(format!(
44 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
45 id
46 ));
47 };
48
49 if provider_id.is_empty() || model_id.is_empty() {
50 return Err(format!("Provider and model ids can't be empty: `{}`", id));
51 }
52
53 Ok(SelectedModel {
54 provider: LanguageModelProviderId(provider_id.to_string().into()),
55 model: LanguageModelId(model_id.to_string().into()),
56 })
57 }
58}
59
60#[derive(Clone)]
61pub struct ConfiguredModel {
62 pub provider: Arc<dyn LanguageModelProvider>,
63 pub model: Arc<dyn LanguageModel>,
64}
65
66impl ConfiguredModel {
67 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
68 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
69 }
70
71 pub fn is_provided_by_zed(&self) -> bool {
72 self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
73 }
74}
75
76pub enum Event {
77 DefaultModelChanged,
78 InlineAssistantModelChanged,
79 CommitMessageModelChanged,
80 ThreadSummaryModelChanged,
81 ProviderStateChanged,
82 AddedProvider(LanguageModelProviderId),
83 RemovedProvider(LanguageModelProviderId),
84}
85
86impl EventEmitter<Event> for LanguageModelRegistry {}
87
88impl LanguageModelRegistry {
89 pub fn global(cx: &App) -> Entity<Self> {
90 cx.global::<GlobalLanguageModelRegistry>().0.clone()
91 }
92
93 pub fn read_global(cx: &App) -> &Self {
94 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
95 }
96
97 #[cfg(any(test, feature = "test-support"))]
98 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
99 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
100 let registry = cx.new(|cx| {
101 let mut registry = Self::default();
102 registry.register_provider(fake_provider.clone(), cx);
103 let model = fake_provider.provided_models(cx)[0].clone();
104 let configured_model = ConfiguredModel {
105 provider: Arc::new(fake_provider.clone()),
106 model,
107 };
108 registry.set_default_model(Some(configured_model), cx);
109 registry
110 });
111 cx.set_global(GlobalLanguageModelRegistry(registry));
112 fake_provider
113 }
114
115 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
116 &mut self,
117 provider: T,
118 cx: &mut Context<Self>,
119 ) {
120 let id = provider.id();
121
122 let subscription = provider.subscribe(cx, |_, cx| {
123 cx.emit(Event::ProviderStateChanged);
124 });
125 if let Some(subscription) = subscription {
126 subscription.detach();
127 }
128
129 self.providers.insert(id.clone(), Arc::new(provider));
130 cx.emit(Event::AddedProvider(id));
131 }
132
133 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
134 if self.providers.remove(&id).is_some() {
135 cx.emit(Event::RemovedProvider(id));
136 }
137 }
138
139 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
140 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
141 let mut providers = Vec::with_capacity(self.providers.len());
142 if let Some(provider) = self.providers.get(&zed_provider_id) {
143 providers.push(provider.clone());
144 }
145 providers.extend(self.providers.values().filter_map(|p| {
146 if p.id() != zed_provider_id {
147 Some(p.clone())
148 } else {
149 None
150 }
151 }));
152 providers
153 }
154
155 pub fn available_models<'a>(
156 &'a self,
157 cx: &'a App,
158 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
159 self.providers
160 .values()
161 .flat_map(|provider| provider.provided_models(cx))
162 }
163
164 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
165 self.providers.get(id).cloned()
166 }
167
168 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
169 let configured_model = model.and_then(|model| self.select_model(model, cx));
170 self.set_default_model(configured_model, cx);
171 }
172
173 pub fn select_inline_assistant_model(
174 &mut self,
175 model: Option<&SelectedModel>,
176 cx: &mut Context<Self>,
177 ) {
178 let configured_model = model.and_then(|model| self.select_model(model, cx));
179 self.set_inline_assistant_model(configured_model, cx);
180 }
181
182 pub fn select_commit_message_model(
183 &mut self,
184 model: Option<&SelectedModel>,
185 cx: &mut Context<Self>,
186 ) {
187 let configured_model = model.and_then(|model| self.select_model(model, cx));
188 self.set_commit_message_model(configured_model, cx);
189 }
190
191 pub fn select_thread_summary_model(
192 &mut self,
193 model: Option<&SelectedModel>,
194 cx: &mut Context<Self>,
195 ) {
196 let configured_model = model.and_then(|model| self.select_model(model, cx));
197 self.set_thread_summary_model(configured_model, cx);
198 }
199
200 /// Selects and sets the inline alternatives for language models based on
201 /// provider name and id.
202 pub fn select_inline_alternative_models(
203 &mut self,
204 alternatives: impl IntoIterator<Item = SelectedModel>,
205 cx: &mut Context<Self>,
206 ) {
207 self.inline_alternatives = alternatives
208 .into_iter()
209 .flat_map(|alternative| {
210 self.select_model(&alternative, cx)
211 .map(|configured_model| configured_model.model)
212 })
213 .collect::<Vec<_>>();
214 }
215
216 pub fn select_model(
217 &mut self,
218 selected_model: &SelectedModel,
219 cx: &mut Context<Self>,
220 ) -> Option<ConfiguredModel> {
221 let provider = self.provider(&selected_model.provider)?;
222 let model = provider
223 .provided_models(cx)
224 .iter()
225 .find(|model| model.id() == selected_model.model)?
226 .clone();
227 Some(ConfiguredModel { provider, model })
228 }
229
230 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
231 match (self.default_model.as_ref(), model.as_ref()) {
232 (Some(old), Some(new)) if old.is_same_as(new) => {}
233 (None, None) => {}
234 _ => cx.emit(Event::DefaultModelChanged),
235 }
236 self.default_fast_model = maybe!({
237 let provider = &model.as_ref()?.provider;
238 let fast_model = provider.default_fast_model(cx)?;
239 Some(ConfiguredModel {
240 provider: provider.clone(),
241 model: fast_model,
242 })
243 });
244 self.default_model = model;
245 }
246
247 pub fn set_inline_assistant_model(
248 &mut self,
249 model: Option<ConfiguredModel>,
250 cx: &mut Context<Self>,
251 ) {
252 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
253 (Some(old), Some(new)) if old.is_same_as(new) => {}
254 (None, None) => {}
255 _ => cx.emit(Event::InlineAssistantModelChanged),
256 }
257 self.inline_assistant_model = model;
258 }
259
260 pub fn set_commit_message_model(
261 &mut self,
262 model: Option<ConfiguredModel>,
263 cx: &mut Context<Self>,
264 ) {
265 match (self.commit_message_model.as_ref(), model.as_ref()) {
266 (Some(old), Some(new)) if old.is_same_as(new) => {}
267 (None, None) => {}
268 _ => cx.emit(Event::CommitMessageModelChanged),
269 }
270 self.commit_message_model = model;
271 }
272
273 pub fn set_thread_summary_model(
274 &mut self,
275 model: Option<ConfiguredModel>,
276 cx: &mut Context<Self>,
277 ) {
278 match (self.thread_summary_model.as_ref(), model.as_ref()) {
279 (Some(old), Some(new)) if old.is_same_as(new) => {}
280 (None, None) => {}
281 _ => cx.emit(Event::ThreadSummaryModelChanged),
282 }
283 self.thread_summary_model = model;
284 }
285
286 pub fn default_model(&self) -> Option<ConfiguredModel> {
287 #[cfg(debug_assertions)]
288 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
289 return None;
290 }
291
292 self.default_model.clone()
293 }
294
295 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
296 #[cfg(debug_assertions)]
297 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
298 return None;
299 }
300
301 self.inline_assistant_model
302 .clone()
303 .or_else(|| self.default_model.clone())
304 }
305
306 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
307 #[cfg(debug_assertions)]
308 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
309 return None;
310 }
311
312 self.commit_message_model
313 .clone()
314 .or_else(|| self.default_fast_model.clone())
315 .or_else(|| self.default_model.clone())
316 }
317
318 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
319 #[cfg(debug_assertions)]
320 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
321 return None;
322 }
323
324 self.thread_summary_model
325 .clone()
326 .or_else(|| self.default_fast_model.clone())
327 .or_else(|| self.default_model.clone())
328 }
329
330 /// The models to use for inline assists. Returns the union of the active
331 /// model and all inline alternatives. When there are multiple models, the
332 /// user will be able to cycle through results.
333 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
334 &self.inline_alternatives
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::fake_provider::FakeLanguageModelProvider;
342
343 #[gpui::test]
344 fn test_register_providers(cx: &mut App) {
345 let registry = cx.new(|_| LanguageModelRegistry::default());
346
347 registry.update(cx, |registry, cx| {
348 registry.register_provider(FakeLanguageModelProvider, cx);
349 });
350
351 let providers = registry.read(cx).providers();
352 assert_eq!(providers.len(), 1);
353 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
354
355 registry.update(cx, |registry, cx| {
356 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
357 });
358
359 let providers = registry.read(cx).providers();
360 assert!(providers.is_empty());
361 }
362}