From 8be73bf187b8f71b84cd709ff889a6ce54790139 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 30 Jan 2025 22:21:40 -0500 Subject: [PATCH] collab: Remove unused `POST /predict_edits` endpoint from LLM service (#23997) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes the `POST /predict_edits` endpoint from the LLM service, as it has been superseded by the corresponding endpoint running in Cloudflare Workers. All traffic is already being routed to the Cloudflare Workers via the Workers route, so nothing is hitting this endpoint running in the LLM service anymore. You can see the drop off in requests to this endpoint on this graph when the Workers route was added: Screenshot 2025-01-30 at 9 18 04 PM We also don't use the `fireworks` crate anymore in this repo, so it has been removed. Release Notes: - N/A --- Cargo.lock | 12 -- Cargo.toml | 2 - crates/collab/Cargo.toml | 1 - crates/collab/src/llm.rs | 161 +------------------ crates/collab/src/llm/prediction_prompt.md | 13 -- crates/fireworks/Cargo.toml | 19 --- crates/fireworks/LICENSE-GPL | 1 - crates/fireworks/src/fireworks.rs | 173 --------------------- 8 files changed, 2 insertions(+), 380 deletions(-) delete mode 100644 crates/collab/src/llm/prediction_prompt.md delete mode 100644 crates/fireworks/Cargo.toml delete mode 120000 crates/fireworks/LICENSE-GPL delete mode 100644 crates/fireworks/src/fireworks.rs diff --git a/Cargo.lock b/Cargo.lock index e0ccab3a983f25421bd9f416d37b504297889fdf..35dac0833337cf65bb8fdd5486cc25bb4c8e1385 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2718,7 +2718,6 @@ dependencies = [ "envy", "extension", "file_finder", - "fireworks", "fs", "futures 0.3.31", "git", @@ -4657,17 +4656,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "fireworks" -version = "0.1.0" -dependencies = [ - "anyhow", - "futures 0.3.31", - "http_client", - "serde", - "serde_json", -] - [[package]] name = "fixedbitset" version = "0.4.2" diff --git a/Cargo.toml b/Cargo.toml index 412da3b320a5c27974940941f8d229e40da92c1c..18701a21463a3b6eedb55b2628c4b925c334b055 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,6 @@ members = [ "crates/feedback", "crates/file_finder", "crates/file_icons", - "crates/fireworks", "crates/fs", "crates/fsevent", "crates/fuzzy", @@ -240,7 +239,6 @@ feature_flags = { path = "crates/feature_flags" } feedback = { path = "crates/feedback" } file_finder = { path = "crates/file_finder" } file_icons = { path = "crates/file_icons" } -fireworks = { path = "crates/fireworks" } fs = { path = "crates/fs" } fsevent = { path = "crates/fsevent" } fuzzy = { path = "crates/fuzzy" } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 020a86c9f80522d0eefa8a3cc2ece0f4f4352fd2..db293c5173806c804c52bb7d2ddc336801c93853 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -34,7 +34,6 @@ collections.workspace = true dashmap.workspace = true derive_more.workspace = true envy = "0.4.2" -fireworks.workspace = true futures.workspace = true google_ai.workspace = true hex.workspace = true diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 6e0ca40d097a8d6ff9cbca98ed8a3d607846d055..b1ab7586613aa210627ecc47519d7254ff5b4eb5 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -21,15 +21,12 @@ use chrono::{DateTime, Duration, Utc}; use collections::HashMap; use db::TokenUsage; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; -use futures::{FutureExt, Stream, StreamExt as _}; +use futures::{Stream, StreamExt as _}; use reqwest_client::ReqwestClient; use rpc::{ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, }; -use rpc::{ - ListModelsResponse, PredictEditsParams, PredictEditsResponse, - MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, -}; +use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME}; use serde_json::json; use std::{ pin::Pin, @@ -44,9 +41,6 @@ pub use token::*; const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); -/// Output token limit. A copy of this constant is also in `crates/zeta/src/zeta.rs`. -const MAX_OUTPUT_TOKENS: u32 = 2048; - pub struct LlmState { pub config: Config, pub executor: Executor, @@ -123,7 +117,6 @@ pub fn routes() -> Router<(), Body> { Router::new() .route("/models", get(list_models)) .route("/completion", post(perform_completion)) - .route("/predict_edits", post(predict_edits)) .layer(middleware::from_fn(validate_api_token)) } @@ -437,156 +430,6 @@ fn normalize_model_name(known_models: Vec, name: String) -> String { } } -async fn predict_edits( - Extension(state): Extension>, - Extension(claims): Extension, - _country_code_header: Option>, - Json(params): Json, -) -> Result { - if !claims.is_staff && !claims.has_predict_edits_feature_flag { - return Err(Error::http( - StatusCode::FORBIDDEN, - "no access to Zed's edit prediction feature".to_string(), - )); - } - - let should_sample = claims.is_staff || params.can_collect_data; - - let api_url = state - .config - .prediction_api_url - .as_ref() - .context("no PREDICTION_API_URL configured on the server")?; - let api_key = state - .config - .prediction_api_key - .as_ref() - .context("no PREDICTION_API_KEY configured on the server")?; - let model = state - .config - .prediction_model - .as_ref() - .context("no PREDICTION_MODEL configured on the server")?; - - let outline_prefix = params - .outline - .as_ref() - .map(|outline| format!("### Outline for current file:\n{}\n", outline)) - .unwrap_or_default(); - - let prompt = include_str!("./llm/prediction_prompt.md") - .replace("", &outline_prefix) - .replace("", ¶ms.input_events) - .replace("", ¶ms.input_excerpt); - - let request_start = std::time::Instant::now(); - let timeout = state - .executor - .sleep(std::time::Duration::from_secs(2)) - .fuse(); - let response = fireworks::complete( - &state.http_client, - api_url, - api_key, - fireworks::CompletionRequest { - model: model.to_string(), - prompt: prompt.clone(), - max_tokens: MAX_OUTPUT_TOKENS, - temperature: 0., - prediction: Some(fireworks::Prediction::Content { - content: params.input_excerpt.clone(), - }), - rewrite_speculation: Some(true), - }, - ) - .fuse(); - futures::pin_mut!(timeout); - futures::pin_mut!(response); - - futures::select! { - _ = timeout => { - state.executor.spawn_detached({ - let kinesis_client = state.kinesis_client.clone(); - let kinesis_stream = state.config.kinesis_stream.clone(); - let model = model.clone(); - async move { - SnowflakeRow::new( - "Fireworks Completion Timeout", - claims.metrics_id, - claims.is_staff, - claims.system_id.clone(), - json!({ - "model": model.to_string(), - "prompt": prompt, - }), - ) - .write(&kinesis_client, &kinesis_stream) - .await - .log_err(); - } - }); - Err(anyhow!("request timed out"))? - }, - response = response => { - let duration = request_start.elapsed(); - - let mut response = response?; - let choice = response - .completion - .choices - .pop() - .context("no output from completion response")?; - - state.executor.spawn_detached({ - let kinesis_client = state.kinesis_client.clone(); - let kinesis_stream = state.config.kinesis_stream.clone(); - let model = model.clone(); - let output = choice.text.clone(); - - async move { - let properties = if should_sample { - json!({ - "model": model.to_string(), - "headers": response.headers, - "usage": response.completion.usage, - "duration": duration.as_secs_f64(), - "prompt": prompt, - "input_excerpt": params.input_excerpt, - "input_events": params.input_events, - "outline": params.outline, - "output": output, - "is_sampled": true, - }) - } else { - json!({ - "model": model.to_string(), - "headers": response.headers, - "usage": response.completion.usage, - "duration": duration.as_secs_f64(), - "is_sampled": false, - }) - }; - - SnowflakeRow::new( - "Fireworks Completion Requested", - claims.metrics_id, - claims.is_staff, - claims.system_id.clone(), - properties, - ) - .write(&kinesis_client, &kinesis_stream) - .await - .log_err(); - } - }); - - Ok(Json(PredictEditsResponse { - output_excerpt: choice.text, - })) - }, - } -} - /// The maximum monthly spending an individual user can reach on the free tier /// before they have to pay. pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10); diff --git a/crates/collab/src/llm/prediction_prompt.md b/crates/collab/src/llm/prediction_prompt.md deleted file mode 100644 index e92e3cc15cab1d5cb977d3b9ede5e0159a15c7c5..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/prediction_prompt.md +++ /dev/null @@ -1,13 +0,0 @@ -## Task -Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - -### Instruction: -You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. - -### Events: - - -### Input: - - -### Response: diff --git a/crates/fireworks/Cargo.toml b/crates/fireworks/Cargo.toml deleted file mode 100644 index baf81ae29e101c0d95c821850d5c3924f0d54af8..0000000000000000000000000000000000000000 --- a/crates/fireworks/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "fireworks" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/fireworks.rs" - -[dependencies] -anyhow.workspace = true -futures.workspace = true -http_client.workspace = true -serde.workspace = true -serde_json.workspace = true diff --git a/crates/fireworks/LICENSE-GPL b/crates/fireworks/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/fireworks/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/fireworks/src/fireworks.rs b/crates/fireworks/src/fireworks.rs deleted file mode 100644 index 5772204747f58e6735429f8e01f735627bf0ca44..0000000000000000000000000000000000000000 --- a/crates/fireworks/src/fireworks.rs +++ /dev/null @@ -1,173 +0,0 @@ -use anyhow::{anyhow, Result}; -use futures::AsyncReadExt; -use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest}; -use serde::{Deserialize, Serialize}; - -pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1"; - -#[derive(Debug, Serialize, Deserialize)] -pub struct CompletionRequest { - pub model: String, - pub prompt: String, - pub max_tokens: u32, - pub temperature: f32, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub prediction: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub rewrite_speculation: Option, -} - -#[derive(Clone, Deserialize, Serialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum Prediction { - Content { content: String }, -} - -#[derive(Debug)] -pub struct Response { - pub completion: CompletionResponse, - pub headers: Headers, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct CompletionResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Usage, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct CompletionChoice { - pub text: String, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Debug, Clone, Default, Serialize)] -pub struct Headers { - pub server_processing_time: Option, - pub request_id: Option, - pub prompt_tokens: Option, - pub speculation_generated_tokens: Option, - pub cached_prompt_tokens: Option, - pub backend_host: Option, - pub num_concurrent_requests: Option, - pub deployment: Option, - pub tokenizer_queue_duration: Option, - pub tokenizer_duration: Option, - pub prefill_queue_duration: Option, - pub prefill_duration: Option, - pub generation_queue_duration: Option, -} - -impl Headers { - pub fn parse(headers: &HeaderMap) -> Self { - Headers { - request_id: headers - .get("x-request-id") - .and_then(|v| v.to_str().ok()) - .map(String::from), - server_processing_time: headers - .get("fireworks-server-processing-time") - .and_then(|v| v.to_str().ok()?.parse().ok()), - prompt_tokens: headers - .get("fireworks-prompt-tokens") - .and_then(|v| v.to_str().ok()?.parse().ok()), - speculation_generated_tokens: headers - .get("fireworks-speculation-generated-tokens") - .and_then(|v| v.to_str().ok()?.parse().ok()), - cached_prompt_tokens: headers - .get("fireworks-cached-prompt-tokens") - .and_then(|v| v.to_str().ok()?.parse().ok()), - backend_host: headers - .get("fireworks-backend-host") - .and_then(|v| v.to_str().ok()) - .map(String::from), - num_concurrent_requests: headers - .get("fireworks-num-concurrent-requests") - .and_then(|v| v.to_str().ok()?.parse().ok()), - deployment: headers - .get("fireworks-deployment") - .and_then(|v| v.to_str().ok()) - .map(String::from), - tokenizer_queue_duration: headers - .get("fireworks-tokenizer-queue-duration") - .and_then(|v| v.to_str().ok()?.parse().ok()), - tokenizer_duration: headers - .get("fireworks-tokenizer-duration") - .and_then(|v| v.to_str().ok()?.parse().ok()), - prefill_queue_duration: headers - .get("fireworks-prefill-queue-duration") - .and_then(|v| v.to_str().ok()?.parse().ok()), - prefill_duration: headers - .get("fireworks-prefill-duration") - .and_then(|v| v.to_str().ok()?.parse().ok()), - generation_queue_duration: headers - .get("fireworks-generation-queue-duration") - .and_then(|v| v.to_str().ok()?.parse().ok()), - } - } -} - -pub async fn complete( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CompletionRequest, -) -> Result { - let uri = format!("{api_url}/completions"); - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; - let mut response = client.send(request).await?; - - if response.status().is_success() { - let headers = Headers::parse(response.headers()); - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Ok(Response { - completion: serde_json::from_str(&body)?, - headers, - }) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct FireworksResponse { - error: FireworksError, - } - - #[derive(Deserialize)] - struct FireworksError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to Fireworks API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to Fireworks API: {} {}", - response.status(), - body, - )), - } - } -}