1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10pub mod provider;
11mod settings;
12
13use crate::provider::anthropic::AnthropicLanguageModelProvider;
14use crate::provider::bedrock::BedrockLanguageModelProvider;
15use crate::provider::cloud::CloudLanguageModelProvider;
16use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
17use crate::provider::google::GoogleLanguageModelProvider;
18use crate::provider::lmstudio::LmStudioLanguageModelProvider;
19pub use crate::provider::mistral::MistralLanguageModelProvider;
20use crate::provider::ollama::OllamaLanguageModelProvider;
21use crate::provider::open_ai::OpenAiLanguageModelProvider;
22use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
23use crate::provider::open_router::OpenRouterLanguageModelProvider;
24use crate::provider::vercel::VercelLanguageModelProvider;
25use crate::provider::x_ai::XAiLanguageModelProvider;
26pub use crate::settings::*;
27
28pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
29 let registry = LanguageModelRegistry::global(cx);
30 registry.update(cx, |registry, cx| {
31 register_language_model_providers(registry, user_store, client.clone(), cx);
32 });
33
34 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
35 .openai_compatible
36 .keys()
37 .cloned()
38 .collect::<HashSet<_>>();
39
40 registry.update(cx, |registry, cx| {
41 register_openai_compatible_providers(
42 registry,
43 &HashSet::default(),
44 &openai_compatible_providers,
45 client.clone(),
46 cx,
47 );
48 });
49 cx.observe_global::<SettingsStore>(move |cx| {
50 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
51 .openai_compatible
52 .keys()
53 .cloned()
54 .collect::<HashSet<_>>();
55 if openai_compatible_providers_new != openai_compatible_providers {
56 registry.update(cx, |registry, cx| {
57 register_openai_compatible_providers(
58 registry,
59 &openai_compatible_providers,
60 &openai_compatible_providers_new,
61 client.clone(),
62 cx,
63 );
64 });
65 openai_compatible_providers = openai_compatible_providers_new;
66 }
67 })
68 .detach();
69}
70
71fn register_openai_compatible_providers(
72 registry: &mut LanguageModelRegistry,
73 old: &HashSet<Arc<str>>,
74 new: &HashSet<Arc<str>>,
75 client: Arc<Client>,
76 cx: &mut Context<LanguageModelRegistry>,
77) {
78 for provider_id in old {
79 if !new.contains(provider_id) {
80 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
81 }
82 }
83
84 for provider_id in new {
85 if !old.contains(provider_id) {
86 registry.register_provider(
87 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
88 provider_id.clone(),
89 client.http_client(),
90 cx,
91 )),
92 cx,
93 );
94 }
95 }
96}
97
98fn register_language_model_providers(
99 registry: &mut LanguageModelRegistry,
100 user_store: Entity<UserStore>,
101 client: Arc<Client>,
102 cx: &mut Context<LanguageModelRegistry>,
103) {
104 registry.register_provider(
105 Arc::new(CloudLanguageModelProvider::new(
106 user_store,
107 client.clone(),
108 cx,
109 )),
110 cx,
111 );
112 registry.register_provider(
113 Arc::new(AnthropicLanguageModelProvider::new(
114 client.http_client(),
115 cx,
116 )),
117 cx,
118 );
119 registry.register_provider(
120 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
121 cx,
122 );
123 registry.register_provider(
124 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
125 cx,
126 );
127 registry.register_provider(
128 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
129 cx,
130 );
131 registry.register_provider(
132 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
133 cx,
134 );
135 registry.register_provider(
136 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
137 cx,
138 );
139 registry.register_provider(
140 MistralLanguageModelProvider::global(client.http_client(), cx),
141 cx,
142 );
143 registry.register_provider(
144 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
145 cx,
146 );
147 registry.register_provider(
148 Arc::new(OpenRouterLanguageModelProvider::new(
149 client.http_client(),
150 cx,
151 )),
152 cx,
153 );
154 registry.register_provider(
155 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
156 cx,
157 );
158 registry.register_provider(
159 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
160 cx,
161 );
162 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
163}