1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod search;
8mod slash_command;
9mod streaming_diff;
10
11pub use assistant_panel::AssistantPanel;
12
13use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
14use client::{proto, Client};
15use command_palette_hooks::CommandPaletteFilter;
16pub(crate) use completion_provider::*;
17use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
18pub(crate) use saved_conversation::*;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::{
22 fmt::{self, Display},
23 sync::Arc,
24};
25
26actions!(
27 assistant,
28 [
29 Assist,
30 Split,
31 CycleMessageRole,
32 QuoteSelection,
33 ToggleFocus,
34 ResetKey,
35 InlineAssist,
36 InsertActivePrompt,
37 ToggleIncludeConversation,
38 ToggleHistory,
39 ApplyEdit,
40 ConfirmCommand
41 ]
42);
43
44#[derive(
45 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
46)]
47struct MessageId(usize);
48
49#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
50#[serde(rename_all = "lowercase")]
51pub enum Role {
52 User,
53 Assistant,
54 System,
55}
56
57impl Role {
58 pub fn cycle(&mut self) {
59 *self = match self {
60 Role::User => Role::Assistant,
61 Role::Assistant => Role::System,
62 Role::System => Role::User,
63 }
64 }
65}
66
67impl Display for Role {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 Role::User => write!(f, "user"),
71 Role::Assistant => write!(f, "assistant"),
72 Role::System => write!(f, "system"),
73 }
74 }
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
78pub enum LanguageModel {
79 ZedDotDev(ZedDotDevModel),
80 OpenAi(OpenAiModel),
81 Anthropic(AnthropicModel),
82}
83
84impl Default for LanguageModel {
85 fn default() -> Self {
86 LanguageModel::ZedDotDev(ZedDotDevModel::default())
87 }
88}
89
90impl LanguageModel {
91 pub fn telemetry_id(&self) -> String {
92 match self {
93 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
94 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
95 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
96 }
97 }
98
99 pub fn display_name(&self) -> String {
100 match self {
101 LanguageModel::OpenAi(model) => model.display_name().into(),
102 LanguageModel::Anthropic(model) => model.display_name().into(),
103 LanguageModel::ZedDotDev(model) => model.display_name().into(),
104 }
105 }
106
107 pub fn max_token_count(&self) -> usize {
108 match self {
109 LanguageModel::OpenAi(model) => model.max_token_count(),
110 LanguageModel::Anthropic(model) => model.max_token_count(),
111 LanguageModel::ZedDotDev(model) => model.max_token_count(),
112 }
113 }
114
115 pub fn id(&self) -> &str {
116 match self {
117 LanguageModel::OpenAi(model) => model.id(),
118 LanguageModel::Anthropic(model) => model.id(),
119 LanguageModel::ZedDotDev(model) => model.id(),
120 }
121 }
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125pub struct LanguageModelRequestMessage {
126 pub role: Role,
127 pub content: String,
128}
129
130impl LanguageModelRequestMessage {
131 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
132 proto::LanguageModelRequestMessage {
133 role: match self.role {
134 Role::User => proto::LanguageModelRole::LanguageModelUser,
135 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
136 Role::System => proto::LanguageModelRole::LanguageModelSystem,
137 } as i32,
138 content: self.content.clone(),
139 tool_calls: Vec::new(),
140 tool_call_id: None,
141 }
142 }
143}
144
145#[derive(Debug, Default, Serialize)]
146pub struct LanguageModelRequest {
147 pub model: LanguageModel,
148 pub messages: Vec<LanguageModelRequestMessage>,
149 pub stop: Vec<String>,
150 pub temperature: f32,
151}
152
153impl LanguageModelRequest {
154 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
155 proto::CompleteWithLanguageModel {
156 model: self.model.id().to_string(),
157 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
158 stop: self.stop.clone(),
159 temperature: self.temperature,
160 tool_choice: None,
161 tools: Vec::new(),
162 }
163 }
164}
165
166#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
167pub struct LanguageModelResponseMessage {
168 pub role: Option<Role>,
169 pub content: Option<String>,
170}
171
172#[derive(Deserialize, Debug)]
173pub struct LanguageModelUsage {
174 pub prompt_tokens: u32,
175 pub completion_tokens: u32,
176 pub total_tokens: u32,
177}
178
179#[derive(Deserialize, Debug)]
180pub struct LanguageModelChoiceDelta {
181 pub index: u32,
182 pub delta: LanguageModelResponseMessage,
183 pub finish_reason: Option<String>,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
187struct MessageMetadata {
188 role: Role,
189 status: MessageStatus,
190}
191
192#[derive(Clone, Debug, Serialize, Deserialize)]
193enum MessageStatus {
194 Pending,
195 Done,
196 Error(SharedString),
197}
198
199/// The state pertaining to the Assistant.
200#[derive(Default)]
201struct Assistant {
202 /// Whether the Assistant is enabled.
203 enabled: bool,
204}
205
206impl Global for Assistant {}
207
208impl Assistant {
209 const NAMESPACE: &'static str = "assistant";
210
211 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
212 if self.enabled == enabled {
213 return;
214 }
215
216 self.enabled = enabled;
217
218 if !enabled {
219 CommandPaletteFilter::update_global(cx, |filter, _cx| {
220 filter.hide_namespace(Self::NAMESPACE);
221 });
222
223 return;
224 }
225
226 CommandPaletteFilter::update_global(cx, |filter, _cx| {
227 filter.show_namespace(Self::NAMESPACE);
228 });
229 }
230}
231
232pub fn init(client: Arc<Client>, cx: &mut AppContext) {
233 cx.set_global(Assistant::default());
234 AssistantSettings::register(cx);
235 completion_provider::init(client, cx);
236 assistant_slash_command::init(cx);
237 assistant_panel::init(cx);
238
239 CommandPaletteFilter::update_global(cx, |filter, _cx| {
240 filter.hide_namespace(Assistant::NAMESPACE);
241 });
242 Assistant::update_global(cx, |assistant, cx| {
243 let settings = AssistantSettings::get_global(cx);
244
245 assistant.set_enabled(settings.enabled, cx);
246 });
247 cx.observe_global::<SettingsStore>(|cx| {
248 Assistant::update_global(cx, |assistant, cx| {
249 let settings = AssistantSettings::get_global(cx);
250
251 assistant.set_enabled(settings.enabled, cx);
252 });
253 })
254 .detach();
255}
256
257#[cfg(test)]
258#[ctor::ctor]
259fn init_logger() {
260 if std::env::var("RUST_LOG").is_ok() {
261 env_logger::init();
262 }
263}