fim.rs

  1use crate::{
  2    EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
  3    zeta,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use gpui::{App, AppContext as _, Entity, Task};
  7use language::{
  8    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
  9    language_settings::all_language_settings,
 10};
 11use settings::EditPredictionPromptFormat;
 12use std::{path::Path, sync::Arc, time::Instant};
 13use zeta_prompt::ZetaPromptInput;
 14
 15const FIM_CONTEXT_TOKENS: usize = 512;
 16
 17struct FimRequestOutput {
 18    request_id: String,
 19    edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
 20    snapshot: BufferSnapshot,
 21    response_received_at: Instant,
 22    inputs: ZetaPromptInput,
 23    buffer: Entity<Buffer>,
 24    buffer_snapshotted_at: Instant,
 25}
 26
 27pub fn request_prediction(
 28    EditPredictionModelInput {
 29        buffer,
 30        snapshot,
 31        position,
 32        events,
 33        ..
 34    }: EditPredictionModelInput,
 35    prompt_format: EditPredictionPromptFormat,
 36    cx: &mut App,
 37) -> Task<Result<Option<EditPredictionResult>>> {
 38    let settings = &all_language_settings(None, cx).edit_predictions;
 39    let provider = settings.provider;
 40
 41    let full_path: Arc<Path> = snapshot
 42        .file()
 43        .map(|file| file.full_path(cx))
 44        .unwrap_or_else(|| "untitled".into())
 45        .into();
 46
 47    let http_client = cx.http_client();
 48    let cursor_point = position.to_point(&snapshot);
 49    let buffer_snapshotted_at = Instant::now();
 50
 51    let Some(settings) = (match provider {
 52        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 53        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 54            settings.open_ai_compatible_api.clone()
 55        }
 56        _ => None,
 57    }) else {
 58        return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
 59    };
 60
 61    let result = cx.background_spawn(async move {
 62        let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
 63            cursor_point,
 64            &snapshot,
 65            FIM_CONTEXT_TOKENS,
 66            0,
 67        );
 68        let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
 69        let cursor_offset = cursor_point.to_offset(&snapshot);
 70
 71        let inputs = ZetaPromptInput {
 72            events,
 73            related_files: Vec::new(),
 74            cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
 75            cursor_path: full_path.clone(),
 76            excerpt_start_row: Some(excerpt_range.start.row),
 77            cursor_excerpt: snapshot
 78                .text_for_range(excerpt_range)
 79                .collect::<String>()
 80                .into(),
 81            excerpt_ranges: Default::default(),
 82            experiment: None,
 83            in_open_source_repo: false,
 84            can_collect_data: false,
 85        };
 86
 87        let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
 88        let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
 89        let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
 90        let stop_tokens = get_fim_stop_tokens();
 91
 92        let max_tokens = settings.max_output_tokens;
 93        let (response_text, request_id) = zeta::send_custom_server_request(
 94            provider,
 95            &settings,
 96            prompt,
 97            max_tokens,
 98            stop_tokens,
 99            &http_client,
100        )
101        .await?;
102
103        let response_received_at = Instant::now();
104
105        log::debug!(
106            "fim: completion received ({:.2}s)",
107            (response_received_at - buffer_snapshotted_at).as_secs_f64()
108        );
109
110        let completion: Arc<str> = clean_fim_completion(&response_text).into();
111        let edits = if completion.is_empty() {
112            vec![]
113        } else {
114            let cursor_offset = cursor_point.to_offset(&snapshot);
115            let anchor = snapshot.anchor_after(cursor_offset);
116            vec![(anchor..anchor, completion)]
117        };
118
119        anyhow::Ok(FimRequestOutput {
120            request_id,
121            edits,
122            snapshot,
123            response_received_at,
124            inputs,
125            buffer,
126            buffer_snapshotted_at,
127        })
128    });
129
130    cx.spawn(async move |cx: &mut gpui::AsyncApp| {
131        let output = result.await.context("fim edit prediction failed")?;
132        anyhow::Ok(Some(
133            EditPredictionResult::new(
134                EditPredictionId(output.request_id.into()),
135                &output.buffer,
136                &output.snapshot,
137                output.edits.into(),
138                None,
139                output.buffer_snapshotted_at,
140                output.response_received_at,
141                output.inputs,
142                None,
143                cx,
144            )
145            .await,
146        ))
147    })
148}
149
150fn format_fim_prompt(
151    prompt_format: EditPredictionPromptFormat,
152    prefix: &str,
153    suffix: &str,
154) -> String {
155    match prompt_format {
156        EditPredictionPromptFormat::CodeLlama => {
157            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
158        }
159        EditPredictionPromptFormat::StarCoder => {
160            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
161        }
162        EditPredictionPromptFormat::DeepseekCoder => {
163            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
164        }
165        EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
166            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
167        }
168        EditPredictionPromptFormat::Codestral => {
169            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
170        }
171        EditPredictionPromptFormat::Glm => {
172            format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
173        }
174        _ => {
175            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
176        }
177    }
178}
179
180fn get_fim_stop_tokens() -> Vec<String> {
181    vec![
182        "<|endoftext|>".to_string(),
183        "<|file_separator|>".to_string(),
184        "<|fim_pad|>".to_string(),
185        "<|fim_prefix|>".to_string(),
186        "<|fim_middle|>".to_string(),
187        "<|fim_suffix|>".to_string(),
188        "<fim_prefix>".to_string(),
189        "<fim_middle>".to_string(),
190        "<fim_suffix>".to_string(),
191        "<PRE>".to_string(),
192        "<SUF>".to_string(),
193        "<MID>".to_string(),
194        "[PREFIX]".to_string(),
195        "[SUFFIX]".to_string(),
196    ]
197}
198
199fn clean_fim_completion(response: &str) -> String {
200    let mut result = response.to_string();
201
202    let end_tokens = [
203        "<|endoftext|>",
204        "<|file_separator|>",
205        "<|fim_pad|>",
206        "<|fim_prefix|>",
207        "<|fim_middle|>",
208        "<|fim_suffix|>",
209        "<fim_prefix>",
210        "<fim_middle>",
211        "<fim_suffix>",
212        "<PRE>",
213        "<SUF>",
214        "<MID>",
215        "[PREFIX]",
216        "[SUFFIX]",
217    ];
218
219    for token in &end_tokens {
220        if let Some(pos) = result.find(token) {
221            result.truncate(pos);
222        }
223    }
224
225    result
226}