1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod model_selector;
6mod prompt_library;
7mod prompts;
8mod saved_conversation;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::AssistantPanel;
14
15use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
21pub(crate) use model_selector::*;
22use prompt_library::PromptStore;
23pub(crate) use saved_conversation::*;
24use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, SettingsStore};
27use slash_command::{
28 active_command, file_command, project_command, prompt_command, rustdoc_command, search_command,
29 tabs_command,
30};
31use std::{
32 fmt::{self, Display},
33 sync::Arc,
34};
35use util::paths::EMBEDDINGS_DIR;
36
37actions!(
38 assistant,
39 [
40 Assist,
41 Split,
42 CycleMessageRole,
43 QuoteSelection,
44 ToggleFocus,
45 ResetKey,
46 InlineAssist,
47 InsertActivePrompt,
48 ToggleHistory,
49 ApplyEdit,
50 ConfirmCommand,
51 ToggleModelSelector
52 ]
53);
54
55#[derive(
56 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
57)]
58struct MessageId(usize);
59
60#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
61#[serde(rename_all = "lowercase")]
62pub enum Role {
63 User,
64 Assistant,
65 System,
66}
67
68impl Role {
69 pub fn cycle(&mut self) {
70 *self = match self {
71 Role::User => Role::Assistant,
72 Role::Assistant => Role::System,
73 Role::System => Role::User,
74 }
75 }
76}
77
78impl Display for Role {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Role::User => write!(f, "user"),
82 Role::Assistant => write!(f, "assistant"),
83 Role::System => write!(f, "system"),
84 }
85 }
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
89pub enum LanguageModel {
90 ZedDotDev(ZedDotDevModel),
91 OpenAi(OpenAiModel),
92 Anthropic(AnthropicModel),
93}
94
95impl Default for LanguageModel {
96 fn default() -> Self {
97 LanguageModel::ZedDotDev(ZedDotDevModel::default())
98 }
99}
100
101impl LanguageModel {
102 pub fn telemetry_id(&self) -> String {
103 match self {
104 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
105 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
106 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
107 }
108 }
109
110 pub fn display_name(&self) -> String {
111 match self {
112 LanguageModel::OpenAi(model) => model.display_name().into(),
113 LanguageModel::Anthropic(model) => model.display_name().into(),
114 LanguageModel::ZedDotDev(model) => model.display_name().into(),
115 }
116 }
117
118 pub fn max_token_count(&self) -> usize {
119 match self {
120 LanguageModel::OpenAi(model) => model.max_token_count(),
121 LanguageModel::Anthropic(model) => model.max_token_count(),
122 LanguageModel::ZedDotDev(model) => model.max_token_count(),
123 }
124 }
125
126 pub fn id(&self) -> &str {
127 match self {
128 LanguageModel::OpenAi(model) => model.id(),
129 LanguageModel::Anthropic(model) => model.id(),
130 LanguageModel::ZedDotDev(model) => model.id(),
131 }
132 }
133}
134
135#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
136pub struct LanguageModelRequestMessage {
137 pub role: Role,
138 pub content: String,
139}
140
141impl LanguageModelRequestMessage {
142 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
143 proto::LanguageModelRequestMessage {
144 role: match self.role {
145 Role::User => proto::LanguageModelRole::LanguageModelUser,
146 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
147 Role::System => proto::LanguageModelRole::LanguageModelSystem,
148 } as i32,
149 content: self.content.clone(),
150 tool_calls: Vec::new(),
151 tool_call_id: None,
152 }
153 }
154}
155
156#[derive(Debug, Default, Serialize)]
157pub struct LanguageModelRequest {
158 pub model: LanguageModel,
159 pub messages: Vec<LanguageModelRequestMessage>,
160 pub stop: Vec<String>,
161 pub temperature: f32,
162}
163
164impl LanguageModelRequest {
165 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
166 proto::CompleteWithLanguageModel {
167 model: self.model.id().to_string(),
168 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
169 stop: self.stop.clone(),
170 temperature: self.temperature,
171 tool_choice: None,
172 tools: Vec::new(),
173 }
174 }
175}
176
177#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
178pub struct LanguageModelResponseMessage {
179 pub role: Option<Role>,
180 pub content: Option<String>,
181}
182
183#[derive(Deserialize, Debug)]
184pub struct LanguageModelUsage {
185 pub prompt_tokens: u32,
186 pub completion_tokens: u32,
187 pub total_tokens: u32,
188}
189
190#[derive(Deserialize, Debug)]
191pub struct LanguageModelChoiceDelta {
192 pub index: u32,
193 pub delta: LanguageModelResponseMessage,
194 pub finish_reason: Option<String>,
195}
196
197#[derive(Clone, Debug, Serialize, Deserialize)]
198struct MessageMetadata {
199 role: Role,
200 status: MessageStatus,
201}
202
203#[derive(Clone, Debug, Serialize, Deserialize)]
204enum MessageStatus {
205 Pending,
206 Done,
207 Error(SharedString),
208}
209
210/// The state pertaining to the Assistant.
211#[derive(Default)]
212struct Assistant {
213 /// Whether the Assistant is enabled.
214 enabled: bool,
215}
216
217impl Global for Assistant {}
218
219impl Assistant {
220 const NAMESPACE: &'static str = "assistant";
221
222 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
223 if self.enabled == enabled {
224 return;
225 }
226
227 self.enabled = enabled;
228
229 if !enabled {
230 CommandPaletteFilter::update_global(cx, |filter, _cx| {
231 filter.hide_namespace(Self::NAMESPACE);
232 });
233
234 return;
235 }
236
237 CommandPaletteFilter::update_global(cx, |filter, _cx| {
238 filter.show_namespace(Self::NAMESPACE);
239 });
240 }
241}
242
243pub fn init(client: Arc<Client>, cx: &mut AppContext) {
244 cx.set_global(Assistant::default());
245 AssistantSettings::register(cx);
246
247 cx.spawn(|mut cx| {
248 let client = client.clone();
249 async move {
250 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
251 let semantic_index = SemanticIndex::new(
252 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
253 Arc::new(embedding_provider),
254 &mut cx,
255 )
256 .await?;
257 cx.update(|cx| cx.set_global(semantic_index))
258 }
259 })
260 .detach();
261
262 prompt_library::init(cx);
263 completion_provider::init(client, cx);
264 assistant_slash_command::init(cx);
265 register_slash_commands(cx);
266 assistant_panel::init(cx);
267
268 CommandPaletteFilter::update_global(cx, |filter, _cx| {
269 filter.hide_namespace(Assistant::NAMESPACE);
270 });
271 Assistant::update_global(cx, |assistant, cx| {
272 let settings = AssistantSettings::get_global(cx);
273
274 assistant.set_enabled(settings.enabled, cx);
275 });
276 cx.observe_global::<SettingsStore>(|cx| {
277 Assistant::update_global(cx, |assistant, cx| {
278 let settings = AssistantSettings::get_global(cx);
279 assistant.set_enabled(settings.enabled, cx);
280 });
281 })
282 .detach();
283}
284
285fn register_slash_commands(cx: &mut AppContext) {
286 let slash_command_registry = SlashCommandRegistry::global(cx);
287 slash_command_registry.register_command(file_command::FileSlashCommand, true);
288 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
289 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
290 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
291 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
292 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
293
294 let store = PromptStore::global(cx);
295 cx.background_executor()
296 .spawn(async move {
297 let store = store.await?;
298 slash_command_registry
299 .register_command(prompt_command::PromptSlashCommand::new(store), true);
300 anyhow::Ok(())
301 })
302 .detach_and_log_err(cx);
303}
304
305#[cfg(test)]
306#[ctor::ctor]
307fn init_logger() {
308 if std::env::var("RUST_LOG").is_ok() {
309 env_logger::init();
310 }
311}