Detailed changes
@@ -225,6 +225,8 @@ dependencies = [
"anyhow",
"futures 0.3.28",
"http 0.1.0",
+ "isahc",
+ "schemars",
"serde",
"serde_json",
"tokio",
@@ -332,6 +334,7 @@ dependencies = [
name = "assistant"
version = "0.1.0"
dependencies = [
+ "anthropic",
"anyhow",
"chrono",
"client",
@@ -5,6 +5,10 @@ edition = "2021"
publish = false
license = "AGPL-3.0-or-later"
+[features]
+default = []
+schemars = ["dep:schemars"]
+
[lints]
workspace = true
@@ -15,6 +19,8 @@ path = "src/anthropic.rs"
anyhow.workspace = true
futures.workspace = true
http.workspace = true
+isahc.workspace = true
+schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
@@ -1,17 +1,21 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
-use std::{convert::TryFrom, sync::Arc};
+use std::{convert::TryFrom, time::Duration};
+pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum Model {
#[default]
- #[serde(rename = "claude-3-opus-20240229")]
+ #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
Claude3Opus,
- #[serde(rename = "claude-3-sonnet-20240229")]
+ #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
Claude3Sonnet,
- #[serde(rename = "claude-3-haiku-20240307")]
+ #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
Claude3Haiku,
}
@@ -28,6 +32,14 @@ impl Model {
}
}
+ pub fn id(&self) -> &'static str {
+ match self {
+ Model::Claude3Opus => "claude-3-opus-20240229",
+ Model::Claude3Sonnet => "claude-3-sonnet-20240229",
+ Model::Claude3Haiku => "claude-3-opus-20240307",
+ }
+ }
+
pub fn display_name(&self) -> &'static str {
match self {
Self::Claude3Opus => "Claude 3 Opus",
@@ -141,20 +153,24 @@ pub enum TextDelta {
}
pub async fn stream_completion(
- client: Arc<dyn HttpClient>,
+ client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
+ low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let uri = format!("{api_url}/v1/messages");
- let request = HttpRequest::builder()
+ let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
- .header("Anthropic-Beta", "messages-2023-12-15")
+ .header("Anthropic-Beta", "tools-2024-04-04")
.header("X-Api-Key", api_key)
- .header("Content-Type", "application/json")
- .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ .header("Content-Type", "application/json");
+ if let Some(low_speed_timeout) = low_speed_timeout {
+ request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+ }
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -11,6 +11,7 @@ doctest = false
[dependencies]
anyhow.workspace = true
+anthropic = { workspace = true, features = ["schemars"] }
chrono.workspace = true
client.workspace = true
collections.workspace = true
@@ -7,7 +7,7 @@ mod saved_conversation;
mod streaming_diff;
pub use assistant_panel::AssistantPanel;
-use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
+use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*;
@@ -72,6 +72,7 @@ impl Display for Role {
pub enum LanguageModel {
ZedDotDev(ZedDotDevModel),
OpenAi(OpenAiModel),
+ Anthropic(AnthropicModel),
}
impl Default for LanguageModel {
@@ -84,6 +85,7 @@ impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
+ LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
}
}
@@ -91,6 +93,7 @@ impl LanguageModel {
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
+ LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::ZedDotDev(model) => model.display_name().into(),
}
}
@@ -98,6 +101,7 @@ impl LanguageModel {
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
+ LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::ZedDotDev(model) => model.max_token_count(),
}
}
@@ -105,6 +109,7 @@ impl LanguageModel {
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
+ LanguageModel::Anthropic(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(),
}
}
@@ -800,6 +800,11 @@ impl AssistantPanel {
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
}),
+ LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
+ anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
+ anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
+ anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
+ }),
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
@@ -1,5 +1,6 @@
use std::fmt;
+pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
pub use open_ai::Model as OpenAiModel;
use schemars::{
@@ -161,6 +162,15 @@ pub enum AssistantProvider {
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
+ #[serde(rename = "anthropic")]
+ Anthropic {
+ #[serde(default)]
+ default_model: AnthropicModel,
+ #[serde(default = "anthropic_api_url")]
+ api_url: String,
+ #[serde(default)]
+ low_speed_timeout_in_seconds: Option<u64>,
+ },
}
impl Default for AssistantProvider {
@@ -172,7 +182,11 @@ impl Default for AssistantProvider {
}
fn open_ai_url() -> String {
- "https://api.openai.com/v1".into()
+ open_ai::OPEN_AI_API_URL.to_string()
+}
+
+fn anthropic_api_url() -> String {
+ anthropic::ANTHROPIC_API_URL.to_string()
}
#[derive(Default, Debug, Deserialize, Serialize)]
@@ -1,8 +1,10 @@
+mod anthropic;
#[cfg(test)]
mod fake;
mod open_ai;
mod zed;
+pub use anthropic::*;
#[cfg(test)]
pub use fake::*;
pub use open_ai::*;
@@ -42,6 +44,17 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
+ AssistantProvider::Anthropic {
+ default_model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
+ default_model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ )),
};
cx.set_global(provider);
@@ -64,13 +77,28 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
);
}
+ (
+ CompletionProvider::Anthropic(provider),
+ AssistantProvider::Anthropic {
+ default_model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ },
+ ) => {
+ provider.update(
+ default_model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ );
+ }
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { default_model },
) => {
provider.update(default_model.clone(), settings_version);
}
- (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
+ (_, AssistantProvider::ZedDotDev { default_model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
@@ -79,7 +107,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
));
}
(
- CompletionProvider::ZedDotDev(_),
+ _,
AssistantProvider::OpenAi {
default_model,
api_url,
@@ -94,8 +122,22 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
));
}
- #[cfg(test)]
- (CompletionProvider::Fake(_), _) => unimplemented!(),
+ (
+ _,
+ AssistantProvider::Anthropic {
+ default_model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ },
+ ) => {
+ *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
+ default_model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ ));
+ }
}
})
})
@@ -104,6 +146,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider),
+ Anthropic(AnthropicCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
@@ -119,6 +162,7 @@ impl CompletionProvider {
pub fn settings_version(&self) -> usize {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
+ CompletionProvider::Anthropic(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
@@ -128,6 +172,7 @@ impl CompletionProvider {
pub fn is_authenticated(&self) -> bool {
match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
+ CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
@@ -137,6 +182,7 @@ impl CompletionProvider {
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
+ CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
@@ -146,6 +192,7 @@ impl CompletionProvider {
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
+ CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
@@ -155,6 +202,7 @@ impl CompletionProvider {
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
+ CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
@@ -164,6 +212,9 @@ impl CompletionProvider {
pub fn default_model(&self) -> LanguageModel {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
+ CompletionProvider::Anthropic(provider) => {
+ LanguageModel::Anthropic(provider.default_model())
+ }
CompletionProvider::ZedDotDev(provider) => {
LanguageModel::ZedDotDev(provider.default_model())
}
@@ -179,6 +230,7 @@ impl CompletionProvider {
) -> BoxFuture<'static, Result<usize>> {
match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
+ CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
@@ -191,6 +243,7 @@ impl CompletionProvider {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match self {
CompletionProvider::OpenAi(provider) => provider.complete(request),
+ CompletionProvider::Anthropic(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
@@ -0,0 +1,317 @@
+use crate::count_open_ai_tokens;
+use crate::{
+ assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
+ Role,
+};
+use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
+use anyhow::{anyhow, Result};
+use editor::{Editor, EditorElement, EditorStyle};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
+use http::HttpClient;
+use settings::Settings;
+use std::time::Duration;
+use std::{env, sync::Arc};
+use theme::ThemeSettings;
+use ui::prelude::*;
+use util::ResultExt;
+
+pub struct AnthropicCompletionProvider {
+ api_key: Option<String>,
+ api_url: String,
+ default_model: AnthropicModel,
+ http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+}
+
+impl AnthropicCompletionProvider {
+ pub fn new(
+ default_model: AnthropicModel,
+ api_url: String,
+ http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) -> Self {
+ Self {
+ api_key: None,
+ api_url,
+ default_model,
+ http_client,
+ low_speed_timeout,
+ settings_version,
+ }
+ }
+
+ pub fn update(
+ &mut self,
+ default_model: AnthropicModel,
+ api_url: String,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) {
+ self.default_model = default_model;
+ self.api_url = api_url;
+ self.low_speed_timeout = low_speed_timeout;
+ self.settings_version = settings_version;
+ }
+
+ pub fn settings_version(&self) -> usize {
+ self.settings_version
+ }
+
+ pub fn is_authenticated(&self) -> bool {
+ self.api_key.is_some()
+ }
+
+ pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ if self.is_authenticated() {
+ Task::ready(Ok(()))
+ } else {
+ let api_url = self.api_url.clone();
+ cx.spawn(|mut cx| async move {
+ let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
+ api_key
+ } else {
+ let (_, api_key) = cx
+ .update(|cx| cx.read_credentials(&api_url))?
+ .await?
+ .ok_or_else(|| anyhow!("credentials not found"))?;
+ String::from_utf8(api_key)?
+ };
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::Anthropic(provider) = provider {
+ provider.api_key = Some(api_key);
+ }
+ })
+ })
+ }
+ }
+
+ pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ let delete_credentials = cx.delete_credentials(&self.api_url);
+ cx.spawn(|mut cx| async move {
+ delete_credentials.await.log_err();
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::Anthropic(provider) = provider {
+ provider.api_key = None;
+ }
+ })
+ })
+ }
+
+ pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
+ .into()
+ }
+
+ pub fn default_model(&self) -> AnthropicModel {
+ self.default_model.clone()
+ }
+
+ pub fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ count_open_ai_tokens(request, cx.background_executor())
+ }
+
+ pub fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = self.to_anthropic_request(request);
+
+ let http_client = self.http_client.clone();
+ let api_key = self.api_key.clone();
+ let api_url = self.api_url.clone();
+ let low_speed_timeout = self.low_speed_timeout;
+ async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let request = stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ low_speed_timeout,
+ );
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(response) => match response {
+ anthropic::ResponseEvent::ContentBlockStart {
+ content_block, ..
+ } => match content_block {
+ anthropic::ContentBlock::Text { text } => Some(Ok(text)),
+ },
+ anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
+ match delta {
+ anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
+ }
+ }
+ _ => None,
+ },
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+
+ fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
+ let model = match request.model {
+ LanguageModel::Anthropic(model) => model,
+ _ => self.default_model(),
+ };
+
+ let mut system_message = String::new();
+ let messages = request
+ .messages
+ .into_iter()
+ .filter_map(|message| {
+ match message.role {
+ Role::User => Some(RequestMessage {
+ role: AnthropicRole::User,
+ content: message.content,
+ }),
+ Role::Assistant => Some(RequestMessage {
+ role: AnthropicRole::Assistant,
+ content: message.content,
+ }),
+ // Anthropic's API breaks system instructions out as a separate field rather
+ // than having a system message role.
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.content);
+
+ None
+ }
+ }
+ })
+ .collect();
+
+ Request {
+ model,
+ messages,
+ stream: true,
+ system: system_message,
+ max_tokens: 4092,
+ }
+ }
+}
+
+struct AuthenticationPrompt {
+ api_key: View<Editor>,
+ api_url: String,
+}
+
+impl AuthenticationPrompt {
+ fn new(api_url: String, cx: &mut WindowContext) -> Self {
+ Self {
+ api_key: cx.new_view(|cx| {
+ let mut editor = Editor::single_line(cx);
+ editor.set_placeholder_text(
+ "sk-000000000000000000000000000000000000000000000000",
+ cx,
+ );
+ editor
+ }),
+ api_url,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ let api_key = self.api_key.read(cx).text(cx);
+ if api_key.is_empty() {
+ return;
+ }
+
+ let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
+ cx.spawn(|_, mut cx| async move {
+ write_credentials.await?;
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::Anthropic(provider) = provider {
+ provider.api_key = Some(api_key);
+ }
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let settings = ThemeSettings::get_global(cx);
+ let text_style = TextStyle {
+ color: cx.theme().colors().text,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features.clone(),
+ font_size: rems(0.875).into(),
+ font_weight: FontWeight::NORMAL,
+ font_style: FontStyle::Normal,
+ line_height: relative(1.3),
+ background_color: None,
+ underline: None,
+ strikethrough: None,
+ white_space: WhiteSpace::Normal,
+ };
+ EditorElement::new(
+ &self.api_key,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ }
+}
+
+impl Render for AuthenticationPrompt {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ const INSTRUCTIONS: [&str; 4] = [
+ "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
+ "You can create an API key at: https://console.anthropic.com/settings/keys",
+ "",
+ "Paste your Anthropic API key below and hit enter to use the assistant:",
+ ];
+
+ v_flex()
+ .p_4()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .children(
+ INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .my_2()
+ .px_2()
+ .py_1()
+ .bg(cx.theme().colors().editor_background)
+ .rounded_md()
+ .child(self.render_api_key_editor(cx)),
+ )
+ .child(
+ Label::new(
+ "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
+ )
+ .size(LabelSize::Small),
+ )
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new("Click on").size(LabelSize::Small))
+ .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
+ .child(
+ Label::new("in the status bar to close this panel.").size(LabelSize::Small),
+ ),
+ )
+ .into_any()
+ }
+}
@@ -151,8 +151,8 @@ impl OpenAiCompletionProvider {
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
- LanguageModel::ZedDotDev(_) => self.default_model(),
LanguageModel::OpenAi(model) => model,
+ _ => self.default_model(),
};
Request {
@@ -205,8 +205,12 @@ pub fn count_open_ai_tokens(
match request.model {
LanguageModel::OpenAi(OpenAiModel::FourOmni)
- | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) => {
- // Tiktoken doesn't yet support gpt-4o, so we manually use the
+ | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
+ | LanguageModel::Anthropic(_)
+ | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
+ | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
+ | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
+ // Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
}
@@ -78,7 +78,6 @@ impl ZedDotDevCompletionProvider {
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
- LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
@@ -108,6 +107,7 @@ impl ZedDotDevCompletionProvider {
}
.boxed()
}
+ _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
}
}
@@ -4489,8 +4489,8 @@ async fn complete_with_anthropic(
.collect();
let mut stream = anthropic::stream_completion(
- session.http_client.clone(),
- "https://api.anthropic.com",
+ session.http_client.as_ref(),
+ anthropic::ANTHROPIC_API_URL,
&api_key,
anthropic::Request {
model,
@@ -4499,6 +4499,7 @@ async fn complete_with_anthropic(
system: system_message,
max_tokens: 4092,
},
+ None,
)
.await?;