completion_provider.rs

  1#[cfg(test)]
  2mod fake;
  3mod open_ai;
  4mod zed;
  5
  6#[cfg(test)]
  7pub use fake::*;
  8pub use open_ai::*;
  9pub use zed::*;
 10
 11use crate::{
 12    assistant_settings::{AssistantProvider, AssistantSettings},
 13    LanguageModel, LanguageModelRequest,
 14};
 15use anyhow::Result;
 16use client::Client;
 17use futures::{future::BoxFuture, stream::BoxStream};
 18use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 19use settings::{Settings, SettingsStore};
 20use std::sync::Arc;
 21
 22pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 23    let mut settings_version = 0;
 24    let provider = match &AssistantSettings::get_global(cx).provider {
 25        AssistantProvider::ZedDotDev { default_model } => {
 26            CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
 27                default_model.clone(),
 28                client.clone(),
 29                settings_version,
 30                cx,
 31            ))
 32        }
 33        AssistantProvider::OpenAi {
 34            default_model,
 35            api_url,
 36        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 37            default_model.clone(),
 38            api_url.clone(),
 39            client.http_client(),
 40            settings_version,
 41        )),
 42    };
 43    cx.set_global(provider);
 44
 45    cx.observe_global::<SettingsStore>(move |cx| {
 46        settings_version += 1;
 47        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 48            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 49                (
 50                    CompletionProvider::OpenAi(provider),
 51                    AssistantProvider::OpenAi {
 52                        default_model,
 53                        api_url,
 54                    },
 55                ) => {
 56                    provider.update(default_model.clone(), api_url.clone(), settings_version);
 57                }
 58                (
 59                    CompletionProvider::ZedDotDev(provider),
 60                    AssistantProvider::ZedDotDev { default_model },
 61                ) => {
 62                    provider.update(default_model.clone(), settings_version);
 63                }
 64                (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
 65                    *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
 66                        default_model.clone(),
 67                        client.clone(),
 68                        settings_version,
 69                        cx,
 70                    ));
 71                }
 72                (
 73                    CompletionProvider::ZedDotDev(_),
 74                    AssistantProvider::OpenAi {
 75                        default_model,
 76                        api_url,
 77                    },
 78                ) => {
 79                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 80                        default_model.clone(),
 81                        api_url.clone(),
 82                        client.http_client(),
 83                        settings_version,
 84                    ));
 85                }
 86                #[cfg(test)]
 87                (CompletionProvider::Fake(_), _) => unimplemented!(),
 88            }
 89        })
 90    })
 91    .detach();
 92}
 93
 94pub enum CompletionProvider {
 95    OpenAi(OpenAiCompletionProvider),
 96    ZedDotDev(ZedDotDevCompletionProvider),
 97    #[cfg(test)]
 98    Fake(FakeCompletionProvider),
 99}
100
101impl gpui::Global for CompletionProvider {}
102
103impl CompletionProvider {
104    pub fn global(cx: &AppContext) -> &Self {
105        cx.global::<Self>()
106    }
107
108    pub fn settings_version(&self) -> usize {
109        match self {
110            CompletionProvider::OpenAi(provider) => provider.settings_version(),
111            CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
112            #[cfg(test)]
113            CompletionProvider::Fake(_) => unimplemented!(),
114        }
115    }
116
117    pub fn is_authenticated(&self) -> bool {
118        match self {
119            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
120            CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
121            #[cfg(test)]
122            CompletionProvider::Fake(_) => true,
123        }
124    }
125
126    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
127        match self {
128            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
129            CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
130            #[cfg(test)]
131            CompletionProvider::Fake(_) => Task::ready(Ok(())),
132        }
133    }
134
135    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
136        match self {
137            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
138            CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
139            #[cfg(test)]
140            CompletionProvider::Fake(_) => unimplemented!(),
141        }
142    }
143
144    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
145        match self {
146            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
147            CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
148            #[cfg(test)]
149            CompletionProvider::Fake(_) => Task::ready(Ok(())),
150        }
151    }
152
153    pub fn default_model(&self) -> LanguageModel {
154        match self {
155            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
156            CompletionProvider::ZedDotDev(provider) => {
157                LanguageModel::ZedDotDev(provider.default_model())
158            }
159            #[cfg(test)]
160            CompletionProvider::Fake(_) => unimplemented!(),
161        }
162    }
163
164    pub fn count_tokens(
165        &self,
166        request: LanguageModelRequest,
167        cx: &AppContext,
168    ) -> BoxFuture<'static, Result<usize>> {
169        match self {
170            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
171            CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
172            #[cfg(test)]
173            CompletionProvider::Fake(_) => unimplemented!(),
174        }
175    }
176
177    pub fn complete(
178        &self,
179        request: LanguageModelRequest,
180    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
181        match self {
182            CompletionProvider::OpenAi(provider) => provider.complete(request),
183            CompletionProvider::ZedDotDev(provider) => provider.complete(request),
184            #[cfg(test)]
185            CompletionProvider::Fake(provider) => provider.complete(),
186        }
187    }
188}