1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12mod terminal_inline_assistant;
13
14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use fs::Fs;
22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
23pub(crate) use inline_assistant::*;
24pub(crate) use model_selector::*;
25use rustdoc::RustdocStore;
26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
27use serde::{Deserialize, Serialize};
28use settings::{Settings, SettingsStore};
29use slash_command::{
30 active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
31 project_command, prompt_command, rustdoc_command, search_command, tabs_command, term_command,
32};
33use std::{
34 fmt::{self, Display},
35 sync::Arc,
36};
37pub(crate) use streaming_diff::*;
38
39actions!(
40 assistant,
41 [
42 Assist,
43 Split,
44 CycleMessageRole,
45 QuoteSelection,
46 ToggleFocus,
47 ResetKey,
48 InlineAssist,
49 InsertActivePrompt,
50 ToggleHistory,
51 ApplyEdit,
52 ConfirmCommand,
53 ToggleModelSelector
54 ]
55);
56
57#[derive(
58 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
59)]
60struct MessageId(usize);
61
62#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
63#[serde(rename_all = "lowercase")]
64pub enum Role {
65 User,
66 Assistant,
67 System,
68}
69
70impl Role {
71 pub fn cycle(&mut self) {
72 *self = match self {
73 Role::User => Role::Assistant,
74 Role::Assistant => Role::System,
75 Role::System => Role::User,
76 }
77 }
78}
79
80impl Display for Role {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Role::User => write!(f, "user"),
84 Role::Assistant => write!(f, "assistant"),
85 Role::System => write!(f, "system"),
86 }
87 }
88}
89
90#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
91pub enum LanguageModel {
92 Cloud(CloudModel),
93 OpenAi(OpenAiModel),
94 Anthropic(AnthropicModel),
95 Ollama(OllamaModel),
96}
97
98impl Default for LanguageModel {
99 fn default() -> Self {
100 LanguageModel::Cloud(CloudModel::default())
101 }
102}
103
104impl LanguageModel {
105 pub fn telemetry_id(&self) -> String {
106 match self {
107 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
108 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
109 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
110 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
111 }
112 }
113
114 pub fn display_name(&self) -> String {
115 match self {
116 LanguageModel::OpenAi(model) => model.display_name().into(),
117 LanguageModel::Anthropic(model) => model.display_name().into(),
118 LanguageModel::Cloud(model) => model.display_name().into(),
119 LanguageModel::Ollama(model) => model.display_name().into(),
120 }
121 }
122
123 pub fn max_token_count(&self) -> usize {
124 match self {
125 LanguageModel::OpenAi(model) => model.max_token_count(),
126 LanguageModel::Anthropic(model) => model.max_token_count(),
127 LanguageModel::Cloud(model) => model.max_token_count(),
128 LanguageModel::Ollama(model) => model.max_token_count(),
129 }
130 }
131
132 pub fn id(&self) -> &str {
133 match self {
134 LanguageModel::OpenAi(model) => model.id(),
135 LanguageModel::Anthropic(model) => model.id(),
136 LanguageModel::Cloud(model) => model.id(),
137 LanguageModel::Ollama(model) => model.id(),
138 }
139 }
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143pub struct LanguageModelRequestMessage {
144 pub role: Role,
145 pub content: String,
146}
147
148impl LanguageModelRequestMessage {
149 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
150 proto::LanguageModelRequestMessage {
151 role: match self.role {
152 Role::User => proto::LanguageModelRole::LanguageModelUser,
153 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
154 Role::System => proto::LanguageModelRole::LanguageModelSystem,
155 } as i32,
156 content: self.content.clone(),
157 tool_calls: Vec::new(),
158 tool_call_id: None,
159 }
160 }
161}
162
163#[derive(Debug, Default, Serialize)]
164pub struct LanguageModelRequest {
165 pub model: LanguageModel,
166 pub messages: Vec<LanguageModelRequestMessage>,
167 pub stop: Vec<String>,
168 pub temperature: f32,
169}
170
171impl LanguageModelRequest {
172 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
173 proto::CompleteWithLanguageModel {
174 model: self.model.id().to_string(),
175 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
176 stop: self.stop.clone(),
177 temperature: self.temperature,
178 tool_choice: None,
179 tools: Vec::new(),
180 }
181 }
182
183 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
184 pub fn preprocess(&mut self) {
185 match &self.model {
186 LanguageModel::OpenAi(_) => {}
187 LanguageModel::Anthropic(_) => {}
188 LanguageModel::Ollama(_) => {}
189 LanguageModel::Cloud(model) => match model {
190 CloudModel::Claude3Opus
191 | CloudModel::Claude3Sonnet
192 | CloudModel::Claude3Haiku
193 | CloudModel::Claude3_5Sonnet => {
194 preprocess_anthropic_request(self);
195 }
196 _ => {}
197 },
198 }
199 }
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct LanguageModelResponseMessage {
204 pub role: Option<Role>,
205 pub content: Option<String>,
206}
207
208#[derive(Deserialize, Debug)]
209pub struct LanguageModelUsage {
210 pub prompt_tokens: u32,
211 pub completion_tokens: u32,
212 pub total_tokens: u32,
213}
214
215#[derive(Deserialize, Debug)]
216pub struct LanguageModelChoiceDelta {
217 pub index: u32,
218 pub delta: LanguageModelResponseMessage,
219 pub finish_reason: Option<String>,
220}
221
222#[derive(Clone, Debug, Serialize, Deserialize)]
223struct MessageMetadata {
224 role: Role,
225 status: MessageStatus,
226}
227
228#[derive(Clone, Debug, Serialize, Deserialize)]
229enum MessageStatus {
230 Pending,
231 Done,
232 Error(SharedString),
233}
234
235/// The state pertaining to the Assistant.
236#[derive(Default)]
237struct Assistant {
238 /// Whether the Assistant is enabled.
239 enabled: bool,
240}
241
242impl Global for Assistant {}
243
244impl Assistant {
245 const NAMESPACE: &'static str = "assistant";
246
247 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
248 if self.enabled == enabled {
249 return;
250 }
251
252 self.enabled = enabled;
253
254 if !enabled {
255 CommandPaletteFilter::update_global(cx, |filter, _cx| {
256 filter.hide_namespace(Self::NAMESPACE);
257 });
258
259 return;
260 }
261
262 CommandPaletteFilter::update_global(cx, |filter, _cx| {
263 filter.show_namespace(Self::NAMESPACE);
264 });
265 }
266}
267
268pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
269 cx.set_global(Assistant::default());
270 AssistantSettings::register(cx);
271
272 cx.spawn(|mut cx| {
273 let client = client.clone();
274 async move {
275 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
276 let semantic_index = SemanticIndex::new(
277 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
278 Arc::new(embedding_provider),
279 &mut cx,
280 )
281 .await?;
282 cx.update(|cx| cx.set_global(semantic_index))
283 }
284 })
285 .detach();
286
287 prompt_library::init(cx);
288 completion_provider::init(client.clone(), cx);
289 assistant_slash_command::init(cx);
290 register_slash_commands(cx);
291 assistant_panel::init(cx);
292 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
293 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
294 RustdocStore::init_global(cx);
295
296 CommandPaletteFilter::update_global(cx, |filter, _cx| {
297 filter.hide_namespace(Assistant::NAMESPACE);
298 });
299 Assistant::update_global(cx, |assistant, cx| {
300 let settings = AssistantSettings::get_global(cx);
301
302 assistant.set_enabled(settings.enabled, cx);
303 });
304 cx.observe_global::<SettingsStore>(|cx| {
305 Assistant::update_global(cx, |assistant, cx| {
306 let settings = AssistantSettings::get_global(cx);
307 assistant.set_enabled(settings.enabled, cx);
308 });
309 })
310 .detach();
311}
312
313fn register_slash_commands(cx: &mut AppContext) {
314 let slash_command_registry = SlashCommandRegistry::global(cx);
315 slash_command_registry.register_command(file_command::FileSlashCommand, true);
316 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
317 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
318 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
319 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
320 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
321 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
322 slash_command_registry.register_command(term_command::TermSlashCommand, true);
323 slash_command_registry.register_command(now_command::NowSlashCommand, true);
324 slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
325 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
326 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
327}
328
329pub fn humanize_token_count(count: usize) -> String {
330 match count {
331 0..=999 => count.to_string(),
332 1000..=9999 => {
333 let thousands = count / 1000;
334 let hundreds = (count % 1000 + 50) / 100;
335 if hundreds == 0 {
336 format!("{}k", thousands)
337 } else if hundreds == 10 {
338 format!("{}k", thousands + 1)
339 } else {
340 format!("{}.{}k", thousands, hundreds)
341 }
342 }
343 _ => format!("{}k", (count + 500) / 1000),
344 }
345}
346
347#[cfg(test)]
348#[ctor::ctor]
349fn init_logger() {
350 if std::env::var("RUST_LOG").is_ok() {
351 env_logger::init();
352 }
353}