1mod anthropic;
2mod cloud;
3#[cfg(test)]
4mod fake;
5mod ollama;
6mod open_ai;
7
8pub use anthropic::*;
9pub use cloud::*;
10#[cfg(test)]
11pub use fake::*;
12pub use ollama::*;
13pub use open_ai::*;
14
15use crate::{
16 assistant_settings::{AssistantProvider, AssistantSettings},
17 LanguageModel, LanguageModelRequest,
18};
19use anyhow::Result;
20use client::Client;
21use futures::{future::BoxFuture, stream::BoxStream};
22use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
23use settings::{Settings, SettingsStore};
24use std::sync::Arc;
25use std::time::Duration;
26
27pub fn init(client: Arc<Client>, cx: &mut AppContext) {
28 let mut settings_version = 0;
29 let provider = match &AssistantSettings::get_global(cx).provider {
30 AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
31 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
32 ),
33 AssistantProvider::OpenAi {
34 model,
35 api_url,
36 low_speed_timeout_in_seconds,
37 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
38 model.clone(),
39 api_url.clone(),
40 client.http_client(),
41 low_speed_timeout_in_seconds.map(Duration::from_secs),
42 settings_version,
43 )),
44 AssistantProvider::Anthropic {
45 model,
46 api_url,
47 low_speed_timeout_in_seconds,
48 } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
49 model.clone(),
50 api_url.clone(),
51 client.http_client(),
52 low_speed_timeout_in_seconds.map(Duration::from_secs),
53 settings_version,
54 )),
55 AssistantProvider::Ollama {
56 model,
57 api_url,
58 low_speed_timeout_in_seconds,
59 } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
60 model.clone(),
61 api_url.clone(),
62 client.http_client(),
63 low_speed_timeout_in_seconds.map(Duration::from_secs),
64 settings_version,
65 cx,
66 )),
67 };
68 cx.set_global(provider);
69
70 cx.observe_global::<SettingsStore>(move |cx| {
71 settings_version += 1;
72 cx.update_global::<CompletionProvider, _>(|provider, cx| {
73 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
74 (
75 CompletionProvider::OpenAi(provider),
76 AssistantProvider::OpenAi {
77 model,
78 api_url,
79 low_speed_timeout_in_seconds,
80 },
81 ) => {
82 provider.update(
83 model.clone(),
84 api_url.clone(),
85 low_speed_timeout_in_seconds.map(Duration::from_secs),
86 settings_version,
87 );
88 }
89 (
90 CompletionProvider::Anthropic(provider),
91 AssistantProvider::Anthropic {
92 model,
93 api_url,
94 low_speed_timeout_in_seconds,
95 },
96 ) => {
97 provider.update(
98 model.clone(),
99 api_url.clone(),
100 low_speed_timeout_in_seconds.map(Duration::from_secs),
101 settings_version,
102 );
103 }
104
105 (
106 CompletionProvider::Ollama(provider),
107 AssistantProvider::Ollama {
108 model,
109 api_url,
110 low_speed_timeout_in_seconds,
111 },
112 ) => {
113 provider.update(
114 model.clone(),
115 api_url.clone(),
116 low_speed_timeout_in_seconds.map(Duration::from_secs),
117 settings_version,
118 cx,
119 );
120 }
121
122 (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
123 provider.update(model.clone(), settings_version);
124 }
125 (_, AssistantProvider::ZedDotDev { model }) => {
126 *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
127 model.clone(),
128 client.clone(),
129 settings_version,
130 cx,
131 ));
132 }
133 (
134 _,
135 AssistantProvider::OpenAi {
136 model,
137 api_url,
138 low_speed_timeout_in_seconds,
139 },
140 ) => {
141 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
142 model.clone(),
143 api_url.clone(),
144 client.http_client(),
145 low_speed_timeout_in_seconds.map(Duration::from_secs),
146 settings_version,
147 ));
148 }
149 (
150 _,
151 AssistantProvider::Anthropic {
152 model,
153 api_url,
154 low_speed_timeout_in_seconds,
155 },
156 ) => {
157 *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
158 model.clone(),
159 api_url.clone(),
160 client.http_client(),
161 low_speed_timeout_in_seconds.map(Duration::from_secs),
162 settings_version,
163 ));
164 }
165 (
166 _,
167 AssistantProvider::Ollama {
168 model,
169 api_url,
170 low_speed_timeout_in_seconds,
171 },
172 ) => {
173 *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
174 model.clone(),
175 api_url.clone(),
176 client.http_client(),
177 low_speed_timeout_in_seconds.map(Duration::from_secs),
178 settings_version,
179 cx,
180 ));
181 }
182 }
183 })
184 })
185 .detach();
186}
187
188pub enum CompletionProvider {
189 OpenAi(OpenAiCompletionProvider),
190 Anthropic(AnthropicCompletionProvider),
191 Cloud(CloudCompletionProvider),
192 #[cfg(test)]
193 Fake(FakeCompletionProvider),
194 Ollama(OllamaCompletionProvider),
195}
196
197impl gpui::Global for CompletionProvider {}
198
199impl CompletionProvider {
200 pub fn global(cx: &AppContext) -> &Self {
201 cx.global::<Self>()
202 }
203
204 pub fn available_models(&self) -> Vec<LanguageModel> {
205 match self {
206 CompletionProvider::OpenAi(provider) => provider
207 .available_models()
208 .map(LanguageModel::OpenAi)
209 .collect(),
210 CompletionProvider::Anthropic(provider) => provider
211 .available_models()
212 .map(LanguageModel::Anthropic)
213 .collect(),
214 CompletionProvider::Cloud(provider) => provider
215 .available_models()
216 .map(LanguageModel::Cloud)
217 .collect(),
218 CompletionProvider::Ollama(provider) => provider
219 .available_models()
220 .map(|model| LanguageModel::Ollama(model.clone()))
221 .collect(),
222 #[cfg(test)]
223 CompletionProvider::Fake(_) => unimplemented!(),
224 }
225 }
226
227 pub fn settings_version(&self) -> usize {
228 match self {
229 CompletionProvider::OpenAi(provider) => provider.settings_version(),
230 CompletionProvider::Anthropic(provider) => provider.settings_version(),
231 CompletionProvider::Cloud(provider) => provider.settings_version(),
232 CompletionProvider::Ollama(provider) => provider.settings_version(),
233 #[cfg(test)]
234 CompletionProvider::Fake(_) => unimplemented!(),
235 }
236 }
237
238 pub fn is_authenticated(&self) -> bool {
239 match self {
240 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
241 CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
242 CompletionProvider::Cloud(provider) => provider.is_authenticated(),
243 CompletionProvider::Ollama(provider) => provider.is_authenticated(),
244 #[cfg(test)]
245 CompletionProvider::Fake(_) => true,
246 }
247 }
248
249 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
250 match self {
251 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
252 CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
253 CompletionProvider::Cloud(provider) => provider.authenticate(cx),
254 CompletionProvider::Ollama(provider) => provider.authenticate(cx),
255 #[cfg(test)]
256 CompletionProvider::Fake(_) => Task::ready(Ok(())),
257 }
258 }
259
260 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
261 match self {
262 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
263 CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
264 CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
265 CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
266 #[cfg(test)]
267 CompletionProvider::Fake(_) => unimplemented!(),
268 }
269 }
270
271 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
272 match self {
273 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
274 CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
275 CompletionProvider::Cloud(_) => Task::ready(Ok(())),
276 CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
277 #[cfg(test)]
278 CompletionProvider::Fake(_) => Task::ready(Ok(())),
279 }
280 }
281
282 pub fn model(&self) -> LanguageModel {
283 match self {
284 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
285 CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
286 CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
287 CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
288 #[cfg(test)]
289 CompletionProvider::Fake(_) => LanguageModel::default(),
290 }
291 }
292
293 pub fn count_tokens(
294 &self,
295 request: LanguageModelRequest,
296 cx: &AppContext,
297 ) -> BoxFuture<'static, Result<usize>> {
298 match self {
299 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
300 CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
301 CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
302 CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
303 #[cfg(test)]
304 CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
305 }
306 }
307
308 pub fn complete(
309 &self,
310 request: LanguageModelRequest,
311 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
312 match self {
313 CompletionProvider::OpenAi(provider) => provider.complete(request),
314 CompletionProvider::Anthropic(provider) => provider.complete(request),
315 CompletionProvider::Cloud(provider) => provider.complete(request),
316 CompletionProvider::Ollama(provider) => provider.complete(request),
317 #[cfg(test)]
318 CompletionProvider::Fake(provider) => provider.complete(),
319 }
320 }
321}