1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
3 zeta,
4};
5use anyhow::{Context as _, Result, anyhow};
6use gpui::{App, AppContext as _, Entity, Task};
7use language::{
8 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
9 language_settings::all_language_settings,
10};
11use settings::EditPredictionPromptFormat;
12use std::{path::Path, sync::Arc, time::Instant};
13use zeta_prompt::ZetaPromptInput;
14
15const FIM_CONTEXT_TOKENS: usize = 512;
16
17struct FimRequestOutput {
18 request_id: String,
19 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
20 snapshot: BufferSnapshot,
21 response_received_at: Instant,
22 inputs: ZetaPromptInput,
23 buffer: Entity<Buffer>,
24 buffer_snapshotted_at: Instant,
25}
26
27pub fn request_prediction(
28 EditPredictionModelInput {
29 buffer,
30 snapshot,
31 position,
32 events,
33 ..
34 }: EditPredictionModelInput,
35 prompt_format: EditPredictionPromptFormat,
36 cx: &mut App,
37) -> Task<Result<Option<EditPredictionResult>>> {
38 let settings = &all_language_settings(None, cx).edit_predictions;
39 let provider = settings.provider;
40
41 let full_path: Arc<Path> = snapshot
42 .file()
43 .map(|file| file.full_path(cx))
44 .unwrap_or_else(|| "untitled".into())
45 .into();
46
47 let http_client = cx.http_client();
48 let cursor_point = position.to_point(&snapshot);
49 let buffer_snapshotted_at = Instant::now();
50
51 let Some(settings) = (match provider {
52 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
53 settings::EditPredictionProvider::OpenAiCompatibleApi => {
54 settings.open_ai_compatible_api.clone()
55 }
56 _ => None,
57 }) else {
58 return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
59 };
60
61 let result = cx.background_spawn(async move {
62 let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
63 cursor_point,
64 &snapshot,
65 FIM_CONTEXT_TOKENS,
66 0,
67 );
68 let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
69 let cursor_offset = cursor_point.to_offset(&snapshot);
70
71 let inputs = ZetaPromptInput {
72 events,
73 related_files: Vec::new(),
74 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
75 editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start
76 ..cursor_offset - excerpt_offset_range.start,
77 cursor_path: full_path.clone(),
78 excerpt_start_row: Some(excerpt_range.start.row),
79 cursor_excerpt: snapshot
80 .text_for_range(excerpt_range)
81 .collect::<String>()
82 .into(),
83 excerpt_ranges: None,
84 preferred_model: None,
85 in_open_source_repo: false,
86 can_collect_data: false,
87 };
88
89 let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
90 let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
91 let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
92 let stop_tokens = get_fim_stop_tokens();
93
94 let max_tokens = settings.max_output_tokens;
95 let (response_text, request_id) = zeta::send_custom_server_request(
96 provider,
97 &settings,
98 prompt,
99 max_tokens,
100 stop_tokens,
101 &http_client,
102 )
103 .await?;
104
105 let response_received_at = Instant::now();
106
107 log::debug!(
108 "fim: completion received ({:.2}s)",
109 (response_received_at - buffer_snapshotted_at).as_secs_f64()
110 );
111
112 let completion: Arc<str> = clean_fim_completion(&response_text).into();
113 let edits = if completion.is_empty() {
114 vec![]
115 } else {
116 let cursor_offset = cursor_point.to_offset(&snapshot);
117 let anchor = snapshot.anchor_after(cursor_offset);
118 vec![(anchor..anchor, completion)]
119 };
120
121 anyhow::Ok(FimRequestOutput {
122 request_id,
123 edits,
124 snapshot,
125 response_received_at,
126 inputs,
127 buffer,
128 buffer_snapshotted_at,
129 })
130 });
131
132 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
133 let output = result.await.context("fim edit prediction failed")?;
134 anyhow::Ok(Some(
135 EditPredictionResult::new(
136 EditPredictionId(output.request_id.into()),
137 &output.buffer,
138 &output.snapshot,
139 output.edits.into(),
140 None,
141 output.buffer_snapshotted_at,
142 output.response_received_at,
143 output.inputs,
144 cx,
145 )
146 .await,
147 ))
148 })
149}
150
151fn format_fim_prompt(
152 prompt_format: EditPredictionPromptFormat,
153 prefix: &str,
154 suffix: &str,
155) -> String {
156 match prompt_format {
157 EditPredictionPromptFormat::CodeLlama => {
158 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
159 }
160 EditPredictionPromptFormat::StarCoder => {
161 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
162 }
163 EditPredictionPromptFormat::DeepseekCoder => {
164 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
165 }
166 EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
167 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
168 }
169 EditPredictionPromptFormat::Codestral => {
170 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
171 }
172 EditPredictionPromptFormat::Glm => {
173 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
174 }
175 _ => {
176 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
177 }
178 }
179}
180
181fn get_fim_stop_tokens() -> Vec<String> {
182 vec![
183 "<|endoftext|>".to_string(),
184 "<|file_separator|>".to_string(),
185 "<|fim_pad|>".to_string(),
186 "<|fim_prefix|>".to_string(),
187 "<|fim_middle|>".to_string(),
188 "<|fim_suffix|>".to_string(),
189 "<fim_prefix>".to_string(),
190 "<fim_middle>".to_string(),
191 "<fim_suffix>".to_string(),
192 "<PRE>".to_string(),
193 "<SUF>".to_string(),
194 "<MID>".to_string(),
195 "[PREFIX]".to_string(),
196 "[SUFFIX]".to_string(),
197 ]
198}
199
200fn clean_fim_completion(response: &str) -> String {
201 let mut result = response.to_string();
202
203 let end_tokens = [
204 "<|endoftext|>",
205 "<|file_separator|>",
206 "<|fim_pad|>",
207 "<|fim_prefix|>",
208 "<|fim_middle|>",
209 "<|fim_suffix|>",
210 "<fim_prefix>",
211 "<fim_middle>",
212 "<fim_suffix>",
213 "<PRE>",
214 "<SUF>",
215 "<MID>",
216 "[PREFIX]",
217 "[SUFFIX]",
218 ];
219
220 for token in &end_tokens {
221 if let Some(pos) = result.find(token) {
222 result.truncate(pos);
223 }
224 }
225
226 result
227}