1#[cfg(test)]
2mod fake;
3mod open_ai;
4mod zed;
5
6#[cfg(test)]
7pub use fake::*;
8pub use open_ai::*;
9pub use zed::*;
10
11use crate::{
12 assistant_settings::{AssistantProvider, AssistantSettings},
13 LanguageModel, LanguageModelRequest,
14};
15use anyhow::Result;
16use client::Client;
17use futures::{future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
19use settings::{Settings, SettingsStore};
20use std::sync::Arc;
21
22pub fn init(client: Arc<Client>, cx: &mut AppContext) {
23 let mut settings_version = 0;
24 let provider = match &AssistantSettings::get_global(cx).provider {
25 AssistantProvider::ZedDotDev { default_model } => {
26 CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
27 default_model.clone(),
28 client.clone(),
29 settings_version,
30 cx,
31 ))
32 }
33 AssistantProvider::OpenAi {
34 default_model,
35 api_url,
36 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
37 default_model.clone(),
38 api_url.clone(),
39 client.http_client(),
40 settings_version,
41 )),
42 };
43 cx.set_global(provider);
44
45 cx.observe_global::<SettingsStore>(move |cx| {
46 settings_version += 1;
47 cx.update_global::<CompletionProvider, _>(|provider, cx| {
48 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
49 (
50 CompletionProvider::OpenAi(provider),
51 AssistantProvider::OpenAi {
52 default_model,
53 api_url,
54 },
55 ) => {
56 provider.update(default_model.clone(), api_url.clone(), settings_version);
57 }
58 (
59 CompletionProvider::ZedDotDev(provider),
60 AssistantProvider::ZedDotDev { default_model },
61 ) => {
62 provider.update(default_model.clone(), settings_version);
63 }
64 (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
65 *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
66 default_model.clone(),
67 client.clone(),
68 settings_version,
69 cx,
70 ));
71 }
72 (
73 CompletionProvider::ZedDotDev(_),
74 AssistantProvider::OpenAi {
75 default_model,
76 api_url,
77 },
78 ) => {
79 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
80 default_model.clone(),
81 api_url.clone(),
82 client.http_client(),
83 settings_version,
84 ));
85 }
86 #[cfg(test)]
87 (CompletionProvider::Fake(_), _) => unimplemented!(),
88 }
89 })
90 })
91 .detach();
92}
93
94pub enum CompletionProvider {
95 OpenAi(OpenAiCompletionProvider),
96 ZedDotDev(ZedDotDevCompletionProvider),
97 #[cfg(test)]
98 Fake(FakeCompletionProvider),
99}
100
101impl gpui::Global for CompletionProvider {}
102
103impl CompletionProvider {
104 pub fn global(cx: &AppContext) -> &Self {
105 cx.global::<Self>()
106 }
107
108 pub fn settings_version(&self) -> usize {
109 match self {
110 CompletionProvider::OpenAi(provider) => provider.settings_version(),
111 CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
112 #[cfg(test)]
113 CompletionProvider::Fake(_) => unimplemented!(),
114 }
115 }
116
117 pub fn is_authenticated(&self) -> bool {
118 match self {
119 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
120 CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
121 #[cfg(test)]
122 CompletionProvider::Fake(_) => true,
123 }
124 }
125
126 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
127 match self {
128 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
129 CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
130 #[cfg(test)]
131 CompletionProvider::Fake(_) => Task::ready(Ok(())),
132 }
133 }
134
135 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
136 match self {
137 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
138 CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
139 #[cfg(test)]
140 CompletionProvider::Fake(_) => unimplemented!(),
141 }
142 }
143
144 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
145 match self {
146 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
147 CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
148 #[cfg(test)]
149 CompletionProvider::Fake(_) => Task::ready(Ok(())),
150 }
151 }
152
153 pub fn default_model(&self) -> LanguageModel {
154 match self {
155 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
156 CompletionProvider::ZedDotDev(provider) => {
157 LanguageModel::ZedDotDev(provider.default_model())
158 }
159 #[cfg(test)]
160 CompletionProvider::Fake(_) => unimplemented!(),
161 }
162 }
163
164 pub fn count_tokens(
165 &self,
166 request: LanguageModelRequest,
167 cx: &AppContext,
168 ) -> BoxFuture<'static, Result<usize>> {
169 match self {
170 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
171 CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
172 #[cfg(test)]
173 CompletionProvider::Fake(_) => unimplemented!(),
174 }
175 }
176
177 pub fn complete(
178 &self,
179 request: LanguageModelRequest,
180 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
181 match self {
182 CompletionProvider::OpenAi(provider) => provider.complete(request),
183 CompletionProvider::ZedDotDev(provider) => provider.complete(request),
184 #[cfg(test)]
185 CompletionProvider::Fake(provider) => provider.complete(),
186 }
187 }
188}