fim.rs

  1use crate::{
  2    EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
  3    zeta,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use gpui::{App, AppContext as _, Entity, Task};
  7use language::{
  8    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
  9    language_settings::all_language_settings,
 10};
 11use settings::EditPredictionPromptFormat;
 12use std::{path::Path, sync::Arc, time::Instant};
 13use zeta_prompt::ZetaPromptInput;
 14
 15const FIM_CONTEXT_TOKENS: usize = 512;
 16
 17struct FimRequestOutput {
 18    request_id: String,
 19    edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
 20    snapshot: BufferSnapshot,
 21    response_received_at: Instant,
 22    inputs: ZetaPromptInput,
 23    buffer: Entity<Buffer>,
 24    buffer_snapshotted_at: Instant,
 25}
 26
 27pub fn request_prediction(
 28    EditPredictionModelInput {
 29        buffer,
 30        snapshot,
 31        position,
 32        events,
 33        ..
 34    }: EditPredictionModelInput,
 35    prompt_format: EditPredictionPromptFormat,
 36    cx: &mut App,
 37) -> Task<Result<Option<EditPredictionResult>>> {
 38    let settings = &all_language_settings(None, cx).edit_predictions;
 39    let provider = settings.provider;
 40
 41    let full_path: Arc<Path> = snapshot
 42        .file()
 43        .map(|file| file.full_path(cx))
 44        .unwrap_or_else(|| "untitled".into())
 45        .into();
 46
 47    let http_client = cx.http_client();
 48    let cursor_point = position.to_point(&snapshot);
 49    let buffer_snapshotted_at = Instant::now();
 50
 51    let Some(settings) = (match provider {
 52        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 53        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 54            settings.open_ai_compatible_api.clone()
 55        }
 56        _ => None,
 57    }) else {
 58        return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
 59    };
 60
 61    let result = cx.background_spawn(async move {
 62        let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
 63            cursor_point,
 64            &snapshot,
 65            FIM_CONTEXT_TOKENS,
 66            0,
 67        );
 68        let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
 69        let cursor_offset = cursor_point.to_offset(&snapshot);
 70
 71        let inputs = ZetaPromptInput {
 72            events,
 73            related_files: Vec::new(),
 74            cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
 75            editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start
 76                ..cursor_offset - excerpt_offset_range.start,
 77            cursor_path: full_path.clone(),
 78            excerpt_start_row: Some(excerpt_range.start.row),
 79            cursor_excerpt: snapshot
 80                .text_for_range(excerpt_range)
 81                .collect::<String>()
 82                .into(),
 83            excerpt_ranges: None,
 84            preferred_model: None,
 85            in_open_source_repo: false,
 86            can_collect_data: false,
 87        };
 88
 89        let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
 90        let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
 91        let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
 92        let stop_tokens = get_fim_stop_tokens();
 93
 94        let max_tokens = settings.max_output_tokens;
 95        let (response_text, request_id) = zeta::send_custom_server_request(
 96            provider,
 97            &settings,
 98            prompt,
 99            max_tokens,
100            stop_tokens,
101            &http_client,
102        )
103        .await?;
104
105        let response_received_at = Instant::now();
106
107        log::debug!(
108            "fim: completion received ({:.2}s)",
109            (response_received_at - buffer_snapshotted_at).as_secs_f64()
110        );
111
112        let completion: Arc<str> = clean_fim_completion(&response_text).into();
113        let edits = if completion.is_empty() {
114            vec![]
115        } else {
116            let cursor_offset = cursor_point.to_offset(&snapshot);
117            let anchor = snapshot.anchor_after(cursor_offset);
118            vec![(anchor..anchor, completion)]
119        };
120
121        anyhow::Ok(FimRequestOutput {
122            request_id,
123            edits,
124            snapshot,
125            response_received_at,
126            inputs,
127            buffer,
128            buffer_snapshotted_at,
129        })
130    });
131
132    cx.spawn(async move |cx: &mut gpui::AsyncApp| {
133        let output = result.await.context("fim edit prediction failed")?;
134        anyhow::Ok(Some(
135            EditPredictionResult::new(
136                EditPredictionId(output.request_id.into()),
137                &output.buffer,
138                &output.snapshot,
139                output.edits.into(),
140                None,
141                output.buffer_snapshotted_at,
142                output.response_received_at,
143                output.inputs,
144                None,
145                cx,
146            )
147            .await,
148        ))
149    })
150}
151
152fn format_fim_prompt(
153    prompt_format: EditPredictionPromptFormat,
154    prefix: &str,
155    suffix: &str,
156) -> String {
157    match prompt_format {
158        EditPredictionPromptFormat::CodeLlama => {
159            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
160        }
161        EditPredictionPromptFormat::StarCoder => {
162            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
163        }
164        EditPredictionPromptFormat::DeepseekCoder => {
165            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
166        }
167        EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
168            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
169        }
170        EditPredictionPromptFormat::Codestral => {
171            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
172        }
173        EditPredictionPromptFormat::Glm => {
174            format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
175        }
176        _ => {
177            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
178        }
179    }
180}
181
182fn get_fim_stop_tokens() -> Vec<String> {
183    vec![
184        "<|endoftext|>".to_string(),
185        "<|file_separator|>".to_string(),
186        "<|fim_pad|>".to_string(),
187        "<|fim_prefix|>".to_string(),
188        "<|fim_middle|>".to_string(),
189        "<|fim_suffix|>".to_string(),
190        "<fim_prefix>".to_string(),
191        "<fim_middle>".to_string(),
192        "<fim_suffix>".to_string(),
193        "<PRE>".to_string(),
194        "<SUF>".to_string(),
195        "<MID>".to_string(),
196        "[PREFIX]".to_string(),
197        "[SUFFIX]".to_string(),
198    ]
199}
200
201fn clean_fim_completion(response: &str) -> String {
202    let mut result = response.to_string();
203
204    let end_tokens = [
205        "<|endoftext|>",
206        "<|file_separator|>",
207        "<|fim_pad|>",
208        "<|fim_prefix|>",
209        "<|fim_middle|>",
210        "<|fim_suffix|>",
211        "<fim_prefix>",
212        "<fim_middle>",
213        "<fim_suffix>",
214        "<PRE>",
215        "<SUF>",
216        "<MID>",
217        "[PREFIX]",
218        "[SUFFIX]",
219    ];
220
221    for token in &end_tokens {
222        if let Some(pos) = result.find(token) {
223            result.truncate(pos);
224        }
225    }
226
227    result
228}