1use crate::{
2 provider::{
3 anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
4 copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
5 ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
6 },
7 LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
8};
9use client::Client;
10use collections::BTreeMap;
11use gpui::{AppContext, Global, Model, ModelContext};
12use std::sync::Arc;
13use ui::Context;
14
15pub fn init(client: Arc<Client>, cx: &mut AppContext) {
16 let registry = cx.new_model(|cx| {
17 let mut registry = LanguageModelRegistry::default();
18 register_language_model_providers(&mut registry, client, cx);
19 registry
20 });
21 cx.set_global(GlobalLanguageModelRegistry(registry));
22}
23
24fn register_language_model_providers(
25 registry: &mut LanguageModelRegistry,
26 client: Arc<Client>,
27 cx: &mut ModelContext<LanguageModelRegistry>,
28) {
29 use feature_flags::FeatureFlagAppExt;
30
31 registry.register_provider(
32 AnthropicLanguageModelProvider::new(client.http_client(), cx),
33 cx,
34 );
35 registry.register_provider(
36 OpenAiLanguageModelProvider::new(client.http_client(), cx),
37 cx,
38 );
39 registry.register_provider(
40 OllamaLanguageModelProvider::new(client.http_client(), cx),
41 cx,
42 );
43 registry.register_provider(
44 GoogleLanguageModelProvider::new(client.http_client(), cx),
45 cx,
46 );
47 registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
48
49 cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
50 let client = client.clone();
51 LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
52 if enabled {
53 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
54 } else {
55 registry.unregister_provider(
56 &LanguageModelProviderId::from(
57 crate::provider::cloud::PROVIDER_NAME.to_string(),
58 ),
59 cx,
60 );
61 }
62 });
63 })
64 .detach();
65}
66
67struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
68
69impl Global for GlobalLanguageModelRegistry {}
70
71#[derive(Default)]
72pub struct LanguageModelRegistry {
73 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
74}
75
76impl LanguageModelRegistry {
77 pub fn global(cx: &AppContext) -> Model<Self> {
78 cx.global::<GlobalLanguageModelRegistry>().0.clone()
79 }
80
81 pub fn read_global(cx: &AppContext) -> &Self {
82 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
83 }
84
85 #[cfg(any(test, feature = "test-support"))]
86 pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
87 let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
88 let registry = cx.new_model(|cx| {
89 let mut registry = Self::default();
90 registry.register_provider(fake_provider.clone(), cx);
91 registry
92 });
93 cx.set_global(GlobalLanguageModelRegistry(registry));
94 fake_provider
95 }
96
97 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
98 &mut self,
99 provider: T,
100 cx: &mut ModelContext<Self>,
101 ) {
102 let name = provider.id();
103
104 if let Some(subscription) = provider.subscribe(cx) {
105 subscription.detach();
106 }
107
108 self.providers.insert(name, Arc::new(provider));
109 cx.notify();
110 }
111
112 pub fn unregister_provider(
113 &mut self,
114 name: &LanguageModelProviderId,
115 cx: &mut ModelContext<Self>,
116 ) {
117 if self.providers.remove(name).is_some() {
118 cx.notify();
119 }
120 }
121
122 pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
123 self.providers.values()
124 }
125
126 pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
127 self.providers
128 .values()
129 .flat_map(|provider| provider.provided_models(cx))
130 .collect()
131 }
132
133 pub fn provider(
134 &self,
135 name: &LanguageModelProviderId,
136 ) -> Option<Arc<dyn LanguageModelProvider>> {
137 self.providers.get(name).cloned()
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::provider::fake::FakeLanguageModelProvider;
145
146 #[gpui::test]
147 fn test_register_providers(cx: &mut AppContext) {
148 let registry = cx.new_model(|_| LanguageModelRegistry::default());
149
150 registry.update(cx, |registry, cx| {
151 registry.register_provider(FakeLanguageModelProvider::default(), cx);
152 });
153
154 let providers = registry.read(cx).providers().collect::<Vec<_>>();
155 assert_eq!(providers.len(), 1);
156 assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
157
158 registry.update(cx, |registry, cx| {
159 registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
160 });
161
162 let providers = registry.read(cx).providers().collect::<Vec<_>>();
163 assert!(providers.is_empty());
164 }
165}