fim.rs

  1use crate::{
  2    EditPredictionId, EditPredictionModelInput, cursor_excerpt,
  3    open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
  4    prediction::EditPredictionResult,
  5};
  6use anyhow::{Context as _, Result, anyhow};
  7use gpui::{App, AppContext as _, Entity, Task};
  8use language::{
  9    Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _,
 10    language_settings::all_language_settings,
 11};
 12use settings::EditPredictionPromptFormat;
 13use std::{path::Path, sync::Arc, time::Instant};
 14use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges};
 15
 16const FIM_CONTEXT_TOKENS: usize = 512;
 17
 18struct FimRequestOutput {
 19    request_id: String,
 20    edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
 21    snapshot: BufferSnapshot,
 22    inputs: ZetaPromptInput,
 23    buffer: Entity<Buffer>,
 24}
 25
 26pub fn request_prediction(
 27    EditPredictionModelInput {
 28        buffer,
 29        snapshot,
 30        position,
 31        events,
 32        ..
 33    }: EditPredictionModelInput,
 34    prompt_format: EditPredictionPromptFormat,
 35    cx: &mut App,
 36) -> Task<Result<Option<EditPredictionResult>>> {
 37    let settings = &all_language_settings(None, cx).edit_predictions;
 38    let provider = settings.provider;
 39
 40    let full_path: Arc<Path> = snapshot
 41        .file()
 42        .map(|file| file.full_path(cx))
 43        .unwrap_or_else(|| "untitled".into())
 44        .into();
 45
 46    let http_client = cx.http_client();
 47    let cursor_point = position.to_point(&snapshot);
 48    let request_start = cx.background_executor().now();
 49
 50    let Some(settings) = (match provider {
 51        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 52        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 53            settings.open_ai_compatible_api.clone()
 54        }
 55        _ => None,
 56    }) else {
 57        return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
 58    };
 59
 60    let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 61
 62    let result = cx.background_spawn(async move {
 63        let cursor_offset = cursor_point.to_offset(&snapshot);
 64        let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
 65            cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
 66        let cursor_excerpt: Arc<str> = snapshot
 67            .text_for_range(excerpt_point_range.clone())
 68            .collect::<String>()
 69            .into();
 70        let syntax_ranges =
 71            cursor_excerpt::compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
 72        let (editable_range, _) = compute_editable_and_context_ranges(
 73            &cursor_excerpt,
 74            cursor_offset_in_excerpt,
 75            &syntax_ranges,
 76            FIM_CONTEXT_TOKENS,
 77            0,
 78        );
 79
 80        let inputs = ZetaPromptInput {
 81            events,
 82            related_files: Some(Vec::new()),
 83            active_buffer_diagnostics: Vec::new(),
 84            cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
 85            cursor_path: full_path.clone(),
 86            excerpt_start_row: Some(excerpt_point_range.start.row),
 87            cursor_excerpt,
 88            excerpt_ranges: Default::default(),
 89            syntax_ranges: None,
 90            experiment: None,
 91            in_open_source_repo: false,
 92            can_collect_data: false,
 93            repo_url: None,
 94        };
 95
 96        let editable_text = &inputs.cursor_excerpt[editable_range.clone()];
 97        let cursor_in_editable = cursor_offset_in_excerpt.saturating_sub(editable_range.start);
 98        let prefix = editable_text[..cursor_in_editable].to_string();
 99        let suffix = editable_text[cursor_in_editable..].to_string();
100        let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
101        let stop_tokens = get_fim_stop_tokens();
102
103        let max_tokens = settings.max_output_tokens;
104
105        let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
106            provider,
107            &settings,
108            prompt,
109            max_tokens,
110            stop_tokens,
111            api_key,
112            &http_client,
113        )
114        .await?;
115
116        let response_received_at = Instant::now();
117
118        log::debug!(
119            "fim: completion received ({:.2}s)",
120            (response_received_at - request_start).as_secs_f64()
121        );
122
123        let completion: Arc<str> = clean_fim_completion(&response_text).into();
124        let edits = if completion.is_empty() {
125            vec![]
126        } else {
127            let cursor_offset = cursor_point.to_offset(&snapshot);
128            let anchor = snapshot.anchor_after(cursor_offset);
129            vec![(anchor..anchor, completion)]
130        };
131
132        anyhow::Ok(FimRequestOutput {
133            request_id,
134            edits,
135            snapshot,
136            inputs,
137            buffer,
138        })
139    });
140
141    cx.spawn(async move |cx: &mut gpui::AsyncApp| {
142        let output = result.await.context("fim edit prediction failed")?;
143        anyhow::Ok(Some(
144            EditPredictionResult::new(
145                EditPredictionId(output.request_id.into()),
146                &output.buffer,
147                &output.snapshot,
148                output.edits.into(),
149                None,
150                output.inputs,
151                None,
152                cx.background_executor().now() - request_start,
153                cx,
154            )
155            .await,
156        ))
157    })
158}
159
160fn format_fim_prompt(
161    prompt_format: EditPredictionPromptFormat,
162    prefix: &str,
163    suffix: &str,
164) -> String {
165    match prompt_format {
166        EditPredictionPromptFormat::CodeLlama => {
167            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
168        }
169        EditPredictionPromptFormat::StarCoder => {
170            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
171        }
172        EditPredictionPromptFormat::DeepseekCoder => {
173            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
174        }
175        EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
176            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
177        }
178        EditPredictionPromptFormat::Codestral => {
179            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
180        }
181        EditPredictionPromptFormat::Glm => {
182            format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
183        }
184        _ => {
185            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
186        }
187    }
188}
189
190fn get_fim_stop_tokens() -> Vec<String> {
191    vec![
192        "<|endoftext|>".to_string(),
193        "<|file_separator|>".to_string(),
194        "<|fim_pad|>".to_string(),
195        "<|fim_prefix|>".to_string(),
196        "<|fim_middle|>".to_string(),
197        "<|fim_suffix|>".to_string(),
198        "<fim_prefix>".to_string(),
199        "<fim_middle>".to_string(),
200        "<fim_suffix>".to_string(),
201        "<PRE>".to_string(),
202        "<SUF>".to_string(),
203        "<MID>".to_string(),
204        "[PREFIX]".to_string(),
205        "[SUFFIX]".to_string(),
206    ]
207}
208
209fn clean_fim_completion(response: &str) -> String {
210    let mut result = response.to_string();
211
212    let end_tokens = [
213        "<|endoftext|>",
214        "<|file_separator|>",
215        "<|fim_pad|>",
216        "<|fim_prefix|>",
217        "<|fim_middle|>",
218        "<|fim_suffix|>",
219        "<fim_prefix>",
220        "<fim_middle>",
221        "<fim_suffix>",
222        "<PRE>",
223        "<SUF>",
224        "<MID>",
225        "[PREFIX]",
226        "[SUFFIX]",
227    ];
228
229    for token in &end_tokens {
230        if let Some(pos) = result.find(token) {
231            result.truncate(pos);
232        }
233    }
234
235    result
236}