completion_provider.rs

  1mod anthropic;
  2#[cfg(test)]
  3mod fake;
  4mod open_ai;
  5mod zed;
  6
  7pub use anthropic::*;
  8#[cfg(test)]
  9pub use fake::*;
 10pub use open_ai::*;
 11pub use zed::*;
 12
 13use crate::{
 14    assistant_settings::{AssistantProvider, AssistantSettings},
 15    LanguageModel, LanguageModelRequest,
 16};
 17use anyhow::Result;
 18use client::Client;
 19use futures::{future::BoxFuture, stream::BoxStream};
 20use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 21use settings::{Settings, SettingsStore};
 22use std::sync::Arc;
 23use std::time::Duration;
 24
 25pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 26    let mut settings_version = 0;
 27    let provider = match &AssistantSettings::get_global(cx).provider {
 28        AssistantProvider::ZedDotDev { default_model } => {
 29            CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
 30                default_model.clone(),
 31                client.clone(),
 32                settings_version,
 33                cx,
 34            ))
 35        }
 36        AssistantProvider::OpenAi {
 37            default_model,
 38            api_url,
 39            low_speed_timeout_in_seconds,
 40        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 41            default_model.clone(),
 42            api_url.clone(),
 43            client.http_client(),
 44            low_speed_timeout_in_seconds.map(Duration::from_secs),
 45            settings_version,
 46        )),
 47        AssistantProvider::Anthropic {
 48            default_model,
 49            api_url,
 50            low_speed_timeout_in_seconds,
 51        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
 52            default_model.clone(),
 53            api_url.clone(),
 54            client.http_client(),
 55            low_speed_timeout_in_seconds.map(Duration::from_secs),
 56            settings_version,
 57        )),
 58    };
 59    cx.set_global(provider);
 60
 61    cx.observe_global::<SettingsStore>(move |cx| {
 62        settings_version += 1;
 63        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 64            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 65                (
 66                    CompletionProvider::OpenAi(provider),
 67                    AssistantProvider::OpenAi {
 68                        default_model,
 69                        api_url,
 70                        low_speed_timeout_in_seconds,
 71                    },
 72                ) => {
 73                    provider.update(
 74                        default_model.clone(),
 75                        api_url.clone(),
 76                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 77                        settings_version,
 78                    );
 79                }
 80                (
 81                    CompletionProvider::Anthropic(provider),
 82                    AssistantProvider::Anthropic {
 83                        default_model,
 84                        api_url,
 85                        low_speed_timeout_in_seconds,
 86                    },
 87                ) => {
 88                    provider.update(
 89                        default_model.clone(),
 90                        api_url.clone(),
 91                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 92                        settings_version,
 93                    );
 94                }
 95                (
 96                    CompletionProvider::ZedDotDev(provider),
 97                    AssistantProvider::ZedDotDev { default_model },
 98                ) => {
 99                    provider.update(default_model.clone(), settings_version);
100                }
101                (_, AssistantProvider::ZedDotDev { default_model }) => {
102                    *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
103                        default_model.clone(),
104                        client.clone(),
105                        settings_version,
106                        cx,
107                    ));
108                }
109                (
110                    _,
111                    AssistantProvider::OpenAi {
112                        default_model,
113                        api_url,
114                        low_speed_timeout_in_seconds,
115                    },
116                ) => {
117                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
118                        default_model.clone(),
119                        api_url.clone(),
120                        client.http_client(),
121                        low_speed_timeout_in_seconds.map(Duration::from_secs),
122                        settings_version,
123                    ));
124                }
125                (
126                    _,
127                    AssistantProvider::Anthropic {
128                        default_model,
129                        api_url,
130                        low_speed_timeout_in_seconds,
131                    },
132                ) => {
133                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
134                        default_model.clone(),
135                        api_url.clone(),
136                        client.http_client(),
137                        low_speed_timeout_in_seconds.map(Duration::from_secs),
138                        settings_version,
139                    ));
140                }
141            }
142        })
143    })
144    .detach();
145}
146
147pub enum CompletionProvider {
148    OpenAi(OpenAiCompletionProvider),
149    Anthropic(AnthropicCompletionProvider),
150    ZedDotDev(ZedDotDevCompletionProvider),
151    #[cfg(test)]
152    Fake(FakeCompletionProvider),
153}
154
155impl gpui::Global for CompletionProvider {}
156
157impl CompletionProvider {
158    pub fn global(cx: &AppContext) -> &Self {
159        cx.global::<Self>()
160    }
161
162    pub fn settings_version(&self) -> usize {
163        match self {
164            CompletionProvider::OpenAi(provider) => provider.settings_version(),
165            CompletionProvider::Anthropic(provider) => provider.settings_version(),
166            CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
167            #[cfg(test)]
168            CompletionProvider::Fake(_) => unimplemented!(),
169        }
170    }
171
172    pub fn is_authenticated(&self) -> bool {
173        match self {
174            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
175            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
176            CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
177            #[cfg(test)]
178            CompletionProvider::Fake(_) => true,
179        }
180    }
181
182    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
183        match self {
184            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
185            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
186            CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
187            #[cfg(test)]
188            CompletionProvider::Fake(_) => Task::ready(Ok(())),
189        }
190    }
191
192    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
193        match self {
194            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
195            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
196            CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
197            #[cfg(test)]
198            CompletionProvider::Fake(_) => unimplemented!(),
199        }
200    }
201
202    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
203        match self {
204            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
205            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
206            CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
207            #[cfg(test)]
208            CompletionProvider::Fake(_) => Task::ready(Ok(())),
209        }
210    }
211
212    pub fn default_model(&self) -> LanguageModel {
213        match self {
214            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
215            CompletionProvider::Anthropic(provider) => {
216                LanguageModel::Anthropic(provider.default_model())
217            }
218            CompletionProvider::ZedDotDev(provider) => {
219                LanguageModel::ZedDotDev(provider.default_model())
220            }
221            #[cfg(test)]
222            CompletionProvider::Fake(_) => unimplemented!(),
223        }
224    }
225
226    pub fn count_tokens(
227        &self,
228        request: LanguageModelRequest,
229        cx: &AppContext,
230    ) -> BoxFuture<'static, Result<usize>> {
231        match self {
232            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
233            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
234            CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
235            #[cfg(test)]
236            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
237        }
238    }
239
240    pub fn complete(
241        &self,
242        request: LanguageModelRequest,
243    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
244        match self {
245            CompletionProvider::OpenAi(provider) => provider.complete(request),
246            CompletionProvider::Anthropic(provider) => provider.complete(request),
247            CompletionProvider::ZedDotDev(provider) => provider.complete(request),
248            #[cfg(test)]
249            CompletionProvider::Fake(provider) => provider.complete(),
250        }
251    }
252}