completion_provider.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(test)]
  4mod fake;
  5mod ollama;
  6mod open_ai;
  7
  8pub use anthropic::*;
  9pub use cloud::*;
 10#[cfg(test)]
 11pub use fake::*;
 12pub use ollama::*;
 13pub use open_ai::*;
 14
 15use crate::{
 16    assistant_settings::{AssistantProvider, AssistantSettings},
 17    LanguageModel, LanguageModelRequest,
 18};
 19use anyhow::Result;
 20use client::Client;
 21use futures::{future::BoxFuture, stream::BoxStream};
 22use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 23use settings::{Settings, SettingsStore};
 24use std::sync::Arc;
 25use std::time::Duration;
 26
 27pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 28    let mut settings_version = 0;
 29    let provider = match &AssistantSettings::get_global(cx).provider {
 30        AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
 31            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
 32        ),
 33        AssistantProvider::OpenAi {
 34            model,
 35            api_url,
 36            low_speed_timeout_in_seconds,
 37        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 38            model.clone(),
 39            api_url.clone(),
 40            client.http_client(),
 41            low_speed_timeout_in_seconds.map(Duration::from_secs),
 42            settings_version,
 43        )),
 44        AssistantProvider::Anthropic {
 45            model,
 46            api_url,
 47            low_speed_timeout_in_seconds,
 48        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
 49            model.clone(),
 50            api_url.clone(),
 51            client.http_client(),
 52            low_speed_timeout_in_seconds.map(Duration::from_secs),
 53            settings_version,
 54        )),
 55        AssistantProvider::Ollama {
 56            model,
 57            api_url,
 58            low_speed_timeout_in_seconds,
 59        } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
 60            model.clone(),
 61            api_url.clone(),
 62            client.http_client(),
 63            low_speed_timeout_in_seconds.map(Duration::from_secs),
 64            settings_version,
 65        )),
 66    };
 67    cx.set_global(provider);
 68
 69    cx.observe_global::<SettingsStore>(move |cx| {
 70        settings_version += 1;
 71        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 72            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 73                (
 74                    CompletionProvider::OpenAi(provider),
 75                    AssistantProvider::OpenAi {
 76                        model,
 77                        api_url,
 78                        low_speed_timeout_in_seconds,
 79                    },
 80                ) => {
 81                    provider.update(
 82                        model.clone(),
 83                        api_url.clone(),
 84                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 85                        settings_version,
 86                    );
 87                }
 88                (
 89                    CompletionProvider::Anthropic(provider),
 90                    AssistantProvider::Anthropic {
 91                        model,
 92                        api_url,
 93                        low_speed_timeout_in_seconds,
 94                    },
 95                ) => {
 96                    provider.update(
 97                        model.clone(),
 98                        api_url.clone(),
 99                        low_speed_timeout_in_seconds.map(Duration::from_secs),
100                        settings_version,
101                    );
102                }
103
104                (
105                    CompletionProvider::Ollama(provider),
106                    AssistantProvider::Ollama {
107                        model,
108                        api_url,
109                        low_speed_timeout_in_seconds,
110                    },
111                ) => {
112                    provider.update(
113                        model.clone(),
114                        api_url.clone(),
115                        low_speed_timeout_in_seconds.map(Duration::from_secs),
116                        settings_version,
117                    );
118                }
119
120                (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
121                    provider.update(model.clone(), settings_version);
122                }
123                (_, AssistantProvider::ZedDotDev { model }) => {
124                    *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
125                        model.clone(),
126                        client.clone(),
127                        settings_version,
128                        cx,
129                    ));
130                }
131                (
132                    _,
133                    AssistantProvider::OpenAi {
134                        model,
135                        api_url,
136                        low_speed_timeout_in_seconds,
137                    },
138                ) => {
139                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
140                        model.clone(),
141                        api_url.clone(),
142                        client.http_client(),
143                        low_speed_timeout_in_seconds.map(Duration::from_secs),
144                        settings_version,
145                    ));
146                }
147                (
148                    _,
149                    AssistantProvider::Anthropic {
150                        model,
151                        api_url,
152                        low_speed_timeout_in_seconds,
153                    },
154                ) => {
155                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
156                        model.clone(),
157                        api_url.clone(),
158                        client.http_client(),
159                        low_speed_timeout_in_seconds.map(Duration::from_secs),
160                        settings_version,
161                    ));
162                }
163                (
164                    _,
165                    AssistantProvider::Ollama {
166                        model,
167                        api_url,
168                        low_speed_timeout_in_seconds,
169                    },
170                ) => {
171                    *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
172                        model.clone(),
173                        api_url.clone(),
174                        client.http_client(),
175                        low_speed_timeout_in_seconds.map(Duration::from_secs),
176                        settings_version,
177                    ));
178                }
179            }
180        })
181    })
182    .detach();
183}
184
185pub enum CompletionProvider {
186    OpenAi(OpenAiCompletionProvider),
187    Anthropic(AnthropicCompletionProvider),
188    Cloud(CloudCompletionProvider),
189    #[cfg(test)]
190    Fake(FakeCompletionProvider),
191    Ollama(OllamaCompletionProvider),
192}
193
194impl gpui::Global for CompletionProvider {}
195
196impl CompletionProvider {
197    pub fn global(cx: &AppContext) -> &Self {
198        cx.global::<Self>()
199    }
200
201    pub fn available_models(&self) -> Vec<LanguageModel> {
202        match self {
203            CompletionProvider::OpenAi(provider) => provider
204                .available_models()
205                .map(LanguageModel::OpenAi)
206                .collect(),
207            CompletionProvider::Anthropic(provider) => provider
208                .available_models()
209                .map(LanguageModel::Anthropic)
210                .collect(),
211            CompletionProvider::Cloud(provider) => provider
212                .available_models()
213                .map(LanguageModel::Cloud)
214                .collect(),
215            CompletionProvider::Ollama(provider) => provider
216                .available_models()
217                .map(|model| LanguageModel::Ollama(model.clone()))
218                .collect(),
219            #[cfg(test)]
220            CompletionProvider::Fake(_) => unimplemented!(),
221        }
222    }
223
224    pub fn settings_version(&self) -> usize {
225        match self {
226            CompletionProvider::OpenAi(provider) => provider.settings_version(),
227            CompletionProvider::Anthropic(provider) => provider.settings_version(),
228            CompletionProvider::Cloud(provider) => provider.settings_version(),
229            CompletionProvider::Ollama(provider) => provider.settings_version(),
230            #[cfg(test)]
231            CompletionProvider::Fake(_) => unimplemented!(),
232        }
233    }
234
235    pub fn is_authenticated(&self) -> bool {
236        match self {
237            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
238            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
239            CompletionProvider::Cloud(provider) => provider.is_authenticated(),
240            CompletionProvider::Ollama(provider) => provider.is_authenticated(),
241            #[cfg(test)]
242            CompletionProvider::Fake(_) => true,
243        }
244    }
245
246    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
247        match self {
248            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
249            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
250            CompletionProvider::Cloud(provider) => provider.authenticate(cx),
251            CompletionProvider::Ollama(provider) => provider.authenticate(cx),
252            #[cfg(test)]
253            CompletionProvider::Fake(_) => Task::ready(Ok(())),
254        }
255    }
256
257    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
258        match self {
259            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
260            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
261            CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
262            CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
263            #[cfg(test)]
264            CompletionProvider::Fake(_) => unimplemented!(),
265        }
266    }
267
268    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
269        match self {
270            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
271            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
272            CompletionProvider::Cloud(_) => Task::ready(Ok(())),
273            CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
274            #[cfg(test)]
275            CompletionProvider::Fake(_) => Task::ready(Ok(())),
276        }
277    }
278
279    pub fn model(&self) -> LanguageModel {
280        match self {
281            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
282            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
283            CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
284            CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
285            #[cfg(test)]
286            CompletionProvider::Fake(_) => LanguageModel::default(),
287        }
288    }
289
290    pub fn count_tokens(
291        &self,
292        request: LanguageModelRequest,
293        cx: &AppContext,
294    ) -> BoxFuture<'static, Result<usize>> {
295        match self {
296            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
297            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
298            CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
299            CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
300            #[cfg(test)]
301            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
302        }
303    }
304
305    pub fn complete(
306        &self,
307        request: LanguageModelRequest,
308    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
309        match self {
310            CompletionProvider::OpenAi(provider) => provider.complete(request),
311            CompletionProvider::Anthropic(provider) => provider.complete(request),
312            CompletionProvider::Cloud(provider) => provider.complete(request),
313            CompletionProvider::Ollama(provider) => provider.complete(request),
314            #[cfg(test)]
315            CompletionProvider::Fake(provider) => provider.complete(),
316        }
317    }
318}