1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
3 zeta,
4};
5use anyhow::{Context as _, Result, anyhow};
6use gpui::{App, AppContext as _, Entity, Task};
7use language::{
8 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
9 language_settings::all_language_settings,
10};
11use settings::EditPredictionPromptFormat;
12use std::{path::Path, sync::Arc, time::Instant};
13use zeta_prompt::ZetaPromptInput;
14
15const FIM_CONTEXT_TOKENS: usize = 512;
16
17struct FimRequestOutput {
18 request_id: String,
19 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
20 snapshot: BufferSnapshot,
21 response_received_at: Instant,
22 inputs: ZetaPromptInput,
23 buffer: Entity<Buffer>,
24 buffer_snapshotted_at: Instant,
25}
26
27pub fn request_prediction(
28 EditPredictionModelInput {
29 buffer,
30 snapshot,
31 position,
32 events,
33 ..
34 }: EditPredictionModelInput,
35 prompt_format: EditPredictionPromptFormat,
36 cx: &mut App,
37) -> Task<Result<Option<EditPredictionResult>>> {
38 let settings = &all_language_settings(None, cx).edit_predictions;
39 let provider = settings.provider;
40
41 let full_path: Arc<Path> = snapshot
42 .file()
43 .map(|file| file.full_path(cx))
44 .unwrap_or_else(|| "untitled".into())
45 .into();
46
47 let http_client = cx.http_client();
48 let cursor_point = position.to_point(&snapshot);
49 let buffer_snapshotted_at = Instant::now();
50
51 let Some(settings) = (match provider {
52 settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
53 settings::EditPredictionProvider::OpenAiCompatibleApi => {
54 settings.open_ai_compatible_api.clone()
55 }
56 _ => None,
57 }) else {
58 return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
59 };
60
61 let result = cx.background_spawn(async move {
62 let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
63 cursor_point,
64 &snapshot,
65 FIM_CONTEXT_TOKENS,
66 0,
67 );
68 let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
69 let cursor_offset = cursor_point.to_offset(&snapshot);
70
71 let inputs = ZetaPromptInput {
72 events,
73 related_files: Vec::new(),
74 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
75 cursor_path: full_path.clone(),
76 excerpt_start_row: Some(excerpt_range.start.row),
77 cursor_excerpt: snapshot
78 .text_for_range(excerpt_range)
79 .collect::<String>()
80 .into(),
81 excerpt_ranges: Default::default(),
82 experiment: None,
83 in_open_source_repo: false,
84 can_collect_data: false,
85 };
86
87 let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
88 let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
89 let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
90 let stop_tokens = get_fim_stop_tokens();
91
92 let max_tokens = settings.max_output_tokens;
93 let (response_text, request_id) = zeta::send_custom_server_request(
94 provider,
95 &settings,
96 prompt,
97 max_tokens,
98 stop_tokens,
99 &http_client,
100 )
101 .await?;
102
103 let response_received_at = Instant::now();
104
105 log::debug!(
106 "fim: completion received ({:.2}s)",
107 (response_received_at - buffer_snapshotted_at).as_secs_f64()
108 );
109
110 let completion: Arc<str> = clean_fim_completion(&response_text).into();
111 let edits = if completion.is_empty() {
112 vec![]
113 } else {
114 let cursor_offset = cursor_point.to_offset(&snapshot);
115 let anchor = snapshot.anchor_after(cursor_offset);
116 vec![(anchor..anchor, completion)]
117 };
118
119 anyhow::Ok(FimRequestOutput {
120 request_id,
121 edits,
122 snapshot,
123 response_received_at,
124 inputs,
125 buffer,
126 buffer_snapshotted_at,
127 })
128 });
129
130 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
131 let output = result.await.context("fim edit prediction failed")?;
132 anyhow::Ok(Some(
133 EditPredictionResult::new(
134 EditPredictionId(output.request_id.into()),
135 &output.buffer,
136 &output.snapshot,
137 output.edits.into(),
138 None,
139 output.buffer_snapshotted_at,
140 output.response_received_at,
141 output.inputs,
142 None,
143 cx,
144 )
145 .await,
146 ))
147 })
148}
149
150fn format_fim_prompt(
151 prompt_format: EditPredictionPromptFormat,
152 prefix: &str,
153 suffix: &str,
154) -> String {
155 match prompt_format {
156 EditPredictionPromptFormat::CodeLlama => {
157 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
158 }
159 EditPredictionPromptFormat::StarCoder => {
160 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
161 }
162 EditPredictionPromptFormat::DeepseekCoder => {
163 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
164 }
165 EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
166 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
167 }
168 EditPredictionPromptFormat::Codestral => {
169 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
170 }
171 EditPredictionPromptFormat::Glm => {
172 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
173 }
174 _ => {
175 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
176 }
177 }
178}
179
180fn get_fim_stop_tokens() -> Vec<String> {
181 vec![
182 "<|endoftext|>".to_string(),
183 "<|file_separator|>".to_string(),
184 "<|fim_pad|>".to_string(),
185 "<|fim_prefix|>".to_string(),
186 "<|fim_middle|>".to_string(),
187 "<|fim_suffix|>".to_string(),
188 "<fim_prefix>".to_string(),
189 "<fim_middle>".to_string(),
190 "<fim_suffix>".to_string(),
191 "<PRE>".to_string(),
192 "<SUF>".to_string(),
193 "<MID>".to_string(),
194 "[PREFIX]".to_string(),
195 "[SUFFIX]".to_string(),
196 ]
197}
198
199fn clean_fim_completion(response: &str) -> String {
200 let mut result = response.to_string();
201
202 let end_tokens = [
203 "<|endoftext|>",
204 "<|file_separator|>",
205 "<|fim_pad|>",
206 "<|fim_prefix|>",
207 "<|fim_middle|>",
208 "<|fim_suffix|>",
209 "<fim_prefix>",
210 "<fim_middle>",
211 "<fim_suffix>",
212 "<PRE>",
213 "<SUF>",
214 "<MID>",
215 "[PREFIX]",
216 "[SUFFIX]",
217 ];
218
219 for token in &end_tokens {
220 if let Some(pos) = result.find(token) {
221 result.truncate(pos);
222 }
223 }
224
225 result
226}