1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt,
3 open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
4 prediction::EditPredictionResult,
5};
6use anyhow::{Context as _, Result, anyhow};
7use gpui::{App, AppContext as _, Entity, Task};
8use language::{
9 Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _,
10 language_settings::all_language_settings,
11};
12use settings::EditPredictionPromptFormat;
13use std::{path::Path, sync::Arc, time::Instant};
14use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges};
15
16const FIM_CONTEXT_TOKENS: usize = 512;
17
18struct FimRequestOutput {
19 request_id: String,
20 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
21 snapshot: BufferSnapshot,
22 response_received_at: Instant,
23 inputs: ZetaPromptInput,
24 buffer: Entity<Buffer>,
25 buffer_snapshotted_at: Instant,
26}
27
28pub fn request_prediction(
29 EditPredictionModelInput {
30 buffer,
31 snapshot,
32 position,
33 events,
34 ..
35 }: EditPredictionModelInput,
36 prompt_format: EditPredictionPromptFormat,
37 cx: &mut App,
38) -> Task<Result<Option<EditPredictionResult>>> {
39 let settings = &all_language_settings(None, cx).edit_predictions;
40 let provider = settings.provider;
41
42 let full_path: Arc<Path> = snapshot
43 .file()
44 .map(|file| file.full_path(cx))
45 .unwrap_or_else(|| "untitled".into())
46 .into();
47
48 let http_client = cx.http_client();
49 let cursor_point = position.to_point(&snapshot);
50 let buffer_snapshotted_at = Instant::now();
51
52 let Some(settings) = (match provider {
53 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
54 settings::EditPredictionProvider::OpenAiCompatibleApi => {
55 settings.open_ai_compatible_api.clone()
56 }
57 _ => None,
58 }) else {
59 return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
60 };
61
62 let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
63
64 let result = cx.background_spawn(async move {
65 let cursor_offset = cursor_point.to_offset(&snapshot);
66 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
67 cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
68 let cursor_excerpt: Arc<str> = snapshot
69 .text_for_range(excerpt_point_range.clone())
70 .collect::<String>()
71 .into();
72 let syntax_ranges =
73 cursor_excerpt::compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
74 let (editable_range, _) = compute_editable_and_context_ranges(
75 &cursor_excerpt,
76 cursor_offset_in_excerpt,
77 &syntax_ranges,
78 FIM_CONTEXT_TOKENS,
79 0,
80 );
81
82 let inputs = ZetaPromptInput {
83 events,
84 related_files: Some(Vec::new()),
85 active_buffer_diagnostics: Vec::new(),
86 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
87 cursor_path: full_path.clone(),
88 excerpt_start_row: Some(excerpt_point_range.start.row),
89 cursor_excerpt,
90 excerpt_ranges: Default::default(),
91 syntax_ranges: None,
92 experiment: None,
93 in_open_source_repo: false,
94 can_collect_data: false,
95 repo_url: None,
96 };
97
98 let editable_text = &inputs.cursor_excerpt[editable_range.clone()];
99 let cursor_in_editable = cursor_offset_in_excerpt.saturating_sub(editable_range.start);
100 let prefix = editable_text[..cursor_in_editable].to_string();
101 let suffix = editable_text[cursor_in_editable..].to_string();
102 let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
103 let stop_tokens = get_fim_stop_tokens();
104
105 let max_tokens = settings.max_output_tokens;
106
107 let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
108 provider,
109 &settings,
110 prompt,
111 max_tokens,
112 stop_tokens,
113 api_key,
114 &http_client,
115 )
116 .await?;
117
118 let response_received_at = Instant::now();
119
120 log::debug!(
121 "fim: completion received ({:.2}s)",
122 (response_received_at - buffer_snapshotted_at).as_secs_f64()
123 );
124
125 let completion: Arc<str> = clean_fim_completion(&response_text).into();
126 let edits = if completion.is_empty() {
127 vec![]
128 } else {
129 let cursor_offset = cursor_point.to_offset(&snapshot);
130 let anchor = snapshot.anchor_after(cursor_offset);
131 vec![(anchor..anchor, completion)]
132 };
133
134 anyhow::Ok(FimRequestOutput {
135 request_id,
136 edits,
137 snapshot,
138 response_received_at,
139 inputs,
140 buffer,
141 buffer_snapshotted_at,
142 })
143 });
144
145 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
146 let output = result.await.context("fim edit prediction failed")?;
147 anyhow::Ok(Some(
148 EditPredictionResult::new(
149 EditPredictionId(output.request_id.into()),
150 &output.buffer,
151 &output.snapshot,
152 output.edits.into(),
153 None,
154 output.buffer_snapshotted_at,
155 output.response_received_at,
156 output.inputs,
157 None,
158 cx,
159 )
160 .await,
161 ))
162 })
163}
164
165fn format_fim_prompt(
166 prompt_format: EditPredictionPromptFormat,
167 prefix: &str,
168 suffix: &str,
169) -> String {
170 match prompt_format {
171 EditPredictionPromptFormat::CodeLlama => {
172 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
173 }
174 EditPredictionPromptFormat::StarCoder => {
175 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
176 }
177 EditPredictionPromptFormat::DeepseekCoder => {
178 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
179 }
180 EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
181 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
182 }
183 EditPredictionPromptFormat::Codestral => {
184 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
185 }
186 EditPredictionPromptFormat::Glm => {
187 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
188 }
189 _ => {
190 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
191 }
192 }
193}
194
195fn get_fim_stop_tokens() -> Vec<String> {
196 vec![
197 "<|endoftext|>".to_string(),
198 "<|file_separator|>".to_string(),
199 "<|fim_pad|>".to_string(),
200 "<|fim_prefix|>".to_string(),
201 "<|fim_middle|>".to_string(),
202 "<|fim_suffix|>".to_string(),
203 "<fim_prefix>".to_string(),
204 "<fim_middle>".to_string(),
205 "<fim_suffix>".to_string(),
206 "<PRE>".to_string(),
207 "<SUF>".to_string(),
208 "<MID>".to_string(),
209 "[PREFIX]".to_string(),
210 "[SUFFIX]".to_string(),
211 ]
212}
213
214fn clean_fim_completion(response: &str) -> String {
215 let mut result = response.to_string();
216
217 let end_tokens = [
218 "<|endoftext|>",
219 "<|file_separator|>",
220 "<|fim_pad|>",
221 "<|fim_prefix|>",
222 "<|fim_middle|>",
223 "<|fim_suffix|>",
224 "<fim_prefix>",
225 "<fim_middle>",
226 "<fim_suffix>",
227 "<PRE>",
228 "<SUF>",
229 "<MID>",
230 "[PREFIX]",
231 "[SUFFIX]",
232 ];
233
234 for token in &end_tokens {
235 if let Some(pos) = result.find(token) {
236 result.truncate(pos);
237 }
238 }
239
240 result
241}