1mod anthropic;
2mod cloud;
3#[cfg(test)]
4mod fake;
5mod ollama;
6mod open_ai;
7
8pub use anthropic::*;
9pub use cloud::*;
10#[cfg(test)]
11pub use fake::*;
12pub use ollama::*;
13pub use open_ai::*;
14
15use crate::{
16 assistant_settings::{AssistantProvider, AssistantSettings},
17 LanguageModel, LanguageModelRequest,
18};
19use anyhow::Result;
20use client::Client;
21use futures::{future::BoxFuture, stream::BoxStream};
22use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
23use settings::{Settings, SettingsStore};
24use std::sync::Arc;
25use std::time::Duration;
26
27/// Choose which model to use for openai provider.
28/// If the model is not available, try to use the first available model, or fallback to the original model.
29fn choose_openai_model(
30 model: &::open_ai::Model,
31 available_models: &[::open_ai::Model],
32) -> ::open_ai::Model {
33 available_models
34 .iter()
35 .find(|&m| m == model)
36 .or_else(|| available_models.first())
37 .unwrap_or_else(|| model)
38 .clone()
39}
40
41pub fn init(client: Arc<Client>, cx: &mut AppContext) {
42 let mut settings_version = 0;
43 let provider = match &AssistantSettings::get_global(cx).provider {
44 AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
45 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
46 ),
47 AssistantProvider::OpenAi {
48 model,
49 api_url,
50 low_speed_timeout_in_seconds,
51 available_models,
52 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
53 choose_openai_model(model, available_models),
54 api_url.clone(),
55 client.http_client(),
56 low_speed_timeout_in_seconds.map(Duration::from_secs),
57 settings_version,
58 )),
59 AssistantProvider::Anthropic {
60 model,
61 api_url,
62 low_speed_timeout_in_seconds,
63 } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
64 model.clone(),
65 api_url.clone(),
66 client.http_client(),
67 low_speed_timeout_in_seconds.map(Duration::from_secs),
68 settings_version,
69 )),
70 AssistantProvider::Ollama {
71 model,
72 api_url,
73 low_speed_timeout_in_seconds,
74 } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
75 model.clone(),
76 api_url.clone(),
77 client.http_client(),
78 low_speed_timeout_in_seconds.map(Duration::from_secs),
79 settings_version,
80 cx,
81 )),
82 };
83 cx.set_global(provider);
84
85 cx.observe_global::<SettingsStore>(move |cx| {
86 settings_version += 1;
87 cx.update_global::<CompletionProvider, _>(|provider, cx| {
88 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
89 (
90 CompletionProvider::OpenAi(provider),
91 AssistantProvider::OpenAi {
92 model,
93 api_url,
94 low_speed_timeout_in_seconds,
95 available_models,
96 },
97 ) => {
98 provider.update(
99 choose_openai_model(model, available_models),
100 api_url.clone(),
101 low_speed_timeout_in_seconds.map(Duration::from_secs),
102 settings_version,
103 );
104 }
105 (
106 CompletionProvider::Anthropic(provider),
107 AssistantProvider::Anthropic {
108 model,
109 api_url,
110 low_speed_timeout_in_seconds,
111 },
112 ) => {
113 provider.update(
114 model.clone(),
115 api_url.clone(),
116 low_speed_timeout_in_seconds.map(Duration::from_secs),
117 settings_version,
118 );
119 }
120
121 (
122 CompletionProvider::Ollama(provider),
123 AssistantProvider::Ollama {
124 model,
125 api_url,
126 low_speed_timeout_in_seconds,
127 },
128 ) => {
129 provider.update(
130 model.clone(),
131 api_url.clone(),
132 low_speed_timeout_in_seconds.map(Duration::from_secs),
133 settings_version,
134 cx,
135 );
136 }
137
138 (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
139 provider.update(model.clone(), settings_version);
140 }
141 (_, AssistantProvider::ZedDotDev { model }) => {
142 *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
143 model.clone(),
144 client.clone(),
145 settings_version,
146 cx,
147 ));
148 }
149 (
150 _,
151 AssistantProvider::OpenAi {
152 model,
153 api_url,
154 low_speed_timeout_in_seconds,
155 available_models,
156 },
157 ) => {
158 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
159 choose_openai_model(model, available_models),
160 api_url.clone(),
161 client.http_client(),
162 low_speed_timeout_in_seconds.map(Duration::from_secs),
163 settings_version,
164 ));
165 }
166 (
167 _,
168 AssistantProvider::Anthropic {
169 model,
170 api_url,
171 low_speed_timeout_in_seconds,
172 },
173 ) => {
174 *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
175 model.clone(),
176 api_url.clone(),
177 client.http_client(),
178 low_speed_timeout_in_seconds.map(Duration::from_secs),
179 settings_version,
180 ));
181 }
182 (
183 _,
184 AssistantProvider::Ollama {
185 model,
186 api_url,
187 low_speed_timeout_in_seconds,
188 },
189 ) => {
190 *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
191 model.clone(),
192 api_url.clone(),
193 client.http_client(),
194 low_speed_timeout_in_seconds.map(Duration::from_secs),
195 settings_version,
196 cx,
197 ));
198 }
199 }
200 })
201 })
202 .detach();
203}
204
205pub enum CompletionProvider {
206 OpenAi(OpenAiCompletionProvider),
207 Anthropic(AnthropicCompletionProvider),
208 Cloud(CloudCompletionProvider),
209 #[cfg(test)]
210 Fake(FakeCompletionProvider),
211 Ollama(OllamaCompletionProvider),
212}
213
214impl gpui::Global for CompletionProvider {}
215
216impl CompletionProvider {
217 pub fn global(cx: &AppContext) -> &Self {
218 cx.global::<Self>()
219 }
220
221 pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
222 match self {
223 CompletionProvider::OpenAi(provider) => provider
224 .available_models(cx)
225 .map(LanguageModel::OpenAi)
226 .collect(),
227 CompletionProvider::Anthropic(provider) => provider
228 .available_models()
229 .map(LanguageModel::Anthropic)
230 .collect(),
231 CompletionProvider::Cloud(provider) => provider
232 .available_models()
233 .map(LanguageModel::Cloud)
234 .collect(),
235 CompletionProvider::Ollama(provider) => provider
236 .available_models()
237 .map(|model| LanguageModel::Ollama(model.clone()))
238 .collect(),
239 #[cfg(test)]
240 CompletionProvider::Fake(_) => unimplemented!(),
241 }
242 }
243
244 pub fn settings_version(&self) -> usize {
245 match self {
246 CompletionProvider::OpenAi(provider) => provider.settings_version(),
247 CompletionProvider::Anthropic(provider) => provider.settings_version(),
248 CompletionProvider::Cloud(provider) => provider.settings_version(),
249 CompletionProvider::Ollama(provider) => provider.settings_version(),
250 #[cfg(test)]
251 CompletionProvider::Fake(_) => unimplemented!(),
252 }
253 }
254
255 pub fn is_authenticated(&self) -> bool {
256 match self {
257 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
258 CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
259 CompletionProvider::Cloud(provider) => provider.is_authenticated(),
260 CompletionProvider::Ollama(provider) => provider.is_authenticated(),
261 #[cfg(test)]
262 CompletionProvider::Fake(_) => true,
263 }
264 }
265
266 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
267 match self {
268 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
269 CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
270 CompletionProvider::Cloud(provider) => provider.authenticate(cx),
271 CompletionProvider::Ollama(provider) => provider.authenticate(cx),
272 #[cfg(test)]
273 CompletionProvider::Fake(_) => Task::ready(Ok(())),
274 }
275 }
276
277 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
278 match self {
279 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
280 CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
281 CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
282 CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
283 #[cfg(test)]
284 CompletionProvider::Fake(_) => unimplemented!(),
285 }
286 }
287
288 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
289 match self {
290 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
291 CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
292 CompletionProvider::Cloud(_) => Task::ready(Ok(())),
293 CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
294 #[cfg(test)]
295 CompletionProvider::Fake(_) => Task::ready(Ok(())),
296 }
297 }
298
299 pub fn model(&self) -> LanguageModel {
300 match self {
301 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
302 CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
303 CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
304 CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
305 #[cfg(test)]
306 CompletionProvider::Fake(_) => LanguageModel::default(),
307 }
308 }
309
310 pub fn count_tokens(
311 &self,
312 request: LanguageModelRequest,
313 cx: &AppContext,
314 ) -> BoxFuture<'static, Result<usize>> {
315 match self {
316 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
317 CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
318 CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
319 CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
320 #[cfg(test)]
321 CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
322 }
323 }
324
325 pub fn complete(
326 &self,
327 request: LanguageModelRequest,
328 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
329 match self {
330 CompletionProvider::OpenAi(provider) => provider.complete(request),
331 CompletionProvider::Anthropic(provider) => provider.complete(request),
332 CompletionProvider::Cloud(provider) => provider.complete(request),
333 CompletionProvider::Ollama(provider) => provider.complete(request),
334 #[cfg(test)]
335 CompletionProvider::Fake(provider) => provider.complete(),
336 }
337 }
338}