Detailed changes
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
-use gpui::serde_json;
+use gpui::{serde_json, ViewContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
@@ -20,9 +20,11 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::completion::OPENAI_API_URL;
lazy_static! {
- static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
@@ -87,6 +89,7 @@ impl Embedding {
#[derive(Clone)]
pub struct OpenAIEmbeddings {
+ pub api_key: Option<String>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
- pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
+ if self.api_key.is_none() {
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
+ if let Some(api_key) = api_key {
+ self.api_key = Some(api_key);
+ }
+ }
+ }
+ pub fn new(
+ api_key: Option<String>,
+ client: Arc<dyn HttpClient>,
+ executor: Arc<Background>,
+ ) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
+ api_key,
client,
executor,
rate_limit_count_rx,
@@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
fn is_authenticated(&self) -> bool {
- OPENAI_API_KEY.as_ref().is_some()
+ self.api_key.is_some()
}
+
fn max_tokens_per_batch(&self) -> usize {
50000
}
@@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
- let api_key = OPENAI_API_KEY
- .as_ref()
- .ok_or_else(|| anyhow!("no api key"))?;
+ let Some(api_key) = self.api_key.clone() else {
+ return Err(anyhow!("no open ai key provided"));
+ };
let mut request_number = 0;
let mut rate_limiting = false;
@@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
while request_number < MAX_RETRIES {
response = self
.send_request(
- api_key,
+ &api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
@@ -351,33 +351,32 @@ impl View for ProjectSearchView {
SemanticIndexStatus::NotAuthenticated => {
major_text = Cow::Borrowed("Not Authenticated");
show_minor_text = false;
- Some(
- "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables"
- .to_string(),
- )
+ Some(vec![
+ "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables."
+ .to_string(), "If you authenticated using the Assistant Panel, please restart Zed to Authenticate.".to_string()])
}
- SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
+ SemanticIndexStatus::Indexed => Some(vec!["Indexing complete".to_string()]),
SemanticIndexStatus::Indexing {
remaining_files,
rate_limit_expiry,
} => {
if remaining_files == 0 {
- Some(format!("Indexing..."))
+ Some(vec![format!("Indexing...")])
} else {
if let Some(rate_limit_expiry) = rate_limit_expiry {
let remaining_seconds =
rate_limit_expiry.duration_since(Instant::now());
if remaining_seconds > Duration::from_secs(0) {
- Some(format!(
+ Some(vec![format!(
"Remaining files to index (rate limit resets in {}s): {}",
remaining_seconds.as_secs(),
remaining_files
- ))
+ )])
} else {
- Some(format!("Remaining files to index: {}", remaining_files))
+ Some(vec![format!("Remaining files to index: {}", remaining_files)])
}
} else {
- Some(format!("Remaining files to index: {}", remaining_files))
+ Some(vec![format!("Remaining files to index: {}", remaining_files)])
}
}
}
@@ -394,9 +393,11 @@ impl View for ProjectSearchView {
} else {
match current_mode {
SearchMode::Semantic => {
- let mut minor_text = Vec::new();
+ let mut minor_text: Vec<String> = Vec::new();
minor_text.push("".into());
- minor_text.extend(semantic_status);
+ if let Some(semantic_status) = semantic_status {
+ minor_text.extend(semantic_status);
+ }
if show_minor_text {
minor_text
.push("Simply explain the code you are looking to find.".into());
@@ -7,7 +7,10 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::{
+ completion::OPENAI_API_URL,
+ embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
+};
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@@ -55,6 +58,19 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
cx.subscribe_global::<WorkspaceCreated, _>({
move |event, cx| {
let Some(semantic_index) = SemanticIndex::global(cx) else {
@@ -88,7 +104,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -1,3 +1,4 @@
+use ai::completion::OPENAI_API_URL;
use ai::embedding::OpenAIEmbeddings;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
@@ -17,6 +18,7 @@ use std::{cmp, env, fs};
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
use util::http::{self};
use util::paths::EMBEDDINGS_DIR;
+use util::ResultExt;
use zed::languages;
#[derive(Deserialize, Clone, Serialize)]
@@ -469,12 +471,26 @@ fn main() {
.join("embeddings_db");
let languages = languages.clone();
+
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
let fs = fs.clone();
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
languages.clone(),
cx.clone(),
)