1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt,
3 open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
4 prediction::EditPredictionResult,
5};
6use anyhow::{Context as _, Result, anyhow};
7use gpui::{App, AppContext as _, Entity, Task};
8use language::{
9 Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _,
10 language_settings::all_language_settings,
11};
12use settings::EditPredictionPromptFormat;
13use std::{path::Path, sync::Arc, time::Instant};
14use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges};
15
16const FIM_CONTEXT_TOKENS: usize = 512;
17
18struct FimRequestOutput {
19 request_id: String,
20 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
21 snapshot: BufferSnapshot,
22 inputs: ZetaPromptInput,
23 buffer: Entity<Buffer>,
24}
25
26pub fn request_prediction(
27 EditPredictionModelInput {
28 buffer,
29 snapshot,
30 position,
31 events,
32 ..
33 }: EditPredictionModelInput,
34 prompt_format: EditPredictionPromptFormat,
35 cx: &mut App,
36) -> Task<Result<Option<EditPredictionResult>>> {
37 let settings = &all_language_settings(None, cx).edit_predictions;
38 let provider = settings.provider;
39
40 let full_path: Arc<Path> = snapshot
41 .file()
42 .map(|file| file.full_path(cx))
43 .unwrap_or_else(|| "untitled".into())
44 .into();
45
46 let http_client = cx.http_client();
47 let cursor_point = position.to_point(&snapshot);
48 let request_start = cx.background_executor().now();
49
50 let Some(settings) = (match provider {
51 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
52 settings::EditPredictionProvider::OpenAiCompatibleApi => {
53 settings.open_ai_compatible_api.clone()
54 }
55 _ => None,
56 }) else {
57 return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
58 };
59
60 let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
61
62 let result = cx.background_spawn(async move {
63 let cursor_offset = cursor_point.to_offset(&snapshot);
64 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
65 cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
66 let cursor_excerpt: Arc<str> = snapshot
67 .text_for_range(excerpt_point_range.clone())
68 .collect::<String>()
69 .into();
70 let syntax_ranges =
71 cursor_excerpt::compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
72 let (editable_range, _) = compute_editable_and_context_ranges(
73 &cursor_excerpt,
74 cursor_offset_in_excerpt,
75 &syntax_ranges,
76 FIM_CONTEXT_TOKENS,
77 0,
78 );
79
80 let inputs = ZetaPromptInput {
81 events,
82 related_files: Some(Vec::new()),
83 active_buffer_diagnostics: Vec::new(),
84 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
85 cursor_path: full_path.clone(),
86 excerpt_start_row: Some(excerpt_point_range.start.row),
87 cursor_excerpt,
88 excerpt_ranges: Default::default(),
89 syntax_ranges: None,
90 experiment: None,
91 in_open_source_repo: false,
92 can_collect_data: false,
93 repo_url: None,
94 };
95
96 let editable_text = &inputs.cursor_excerpt[editable_range.clone()];
97 let cursor_in_editable = cursor_offset_in_excerpt.saturating_sub(editable_range.start);
98 let prefix = editable_text[..cursor_in_editable].to_string();
99 let suffix = editable_text[cursor_in_editable..].to_string();
100 let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
101 let stop_tokens = get_fim_stop_tokens();
102
103 let max_tokens = settings.max_output_tokens;
104
105 let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
106 provider,
107 &settings,
108 prompt,
109 max_tokens,
110 stop_tokens,
111 api_key,
112 &http_client,
113 )
114 .await?;
115
116 let response_received_at = Instant::now();
117
118 log::debug!(
119 "fim: completion received ({:.2}s)",
120 (response_received_at - request_start).as_secs_f64()
121 );
122
123 let completion: Arc<str> = clean_fim_completion(&response_text).into();
124 let edits = if completion.is_empty() {
125 vec![]
126 } else {
127 let cursor_offset = cursor_point.to_offset(&snapshot);
128 let anchor = snapshot.anchor_after(cursor_offset);
129 vec![(anchor..anchor, completion)]
130 };
131
132 anyhow::Ok(FimRequestOutput {
133 request_id,
134 edits,
135 snapshot,
136 inputs,
137 buffer,
138 })
139 });
140
141 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
142 let output = result.await.context("fim edit prediction failed")?;
143 anyhow::Ok(Some(
144 EditPredictionResult::new(
145 EditPredictionId(output.request_id.into()),
146 &output.buffer,
147 &output.snapshot,
148 output.edits.into(),
149 None,
150 output.inputs,
151 None,
152 cx.background_executor().now() - request_start,
153 cx,
154 )
155 .await,
156 ))
157 })
158}
159
160fn format_fim_prompt(
161 prompt_format: EditPredictionPromptFormat,
162 prefix: &str,
163 suffix: &str,
164) -> String {
165 match prompt_format {
166 EditPredictionPromptFormat::CodeLlama => {
167 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
168 }
169 EditPredictionPromptFormat::StarCoder => {
170 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
171 }
172 EditPredictionPromptFormat::DeepseekCoder => {
173 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
174 }
175 EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
176 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
177 }
178 EditPredictionPromptFormat::Codestral => {
179 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
180 }
181 EditPredictionPromptFormat::Glm => {
182 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
183 }
184 _ => {
185 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
186 }
187 }
188}
189
190fn get_fim_stop_tokens() -> Vec<String> {
191 vec![
192 "<|endoftext|>".to_string(),
193 "<|file_separator|>".to_string(),
194 "<|fim_pad|>".to_string(),
195 "<|fim_prefix|>".to_string(),
196 "<|fim_middle|>".to_string(),
197 "<|fim_suffix|>".to_string(),
198 "<fim_prefix>".to_string(),
199 "<fim_middle>".to_string(),
200 "<fim_suffix>".to_string(),
201 "<PRE>".to_string(),
202 "<SUF>".to_string(),
203 "<MID>".to_string(),
204 "[PREFIX]".to_string(),
205 "[SUFFIX]".to_string(),
206 ]
207}
208
209fn clean_fim_completion(response: &str) -> String {
210 let mut result = response.to_string();
211
212 let end_tokens = [
213 "<|endoftext|>",
214 "<|file_separator|>",
215 "<|fim_pad|>",
216 "<|fim_prefix|>",
217 "<|fim_middle|>",
218 "<|fim_suffix|>",
219 "<fim_prefix>",
220 "<fim_middle>",
221 "<fim_suffix>",
222 "<PRE>",
223 "<SUF>",
224 "<MID>",
225 "[PREFIX]",
226 "[SUFFIX]",
227 ];
228
229 for token in &end_tokens {
230 if let Some(pos) = result.find(token) {
231 result.truncate(pos);
232 }
233 }
234
235 result
236}