1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt,
3 open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
4 prediction::EditPredictionResult,
5};
6use anyhow::{Context as _, Result, anyhow};
7use gpui::{App, AppContext as _, Entity, Task};
8use language::{
9 Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _,
10 language_settings::all_language_settings,
11};
12use settings::EditPredictionPromptFormat;
13use std::{path::Path, sync::Arc, time::Instant};
14use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges};
15
16const FIM_CONTEXT_TOKENS: usize = 512;
17
18struct FimRequestOutput {
19 request_id: String,
20 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
21 snapshot: BufferSnapshot,
22 inputs: ZetaPromptInput,
23 buffer: Entity<Buffer>,
24}
25
26pub fn request_prediction(
27 EditPredictionModelInput {
28 buffer,
29 snapshot,
30 position,
31 events,
32 ..
33 }: EditPredictionModelInput,
34 prompt_format: EditPredictionPromptFormat,
35 cx: &mut App,
36) -> Task<Result<Option<EditPredictionResult>>> {
37 let settings = &all_language_settings(None, cx).edit_predictions;
38 let provider = settings.provider;
39
40 let full_path: Arc<Path> = snapshot
41 .file()
42 .map(|file| file.full_path(cx))
43 .unwrap_or_else(|| "untitled".into())
44 .into();
45
46 let http_client = cx.http_client();
47 let cursor_point = position.to_point(&snapshot);
48 let request_start = cx.background_executor().now();
49
50 let Some(settings) = (match provider {
51 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
52 settings::EditPredictionProvider::OpenAiCompatibleApi => {
53 settings.open_ai_compatible_api.clone()
54 }
55 _ => None,
56 }) else {
57 return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
58 };
59
60 let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
61
62 let result = cx.background_spawn(async move {
63 let cursor_offset = cursor_point.to_offset(&snapshot);
64 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
65 cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
66 let cursor_excerpt: Arc<str> = snapshot
67 .text_for_range(excerpt_point_range.clone())
68 .collect::<String>()
69 .into();
70 let syntax_ranges =
71 cursor_excerpt::compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
72 let (editable_range, _) = compute_editable_and_context_ranges(
73 &cursor_excerpt,
74 cursor_offset_in_excerpt,
75 &syntax_ranges,
76 FIM_CONTEXT_TOKENS,
77 0,
78 );
79
80 let inputs = ZetaPromptInput {
81 events,
82 related_files: Some(Vec::new()),
83 active_buffer_diagnostics: Vec::new(),
84 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
85 cursor_path: full_path.clone(),
86 excerpt_start_row: Some(excerpt_point_range.start.row),
87 cursor_excerpt,
88 excerpt_ranges: Default::default(),
89 syntax_ranges: None,
90 in_open_source_repo: false,
91 can_collect_data: false,
92 repo_url: None,
93 };
94
95 let editable_text = &inputs.cursor_excerpt[editable_range.clone()];
96 let cursor_in_editable = cursor_offset_in_excerpt.saturating_sub(editable_range.start);
97 let prefix = editable_text[..cursor_in_editable].to_string();
98 let suffix = editable_text[cursor_in_editable..].to_string();
99 let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
100 let stop_tokens = get_fim_stop_tokens();
101
102 let max_tokens = settings.max_output_tokens;
103
104 let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
105 provider,
106 &settings,
107 prompt,
108 max_tokens,
109 stop_tokens,
110 api_key,
111 &http_client,
112 )
113 .await?;
114
115 let response_received_at = Instant::now();
116
117 log::debug!(
118 "fim: completion received ({:.2}s)",
119 (response_received_at - request_start).as_secs_f64()
120 );
121
122 let completion: Arc<str> = clean_fim_completion(&response_text).into();
123 let edits = if completion.is_empty() {
124 vec![]
125 } else {
126 let cursor_offset = cursor_point.to_offset(&snapshot);
127 let anchor = snapshot.anchor_after(cursor_offset);
128 vec![(anchor..anchor, completion)]
129 };
130
131 anyhow::Ok(FimRequestOutput {
132 request_id,
133 edits,
134 snapshot,
135 inputs,
136 buffer,
137 })
138 });
139
140 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
141 let output = result.await.context("fim edit prediction failed")?;
142 anyhow::Ok(Some(
143 EditPredictionResult::new(
144 EditPredictionId(output.request_id.into()),
145 &output.buffer,
146 &output.snapshot,
147 output.edits.into(),
148 None,
149 output.inputs,
150 None,
151 cx.background_executor().now() - request_start,
152 cx,
153 )
154 .await,
155 ))
156 })
157}
158
159fn format_fim_prompt(
160 prompt_format: EditPredictionPromptFormat,
161 prefix: &str,
162 suffix: &str,
163) -> String {
164 match prompt_format {
165 EditPredictionPromptFormat::CodeLlama => {
166 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
167 }
168 EditPredictionPromptFormat::StarCoder => {
169 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
170 }
171 EditPredictionPromptFormat::DeepseekCoder => {
172 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
173 }
174 EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
175 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
176 }
177 EditPredictionPromptFormat::Codestral => {
178 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
179 }
180 EditPredictionPromptFormat::Glm => {
181 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
182 }
183 _ => {
184 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
185 }
186 }
187}
188
189fn get_fim_stop_tokens() -> Vec<String> {
190 vec![
191 "<|endoftext|>".to_string(),
192 "<|file_separator|>".to_string(),
193 "<|fim_pad|>".to_string(),
194 "<|fim_prefix|>".to_string(),
195 "<|fim_middle|>".to_string(),
196 "<|fim_suffix|>".to_string(),
197 "<fim_prefix>".to_string(),
198 "<fim_middle>".to_string(),
199 "<fim_suffix>".to_string(),
200 "<PRE>".to_string(),
201 "<SUF>".to_string(),
202 "<MID>".to_string(),
203 "[PREFIX]".to_string(),
204 "[SUFFIX]".to_string(),
205 ]
206}
207
208fn clean_fim_completion(response: &str) -> String {
209 let mut result = response.to_string();
210
211 let end_tokens = [
212 "<|endoftext|>",
213 "<|file_separator|>",
214 "<|fim_pad|>",
215 "<|fim_prefix|>",
216 "<|fim_middle|>",
217 "<|fim_suffix|>",
218 "<fim_prefix>",
219 "<fim_middle>",
220 "<fim_suffix>",
221 "<PRE>",
222 "<SUF>",
223 "<MID>",
224 "[PREFIX]",
225 "[SUFFIX]",
226 ];
227
228 for token in &end_tokens {
229 if let Some(pos) = result.find(token) {
230 result.truncate(pos);
231 }
232 }
233
234 result
235}