fim.rs

  1use crate::{
  2    EditPredictionId, EditPredictionModelInput, cursor_excerpt,
  3    open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
  4    prediction::EditPredictionResult,
  5};
  6use anyhow::{Context as _, Result, anyhow};
  7use gpui::{App, AppContext as _, Entity, Task};
  8use language::{
  9    Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _,
 10    language_settings::all_language_settings,
 11};
 12use settings::EditPredictionPromptFormat;
 13use std::{path::Path, sync::Arc, time::Instant};
 14use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges};
 15
 16const FIM_CONTEXT_TOKENS: usize = 512;
 17
 18struct FimRequestOutput {
 19    request_id: String,
 20    edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
 21    snapshot: BufferSnapshot,
 22    response_received_at: Instant,
 23    inputs: ZetaPromptInput,
 24    buffer: Entity<Buffer>,
 25    buffer_snapshotted_at: Instant,
 26}
 27
 28pub fn request_prediction(
 29    EditPredictionModelInput {
 30        buffer,
 31        snapshot,
 32        position,
 33        events,
 34        ..
 35    }: EditPredictionModelInput,
 36    prompt_format: EditPredictionPromptFormat,
 37    cx: &mut App,
 38) -> Task<Result<Option<EditPredictionResult>>> {
 39    let settings = &all_language_settings(None, cx).edit_predictions;
 40    let provider = settings.provider;
 41
 42    let full_path: Arc<Path> = snapshot
 43        .file()
 44        .map(|file| file.full_path(cx))
 45        .unwrap_or_else(|| "untitled".into())
 46        .into();
 47
 48    let http_client = cx.http_client();
 49    let cursor_point = position.to_point(&snapshot);
 50    let buffer_snapshotted_at = Instant::now();
 51
 52    let Some(settings) = (match provider {
 53        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 54        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 55            settings.open_ai_compatible_api.clone()
 56        }
 57        _ => None,
 58    }) else {
 59        return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
 60    };
 61
 62    let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 63
 64    let result = cx.background_spawn(async move {
 65        let cursor_offset = cursor_point.to_offset(&snapshot);
 66        let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
 67            cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
 68        let cursor_excerpt: Arc<str> = snapshot
 69            .text_for_range(excerpt_point_range.clone())
 70            .collect::<String>()
 71            .into();
 72        let syntax_ranges =
 73            cursor_excerpt::compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
 74        let (editable_range, _) = compute_editable_and_context_ranges(
 75            &cursor_excerpt,
 76            cursor_offset_in_excerpt,
 77            &syntax_ranges,
 78            FIM_CONTEXT_TOKENS,
 79            0,
 80        );
 81
 82        let inputs = ZetaPromptInput {
 83            events,
 84            related_files: Some(Vec::new()),
 85            active_buffer_diagnostics: Vec::new(),
 86            cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
 87            cursor_path: full_path.clone(),
 88            excerpt_start_row: Some(excerpt_point_range.start.row),
 89            cursor_excerpt,
 90            excerpt_ranges: Default::default(),
 91            syntax_ranges: None,
 92            experiment: None,
 93            in_open_source_repo: false,
 94            can_collect_data: false,
 95            repo_url: None,
 96        };
 97
 98        let editable_text = &inputs.cursor_excerpt[editable_range.clone()];
 99        let cursor_in_editable = cursor_offset_in_excerpt.saturating_sub(editable_range.start);
100        let prefix = editable_text[..cursor_in_editable].to_string();
101        let suffix = editable_text[cursor_in_editable..].to_string();
102        let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
103        let stop_tokens = get_fim_stop_tokens();
104
105        let max_tokens = settings.max_output_tokens;
106
107        let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
108            provider,
109            &settings,
110            prompt,
111            max_tokens,
112            stop_tokens,
113            api_key,
114            &http_client,
115        )
116        .await?;
117
118        let response_received_at = Instant::now();
119
120        log::debug!(
121            "fim: completion received ({:.2}s)",
122            (response_received_at - buffer_snapshotted_at).as_secs_f64()
123        );
124
125        let completion: Arc<str> = clean_fim_completion(&response_text).into();
126        let edits = if completion.is_empty() {
127            vec![]
128        } else {
129            let cursor_offset = cursor_point.to_offset(&snapshot);
130            let anchor = snapshot.anchor_after(cursor_offset);
131            vec![(anchor..anchor, completion)]
132        };
133
134        anyhow::Ok(FimRequestOutput {
135            request_id,
136            edits,
137            snapshot,
138            response_received_at,
139            inputs,
140            buffer,
141            buffer_snapshotted_at,
142        })
143    });
144
145    cx.spawn(async move |cx: &mut gpui::AsyncApp| {
146        let output = result.await.context("fim edit prediction failed")?;
147        anyhow::Ok(Some(
148            EditPredictionResult::new(
149                EditPredictionId(output.request_id.into()),
150                &output.buffer,
151                &output.snapshot,
152                output.edits.into(),
153                None,
154                output.buffer_snapshotted_at,
155                output.response_received_at,
156                output.inputs,
157                None,
158                cx,
159            )
160            .await,
161        ))
162    })
163}
164
165fn format_fim_prompt(
166    prompt_format: EditPredictionPromptFormat,
167    prefix: &str,
168    suffix: &str,
169) -> String {
170    match prompt_format {
171        EditPredictionPromptFormat::CodeLlama => {
172            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
173        }
174        EditPredictionPromptFormat::StarCoder => {
175            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
176        }
177        EditPredictionPromptFormat::DeepseekCoder => {
178            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
179        }
180        EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
181            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
182        }
183        EditPredictionPromptFormat::Codestral => {
184            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
185        }
186        EditPredictionPromptFormat::Glm => {
187            format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
188        }
189        _ => {
190            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
191        }
192    }
193}
194
195fn get_fim_stop_tokens() -> Vec<String> {
196    vec![
197        "<|endoftext|>".to_string(),
198        "<|file_separator|>".to_string(),
199        "<|fim_pad|>".to_string(),
200        "<|fim_prefix|>".to_string(),
201        "<|fim_middle|>".to_string(),
202        "<|fim_suffix|>".to_string(),
203        "<fim_prefix>".to_string(),
204        "<fim_middle>".to_string(),
205        "<fim_suffix>".to_string(),
206        "<PRE>".to_string(),
207        "<SUF>".to_string(),
208        "<MID>".to_string(),
209        "[PREFIX]".to_string(),
210        "[SUFFIX]".to_string(),
211    ]
212}
213
214fn clean_fim_completion(response: &str) -> String {
215    let mut result = response.to_string();
216
217    let end_tokens = [
218        "<|endoftext|>",
219        "<|file_separator|>",
220        "<|fim_pad|>",
221        "<|fim_prefix|>",
222        "<|fim_middle|>",
223        "<|fim_suffix|>",
224        "<fim_prefix>",
225        "<fim_middle>",
226        "<fim_suffix>",
227        "<PRE>",
228        "<SUF>",
229        "<MID>",
230        "[PREFIX]",
231        "[SUFFIX]",
232    ];
233
234    for token in &end_tokens {
235        if let Some(pos) = result.find(token) {
236            result.truncate(pos);
237        }
238    }
239
240    result
241}