anthropic.rs

  1use anthropic::{stream_completion, Request, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use editor::{Editor, EditorElement, EditorStyle};
  5use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  6use gpui::{
  7    AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
  8    WhiteSpace,
  9};
 10use http::HttpClient;
 11use settings::{Settings, SettingsStore};
 12use std::{sync::Arc, time::Duration};
 13use strum::IntoEnumIterator;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18use crate::{
 19    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 20    LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
 21    LanguageModelRequest, LanguageModelRequestMessage, Role,
 22};
 23
 24const PROVIDER_NAME: &str = "anthropic";
 25
 26#[derive(Default, Clone, Debug, PartialEq)]
 27pub struct AnthropicSettings {
 28    pub api_url: String,
 29    pub low_speed_timeout: Option<Duration>,
 30    pub available_models: Vec<anthropic::Model>,
 31}
 32
 33pub struct AnthropicLanguageModelProvider {
 34    http_client: Arc<dyn HttpClient>,
 35    state: gpui::Model<State>,
 36}
 37
 38struct State {
 39    api_key: Option<String>,
 40    settings: AnthropicSettings,
 41    _subscription: Subscription,
 42}
 43
 44impl AnthropicLanguageModelProvider {
 45    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 46        let state = cx.new_model(|cx| State {
 47            api_key: None,
 48            settings: AnthropicSettings::default(),
 49            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 50                this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
 51                cx.notify();
 52            }),
 53        });
 54
 55        Self { http_client, state }
 56    }
 57}
 58impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 59    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 60        Some(cx.observe(&self.state, |_, _, cx| {
 61            cx.notify();
 62        }))
 63    }
 64}
 65
 66impl LanguageModelProvider for AnthropicLanguageModelProvider {
 67    fn name(&self) -> LanguageModelProviderName {
 68        LanguageModelProviderName(PROVIDER_NAME.into())
 69    }
 70
 71    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 72        let mut models = HashMap::default();
 73
 74        // Add base models from anthropic::Model::iter()
 75        for model in anthropic::Model::iter() {
 76            if !matches!(model, anthropic::Model::Custom { .. }) {
 77                models.insert(model.id().to_string(), model);
 78            }
 79        }
 80
 81        // Override with available models from settings
 82        for model in &self.state.read(cx).settings.available_models {
 83            models.insert(model.id().to_string(), model.clone());
 84        }
 85
 86        models
 87            .into_values()
 88            .map(|model| {
 89                Arc::new(AnthropicModel {
 90                    id: LanguageModelId::from(model.id().to_string()),
 91                    model,
 92                    state: self.state.clone(),
 93                    http_client: self.http_client.clone(),
 94                }) as Arc<dyn LanguageModel>
 95            })
 96            .collect()
 97    }
 98
 99    fn is_authenticated(&self, cx: &AppContext) -> bool {
100        self.state.read(cx).api_key.is_some()
101    }
102
103    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
104        if self.is_authenticated(cx) {
105            Task::ready(Ok(()))
106        } else {
107            let api_url = self.state.read(cx).settings.api_url.clone();
108            let state = self.state.clone();
109            cx.spawn(|mut cx| async move {
110                let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
111                    api_key
112                } else {
113                    let (_, api_key) = cx
114                        .update(|cx| cx.read_credentials(&api_url))?
115                        .await?
116                        .ok_or_else(|| anyhow!("credentials not found"))?;
117                    String::from_utf8(api_key)?
118                };
119
120                state.update(&mut cx, |this, cx| {
121                    this.api_key = Some(api_key);
122                    cx.notify();
123                })
124            })
125        }
126    }
127
128    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
129        cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
130            .into()
131    }
132
133    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
134        let state = self.state.clone();
135        let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
136        cx.spawn(|mut cx| async move {
137            delete_credentials.await.log_err();
138            state.update(&mut cx, |this, cx| {
139                this.api_key = None;
140                cx.notify();
141            })
142        })
143    }
144}
145
146pub struct AnthropicModel {
147    id: LanguageModelId,
148    model: anthropic::Model,
149    state: gpui::Model<State>,
150    http_client: Arc<dyn HttpClient>,
151}
152
153impl AnthropicModel {
154    fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
155        preprocess_anthropic_request(&mut request);
156
157        let mut system_message = String::new();
158        if request
159            .messages
160            .first()
161            .map_or(false, |message| message.role == Role::System)
162        {
163            system_message = request.messages.remove(0).content;
164        }
165
166        Request {
167            model: self.model.clone(),
168            messages: request
169                .messages
170                .iter()
171                .map(|msg| RequestMessage {
172                    role: match msg.role {
173                        Role::User => anthropic::Role::User,
174                        Role::Assistant => anthropic::Role::Assistant,
175                        Role::System => unreachable!("filtered out by preprocess_request"),
176                    },
177                    content: msg.content.clone(),
178                })
179                .collect(),
180            stream: true,
181            system: system_message,
182            max_tokens: 4092,
183        }
184    }
185}
186
187pub fn count_anthropic_tokens(
188    request: LanguageModelRequest,
189    cx: &AppContext,
190) -> BoxFuture<'static, Result<usize>> {
191    cx.background_executor()
192        .spawn(async move {
193            let messages = request
194                .messages
195                .into_iter()
196                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
197                    role: match message.role {
198                        Role::User => "user".into(),
199                        Role::Assistant => "assistant".into(),
200                        Role::System => "system".into(),
201                    },
202                    content: Some(message.content),
203                    name: None,
204                    function_call: None,
205                })
206                .collect::<Vec<_>>();
207
208            // Tiktoken doesn't yet support these models, so we manually use the
209            // same tokenizer as GPT-4.
210            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
211        })
212        .boxed()
213}
214
215impl LanguageModel for AnthropicModel {
216    fn id(&self) -> LanguageModelId {
217        self.id.clone()
218    }
219
220    fn name(&self) -> LanguageModelName {
221        LanguageModelName::from(self.model.display_name().to_string())
222    }
223
224    fn provider_name(&self) -> LanguageModelProviderName {
225        LanguageModelProviderName(PROVIDER_NAME.into())
226    }
227
228    fn telemetry_id(&self) -> String {
229        format!("anthropic/{}", self.model.id())
230    }
231
232    fn max_token_count(&self) -> usize {
233        self.model.max_token_count()
234    }
235
236    fn count_tokens(
237        &self,
238        request: LanguageModelRequest,
239        cx: &AppContext,
240    ) -> BoxFuture<'static, Result<usize>> {
241        count_anthropic_tokens(request, cx)
242    }
243
244    fn stream_completion(
245        &self,
246        request: LanguageModelRequest,
247        cx: &AsyncAppContext,
248    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
249        let request = self.to_anthropic_request(request);
250
251        let http_client = self.http_client.clone();
252        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
253            (
254                state.api_key.clone(),
255                state.settings.api_url.clone(),
256                state.settings.low_speed_timeout,
257            )
258        }) else {
259            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
260        };
261
262        async move {
263            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
264            let request = stream_completion(
265                http_client.as_ref(),
266                &api_url,
267                &api_key,
268                request,
269                low_speed_timeout,
270            );
271            let response = request.await?;
272            let stream = response
273                .filter_map(|response| async move {
274                    match response {
275                        Ok(response) => match response {
276                            anthropic::ResponseEvent::ContentBlockStart {
277                                content_block, ..
278                            } => match content_block {
279                                anthropic::ContentBlock::Text { text } => Some(Ok(text)),
280                            },
281                            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
282                                match delta {
283                                    anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
284                                }
285                            }
286                            _ => None,
287                        },
288                        Err(error) => Some(Err(error)),
289                    }
290                })
291                .boxed();
292            Ok(stream)
293        }
294        .boxed()
295    }
296}
297
298pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
299    let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
300    let mut system_message = String::new();
301
302    for message in request.messages.drain(..) {
303        if message.content.is_empty() {
304            continue;
305        }
306
307        match message.role {
308            Role::User | Role::Assistant => {
309                if let Some(last_message) = new_messages.last_mut() {
310                    if last_message.role == message.role {
311                        last_message.content.push_str("\n\n");
312                        last_message.content.push_str(&message.content);
313                        continue;
314                    }
315                }
316
317                new_messages.push(message);
318            }
319            Role::System => {
320                if !system_message.is_empty() {
321                    system_message.push_str("\n\n");
322                }
323                system_message.push_str(&message.content);
324            }
325        }
326    }
327
328    if !system_message.is_empty() {
329        new_messages.insert(
330            0,
331            LanguageModelRequestMessage {
332                role: Role::System,
333                content: system_message,
334            },
335        );
336    }
337
338    request.messages = new_messages;
339}
340
341struct AuthenticationPrompt {
342    api_key: View<Editor>,
343    state: gpui::Model<State>,
344}
345
346impl AuthenticationPrompt {
347    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
348        Self {
349            api_key: cx.new_view(|cx| {
350                let mut editor = Editor::single_line(cx);
351                editor.set_placeholder_text(
352                    "sk-000000000000000000000000000000000000000000000000",
353                    cx,
354                );
355                editor
356            }),
357            state,
358        }
359    }
360
361    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
362        let api_key = self.api_key.read(cx).text(cx);
363        if api_key.is_empty() {
364            return;
365        }
366
367        let write_credentials = cx.write_credentials(
368            &self.state.read(cx).settings.api_url,
369            "Bearer",
370            api_key.as_bytes(),
371        );
372        let state = self.state.clone();
373        cx.spawn(|_, mut cx| async move {
374            write_credentials.await?;
375
376            state.update(&mut cx, |this, cx| {
377                this.api_key = Some(api_key);
378                cx.notify();
379            })
380        })
381        .detach_and_log_err(cx);
382    }
383
384    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
385        let settings = ThemeSettings::get_global(cx);
386        let text_style = TextStyle {
387            color: cx.theme().colors().text,
388            font_family: settings.ui_font.family.clone(),
389            font_features: settings.ui_font.features.clone(),
390            font_size: rems(0.875).into(),
391            font_weight: settings.ui_font.weight,
392            font_style: FontStyle::Normal,
393            line_height: relative(1.3),
394            background_color: None,
395            underline: None,
396            strikethrough: None,
397            white_space: WhiteSpace::Normal,
398        };
399        EditorElement::new(
400            &self.api_key,
401            EditorStyle {
402                background: cx.theme().colors().editor_background,
403                local_player: cx.theme().players().local(),
404                text: text_style,
405                ..Default::default()
406            },
407        )
408    }
409}
410
411impl Render for AuthenticationPrompt {
412    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
413        const INSTRUCTIONS: [&str; 4] = [
414            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
415            "You can create an API key at: https://console.anthropic.com/settings/keys",
416            "",
417            "Paste your Anthropic API key below and hit enter to use the assistant:",
418        ];
419
420        v_flex()
421            .p_4()
422            .size_full()
423            .on_action(cx.listener(Self::save_api_key))
424            .children(
425                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
426            )
427            .child(
428                h_flex()
429                    .w_full()
430                    .my_2()
431                    .px_2()
432                    .py_1()
433                    .bg(cx.theme().colors().editor_background)
434                    .rounded_md()
435                    .child(self.render_api_key_editor(cx)),
436            )
437            .child(
438                Label::new(
439                    "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
440                )
441                .size(LabelSize::Small),
442            )
443            .child(
444                h_flex()
445                    .gap_2()
446                    .child(Label::new("Click on").size(LabelSize::Small))
447                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
448                    .child(
449                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
450                    ),
451            )
452            .into_any()
453    }
454}