1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt,
3 open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
4 prediction::EditPredictionResult,
5};
6use anyhow::{Context as _, Result, anyhow};
7use gpui::{App, AppContext as _, Entity, Task};
8use language::{
9 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
10 language_settings::all_language_settings,
11};
12use settings::EditPredictionPromptFormat;
13use std::{path::Path, sync::Arc, time::Instant};
14use zeta_prompt::ZetaPromptInput;
15
16const FIM_CONTEXT_TOKENS: usize = 512;
17
18struct FimRequestOutput {
19 request_id: String,
20 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
21 snapshot: BufferSnapshot,
22 response_received_at: Instant,
23 inputs: ZetaPromptInput,
24 buffer: Entity<Buffer>,
25 buffer_snapshotted_at: Instant,
26}
27
28pub fn request_prediction(
29 EditPredictionModelInput {
30 buffer,
31 snapshot,
32 position,
33 events,
34 ..
35 }: EditPredictionModelInput,
36 prompt_format: EditPredictionPromptFormat,
37 cx: &mut App,
38) -> Task<Result<Option<EditPredictionResult>>> {
39 let settings = &all_language_settings(None, cx).edit_predictions;
40 let provider = settings.provider;
41
42 let full_path: Arc<Path> = snapshot
43 .file()
44 .map(|file| file.full_path(cx))
45 .unwrap_or_else(|| "untitled".into())
46 .into();
47
48 let http_client = cx.http_client();
49 let cursor_point = position.to_point(&snapshot);
50 let buffer_snapshotted_at = Instant::now();
51
52 let Some(settings) = (match provider {
53 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
54 settings::EditPredictionProvider::OpenAiCompatibleApi => {
55 settings.open_ai_compatible_api.clone()
56 }
57 _ => None,
58 }) else {
59 return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
60 };
61
62 let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
63
64 let result = cx.background_spawn(async move {
65 let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
66 cursor_point,
67 &snapshot,
68 FIM_CONTEXT_TOKENS,
69 0,
70 );
71 let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
72 let cursor_offset = cursor_point.to_offset(&snapshot);
73
74 let inputs = ZetaPromptInput {
75 events,
76 related_files: Vec::new(),
77 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
78 cursor_path: full_path.clone(),
79 excerpt_start_row: Some(excerpt_range.start.row),
80 cursor_excerpt: snapshot
81 .text_for_range(excerpt_range)
82 .collect::<String>()
83 .into(),
84 excerpt_ranges: Default::default(),
85 experiment: None,
86 in_open_source_repo: false,
87 can_collect_data: false,
88 repo_url: None,
89 };
90
91 let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
92 let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
93 let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
94 let stop_tokens = get_fim_stop_tokens();
95
96 let max_tokens = settings.max_output_tokens;
97
98 let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
99 provider,
100 &settings,
101 prompt,
102 max_tokens,
103 stop_tokens,
104 api_key,
105 &http_client,
106 )
107 .await?;
108
109 let response_received_at = Instant::now();
110
111 log::debug!(
112 "fim: completion received ({:.2}s)",
113 (response_received_at - buffer_snapshotted_at).as_secs_f64()
114 );
115
116 let completion: Arc<str> = clean_fim_completion(&response_text).into();
117 let edits = if completion.is_empty() {
118 vec![]
119 } else {
120 let cursor_offset = cursor_point.to_offset(&snapshot);
121 let anchor = snapshot.anchor_after(cursor_offset);
122 vec![(anchor..anchor, completion)]
123 };
124
125 anyhow::Ok(FimRequestOutput {
126 request_id,
127 edits,
128 snapshot,
129 response_received_at,
130 inputs,
131 buffer,
132 buffer_snapshotted_at,
133 })
134 });
135
136 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
137 let output = result.await.context("fim edit prediction failed")?;
138 anyhow::Ok(Some(
139 EditPredictionResult::new(
140 EditPredictionId(output.request_id.into()),
141 &output.buffer,
142 &output.snapshot,
143 output.edits.into(),
144 None,
145 output.buffer_snapshotted_at,
146 output.response_received_at,
147 output.inputs,
148 None,
149 cx,
150 )
151 .await,
152 ))
153 })
154}
155
156fn format_fim_prompt(
157 prompt_format: EditPredictionPromptFormat,
158 prefix: &str,
159 suffix: &str,
160) -> String {
161 match prompt_format {
162 EditPredictionPromptFormat::CodeLlama => {
163 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
164 }
165 EditPredictionPromptFormat::StarCoder => {
166 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
167 }
168 EditPredictionPromptFormat::DeepseekCoder => {
169 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
170 }
171 EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
172 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
173 }
174 EditPredictionPromptFormat::Codestral => {
175 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
176 }
177 EditPredictionPromptFormat::Glm => {
178 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
179 }
180 _ => {
181 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
182 }
183 }
184}
185
186fn get_fim_stop_tokens() -> Vec<String> {
187 vec![
188 "<|endoftext|>".to_string(),
189 "<|file_separator|>".to_string(),
190 "<|fim_pad|>".to_string(),
191 "<|fim_prefix|>".to_string(),
192 "<|fim_middle|>".to_string(),
193 "<|fim_suffix|>".to_string(),
194 "<fim_prefix>".to_string(),
195 "<fim_middle>".to_string(),
196 "<fim_suffix>".to_string(),
197 "<PRE>".to_string(),
198 "<SUF>".to_string(),
199 "<MID>".to_string(),
200 "[PREFIX]".to_string(),
201 "[SUFFIX]".to_string(),
202 ]
203}
204
205fn clean_fim_completion(response: &str) -> String {
206 let mut result = response.to_string();
207
208 let end_tokens = [
209 "<|endoftext|>",
210 "<|file_separator|>",
211 "<|fim_pad|>",
212 "<|fim_prefix|>",
213 "<|fim_middle|>",
214 "<|fim_suffix|>",
215 "<fim_prefix>",
216 "<fim_middle>",
217 "<fim_suffix>",
218 "<PRE>",
219 "<SUF>",
220 "<MID>",
221 "[PREFIX]",
222 "[SUFFIX]",
223 ];
224
225 for token in &end_tokens {
226 if let Some(pos) = result.find(token) {
227 result.truncate(pos);
228 }
229 }
230
231 result
232}