anthropic.rs

  1use crate::count_open_ai_tokens;
  2use crate::{
  3    assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
  4    Role,
  5};
  6use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
  7use anyhow::{anyhow, Result};
  8use editor::{Editor, EditorElement, EditorStyle};
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 10use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
 11use http::HttpClient;
 12use settings::Settings;
 13use std::time::Duration;
 14use std::{env, sync::Arc};
 15use theme::ThemeSettings;
 16use ui::prelude::*;
 17use util::ResultExt;
 18
 19pub struct AnthropicCompletionProvider {
 20    api_key: Option<String>,
 21    api_url: String,
 22    default_model: AnthropicModel,
 23    http_client: Arc<dyn HttpClient>,
 24    low_speed_timeout: Option<Duration>,
 25    settings_version: usize,
 26}
 27
 28impl AnthropicCompletionProvider {
 29    pub fn new(
 30        default_model: AnthropicModel,
 31        api_url: String,
 32        http_client: Arc<dyn HttpClient>,
 33        low_speed_timeout: Option<Duration>,
 34        settings_version: usize,
 35    ) -> Self {
 36        Self {
 37            api_key: None,
 38            api_url,
 39            default_model,
 40            http_client,
 41            low_speed_timeout,
 42            settings_version,
 43        }
 44    }
 45
 46    pub fn update(
 47        &mut self,
 48        default_model: AnthropicModel,
 49        api_url: String,
 50        low_speed_timeout: Option<Duration>,
 51        settings_version: usize,
 52    ) {
 53        self.default_model = default_model;
 54        self.api_url = api_url;
 55        self.low_speed_timeout = low_speed_timeout;
 56        self.settings_version = settings_version;
 57    }
 58
 59    pub fn settings_version(&self) -> usize {
 60        self.settings_version
 61    }
 62
 63    pub fn is_authenticated(&self) -> bool {
 64        self.api_key.is_some()
 65    }
 66
 67    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 68        if self.is_authenticated() {
 69            Task::ready(Ok(()))
 70        } else {
 71            let api_url = self.api_url.clone();
 72            cx.spawn(|mut cx| async move {
 73                let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
 74                    api_key
 75                } else {
 76                    let (_, api_key) = cx
 77                        .update(|cx| cx.read_credentials(&api_url))?
 78                        .await?
 79                        .ok_or_else(|| anyhow!("credentials not found"))?;
 80                    String::from_utf8(api_key)?
 81                };
 82                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 83                    if let CompletionProvider::Anthropic(provider) = provider {
 84                        provider.api_key = Some(api_key);
 85                    }
 86                })
 87            })
 88        }
 89    }
 90
 91    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 92        let delete_credentials = cx.delete_credentials(&self.api_url);
 93        cx.spawn(|mut cx| async move {
 94            delete_credentials.await.log_err();
 95            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 96                if let CompletionProvider::Anthropic(provider) = provider {
 97                    provider.api_key = None;
 98                }
 99            })
100        })
101    }
102
103    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
104        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
105            .into()
106    }
107
108    pub fn default_model(&self) -> AnthropicModel {
109        self.default_model.clone()
110    }
111
112    pub fn count_tokens(
113        &self,
114        request: LanguageModelRequest,
115        cx: &AppContext,
116    ) -> BoxFuture<'static, Result<usize>> {
117        count_open_ai_tokens(request, cx.background_executor())
118    }
119
120    pub fn complete(
121        &self,
122        request: LanguageModelRequest,
123    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
124        let request = self.to_anthropic_request(request);
125
126        let http_client = self.http_client.clone();
127        let api_key = self.api_key.clone();
128        let api_url = self.api_url.clone();
129        let low_speed_timeout = self.low_speed_timeout;
130        async move {
131            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
132            let request = stream_completion(
133                http_client.as_ref(),
134                &api_url,
135                &api_key,
136                request,
137                low_speed_timeout,
138            );
139            let response = request.await?;
140            let stream = response
141                .filter_map(|response| async move {
142                    match response {
143                        Ok(response) => match response {
144                            anthropic::ResponseEvent::ContentBlockStart {
145                                content_block, ..
146                            } => match content_block {
147                                anthropic::ContentBlock::Text { text } => Some(Ok(text)),
148                            },
149                            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
150                                match delta {
151                                    anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
152                                }
153                            }
154                            _ => None,
155                        },
156                        Err(error) => Some(Err(error)),
157                    }
158                })
159                .boxed();
160            Ok(stream)
161        }
162        .boxed()
163    }
164
165    fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
166        let model = match request.model {
167            LanguageModel::Anthropic(model) => model,
168            _ => self.default_model(),
169        };
170
171        let mut system_message = String::new();
172
173        let mut messages: Vec<RequestMessage> = Vec::new();
174        for message in request.messages {
175            if message.content.is_empty() {
176                continue;
177            }
178
179            match message.role {
180                Role::User | Role::Assistant => {
181                    let role = match message.role {
182                        Role::User => AnthropicRole::User,
183                        Role::Assistant => AnthropicRole::Assistant,
184                        _ => unreachable!(),
185                    };
186
187                    if let Some(last_message) = messages.last_mut() {
188                        if last_message.role == role {
189                            last_message.content.push_str("\n\n");
190                            last_message.content.push_str(&message.content);
191                            continue;
192                        }
193                    }
194
195                    messages.push(RequestMessage {
196                        role,
197                        content: message.content,
198                    });
199                }
200                Role::System => {
201                    if !system_message.is_empty() {
202                        system_message.push_str("\n\n");
203                    }
204                    system_message.push_str(&message.content);
205                }
206            }
207        }
208
209        Request {
210            model,
211            messages,
212            stream: true,
213            system: system_message,
214            max_tokens: 4092,
215        }
216    }
217}
218
219struct AuthenticationPrompt {
220    api_key: View<Editor>,
221    api_url: String,
222}
223
224impl AuthenticationPrompt {
225    fn new(api_url: String, cx: &mut WindowContext) -> Self {
226        Self {
227            api_key: cx.new_view(|cx| {
228                let mut editor = Editor::single_line(cx);
229                editor.set_placeholder_text(
230                    "sk-000000000000000000000000000000000000000000000000",
231                    cx,
232                );
233                editor
234            }),
235            api_url,
236        }
237    }
238
239    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
240        let api_key = self.api_key.read(cx).text(cx);
241        if api_key.is_empty() {
242            return;
243        }
244
245        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
246        cx.spawn(|_, mut cx| async move {
247            write_credentials.await?;
248            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
249                if let CompletionProvider::Anthropic(provider) = provider {
250                    provider.api_key = Some(api_key);
251                }
252            })
253        })
254        .detach_and_log_err(cx);
255    }
256
257    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
258        let settings = ThemeSettings::get_global(cx);
259        let text_style = TextStyle {
260            color: cx.theme().colors().text,
261            font_family: settings.ui_font.family.clone(),
262            font_features: settings.ui_font.features.clone(),
263            font_size: rems(0.875).into(),
264            font_weight: FontWeight::NORMAL,
265            font_style: FontStyle::Normal,
266            line_height: relative(1.3),
267            background_color: None,
268            underline: None,
269            strikethrough: None,
270            white_space: WhiteSpace::Normal,
271        };
272        EditorElement::new(
273            &self.api_key,
274            EditorStyle {
275                background: cx.theme().colors().editor_background,
276                local_player: cx.theme().players().local(),
277                text: text_style,
278                ..Default::default()
279            },
280        )
281    }
282}
283
284impl Render for AuthenticationPrompt {
285    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
286        const INSTRUCTIONS: [&str; 4] = [
287            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
288            "You can create an API key at: https://console.anthropic.com/settings/keys",
289            "",
290            "Paste your Anthropic API key below and hit enter to use the assistant:",
291        ];
292
293        v_flex()
294            .p_4()
295            .size_full()
296            .on_action(cx.listener(Self::save_api_key))
297            .children(
298                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
299            )
300            .child(
301                h_flex()
302                    .w_full()
303                    .my_2()
304                    .px_2()
305                    .py_1()
306                    .bg(cx.theme().colors().editor_background)
307                    .rounded_md()
308                    .child(self.render_api_key_editor(cx)),
309            )
310            .child(
311                Label::new(
312                    "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
313                )
314                .size(LabelSize::Small),
315            )
316            .child(
317                h_flex()
318                    .gap_2()
319                    .child(Label::new("Click on").size(LabelSize::Small))
320                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
321                    .child(
322                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
323                    ),
324            )
325            .into_any()
326    }
327}