1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, OllamaToolCall,
8};
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::{Settings, SettingsStore};
12use std::{collections::BTreeMap, sync::Arc, time::Duration};
13use ui::{prelude::*, ButtonLike, Indicator};
14use util::ResultExt;
15
16use crate::LanguageModelCompletionEvent;
17use crate::{
18 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
19 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
20 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
21};
22
23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
25const OLLAMA_SITE: &str = "https://ollama.com/";
26
27const PROVIDER_ID: &str = "ollama";
28const PROVIDER_NAME: &str = "Ollama";
29
30#[derive(Default, Debug, Clone, PartialEq)]
31pub struct OllamaSettings {
32 pub api_url: String,
33 pub low_speed_timeout: Option<Duration>,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the Ollama API (e.g. "llama3.1:latest")
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
44 pub max_tokens: usize,
45}
46
47pub struct OllamaLanguageModelProvider {
48 http_client: Arc<dyn HttpClient>,
49 state: gpui::Model<State>,
50}
51
52pub struct State {
53 http_client: Arc<dyn HttpClient>,
54 available_models: Vec<ollama::Model>,
55 _subscription: Subscription,
56}
57
58impl State {
59 fn is_authenticated(&self) -> bool {
60 !self.available_models.is_empty()
61 }
62
63 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
64 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
65 let http_client = self.http_client.clone();
66 let api_url = settings.api_url.clone();
67
68 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
69 cx.spawn(|this, mut cx| async move {
70 let models = get_models(http_client.as_ref(), &api_url, None).await?;
71
72 let mut models: Vec<ollama::Model> = models
73 .into_iter()
74 // Since there is no metadata from the Ollama API
75 // indicating which models are embedding models,
76 // simply filter out models with "-embed" in their name
77 .filter(|model| !model.name.contains("-embed"))
78 .map(|model| ollama::Model::new(&model.name, None, None))
79 .collect();
80
81 models.sort_by(|a, b| a.name.cmp(&b.name));
82
83 this.update(&mut cx, |this, cx| {
84 this.available_models = models;
85 cx.notify();
86 })
87 })
88 }
89
90 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
91 if self.is_authenticated() {
92 Task::ready(Ok(()))
93 } else {
94 self.fetch_models(cx)
95 }
96 }
97}
98
99impl OllamaLanguageModelProvider {
100 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
101 let this = Self {
102 http_client: http_client.clone(),
103 state: cx.new_model(|cx| State {
104 http_client,
105 available_models: Default::default(),
106 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
107 this.fetch_models(cx).detach();
108 cx.notify();
109 }),
110 }),
111 };
112 this.state
113 .update(cx, |state, cx| state.fetch_models(cx).detach());
114 this
115 }
116}
117
118impl LanguageModelProviderState for OllamaLanguageModelProvider {
119 type ObservableEntity = State;
120
121 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
122 Some(self.state.clone())
123 }
124}
125
126impl LanguageModelProvider for OllamaLanguageModelProvider {
127 fn id(&self) -> LanguageModelProviderId {
128 LanguageModelProviderId(PROVIDER_ID.into())
129 }
130
131 fn name(&self) -> LanguageModelProviderName {
132 LanguageModelProviderName(PROVIDER_NAME.into())
133 }
134
135 fn icon(&self) -> IconName {
136 IconName::AiOllama
137 }
138
139 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
140 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
141
142 // Add models from the Ollama API
143 for model in self.state.read(cx).available_models.iter() {
144 models.insert(model.name.clone(), model.clone());
145 }
146
147 // Override with available models from settings
148 for model in AllLanguageModelSettings::get_global(cx)
149 .ollama
150 .available_models
151 .iter()
152 {
153 models.insert(
154 model.name.clone(),
155 ollama::Model {
156 name: model.name.clone(),
157 display_name: model.display_name.clone(),
158 max_tokens: model.max_tokens,
159 keep_alive: None,
160 },
161 );
162 }
163
164 models
165 .into_values()
166 .map(|model| {
167 Arc::new(OllamaLanguageModel {
168 id: LanguageModelId::from(model.name.clone()),
169 model: model.clone(),
170 http_client: self.http_client.clone(),
171 request_limiter: RateLimiter::new(4),
172 }) as Arc<dyn LanguageModel>
173 })
174 .collect()
175 }
176
177 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
178 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
179 let http_client = self.http_client.clone();
180 let api_url = settings.api_url.clone();
181 let id = model.id().0.to_string();
182 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
183 .detach_and_log_err(cx);
184 }
185
186 fn is_authenticated(&self, cx: &AppContext) -> bool {
187 self.state.read(cx).is_authenticated()
188 }
189
190 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
191 self.state.update(cx, |state, cx| state.authenticate(cx))
192 }
193
194 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
195 let state = self.state.clone();
196 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
197 }
198
199 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
200 self.state.update(cx, |state, cx| state.fetch_models(cx))
201 }
202}
203
204pub struct OllamaLanguageModel {
205 id: LanguageModelId,
206 model: ollama::Model,
207 http_client: Arc<dyn HttpClient>,
208 request_limiter: RateLimiter,
209}
210
211impl OllamaLanguageModel {
212 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
213 ChatRequest {
214 model: self.model.name.clone(),
215 messages: request
216 .messages
217 .into_iter()
218 .map(|msg| match msg.role {
219 Role::User => ChatMessage::User {
220 content: msg.string_contents(),
221 },
222 Role::Assistant => ChatMessage::Assistant {
223 content: msg.string_contents(),
224 tool_calls: None,
225 },
226 Role::System => ChatMessage::System {
227 content: msg.string_contents(),
228 },
229 })
230 .collect(),
231 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
232 stream: true,
233 options: Some(ChatOptions {
234 num_ctx: Some(self.model.max_tokens),
235 stop: Some(request.stop),
236 temperature: Some(request.temperature),
237 ..Default::default()
238 }),
239 tools: vec![],
240 }
241 }
242 fn request_completion(
243 &self,
244 request: ChatRequest,
245 cx: &AsyncAppContext,
246 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
247 let http_client = self.http_client.clone();
248
249 let Ok(api_url) = cx.update(|cx| {
250 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
251 settings.api_url.clone()
252 }) else {
253 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
254 };
255
256 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
257 }
258}
259
260impl LanguageModel for OllamaLanguageModel {
261 fn id(&self) -> LanguageModelId {
262 self.id.clone()
263 }
264
265 fn name(&self) -> LanguageModelName {
266 LanguageModelName::from(self.model.display_name().to_string())
267 }
268
269 fn provider_id(&self) -> LanguageModelProviderId {
270 LanguageModelProviderId(PROVIDER_ID.into())
271 }
272
273 fn provider_name(&self) -> LanguageModelProviderName {
274 LanguageModelProviderName(PROVIDER_NAME.into())
275 }
276
277 fn telemetry_id(&self) -> String {
278 format!("ollama/{}", self.model.id())
279 }
280
281 fn max_token_count(&self) -> usize {
282 self.model.max_token_count()
283 }
284
285 fn count_tokens(
286 &self,
287 request: LanguageModelRequest,
288 _cx: &AppContext,
289 ) -> BoxFuture<'static, Result<usize>> {
290 // There is no endpoint for this _yet_ in Ollama
291 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
292 let token_count = request
293 .messages
294 .iter()
295 .map(|msg| msg.string_contents().chars().count())
296 .sum::<usize>()
297 / 4;
298
299 async move { Ok(token_count) }.boxed()
300 }
301
302 fn stream_completion(
303 &self,
304 request: LanguageModelRequest,
305 cx: &AsyncAppContext,
306 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
307 let request = self.to_ollama_request(request);
308
309 let http_client = self.http_client.clone();
310 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
311 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
312 (settings.api_url.clone(), settings.low_speed_timeout)
313 }) else {
314 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
315 };
316
317 let future = self.request_limiter.stream(async move {
318 let response =
319 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
320 .await?;
321 let stream = response
322 .filter_map(|response| async move {
323 match response {
324 Ok(delta) => {
325 let content = match delta.message {
326 ChatMessage::User { content } => content,
327 ChatMessage::Assistant { content, .. } => content,
328 ChatMessage::System { content } => content,
329 };
330 Some(Ok(content))
331 }
332 Err(error) => Some(Err(error)),
333 }
334 })
335 .boxed();
336 Ok(stream)
337 });
338
339 async move {
340 Ok(future
341 .await?
342 .map(|result| result.map(LanguageModelCompletionEvent::Text))
343 .boxed())
344 }
345 .boxed()
346 }
347
348 fn use_any_tool(
349 &self,
350 request: LanguageModelRequest,
351 tool_name: String,
352 tool_description: String,
353 schema: serde_json::Value,
354 cx: &AsyncAppContext,
355 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
356 use ollama::{OllamaFunctionTool, OllamaTool};
357 let function = OllamaFunctionTool {
358 name: tool_name.clone(),
359 description: Some(tool_description),
360 parameters: Some(schema),
361 };
362 let tools = vec![OllamaTool::Function { function }];
363 let request = self.to_ollama_request(request).with_tools(tools);
364 let response = self.request_completion(request, cx);
365 self.request_limiter
366 .run(async move {
367 let response = response.await?;
368 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
369 bail!("message does not have an assistant role");
370 };
371 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
372 for call in tool_calls {
373 let OllamaToolCall::Function(function) = call;
374 if function.name == tool_name {
375 return Ok(futures::stream::once(async move {
376 Ok(function.arguments.to_string())
377 })
378 .boxed());
379 }
380 }
381 } else {
382 bail!("assistant message does not have any tool calls");
383 };
384
385 bail!("tool not used")
386 })
387 .boxed()
388 }
389}
390
391struct ConfigurationView {
392 state: gpui::Model<State>,
393 loading_models_task: Option<Task<()>>,
394}
395
396impl ConfigurationView {
397 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
398 let loading_models_task = Some(cx.spawn({
399 let state = state.clone();
400 |this, mut cx| async move {
401 if let Some(task) = state
402 .update(&mut cx, |state, cx| state.authenticate(cx))
403 .log_err()
404 {
405 task.await.log_err();
406 }
407 this.update(&mut cx, |this, cx| {
408 this.loading_models_task = None;
409 cx.notify();
410 })
411 .log_err();
412 }
413 }));
414
415 Self {
416 state,
417 loading_models_task,
418 }
419 }
420
421 fn retry_connection(&self, cx: &mut WindowContext) {
422 self.state
423 .update(cx, |state, cx| state.fetch_models(cx))
424 .detach_and_log_err(cx);
425 }
426}
427
428impl Render for ConfigurationView {
429 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
430 let is_authenticated = self.state.read(cx).is_authenticated();
431
432 let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
433 let ollama_reqs =
434 "Ollama must be running with at least one model installed to use it in the assistant.";
435
436 let mut inline_code_bg = cx.theme().colors().editor_background;
437 inline_code_bg.fade_out(0.5);
438
439 if self.loading_models_task.is_some() {
440 div().child(Label::new("Loading models...")).into_any()
441 } else {
442 v_flex()
443 .size_full()
444 .gap_3()
445 .child(
446 v_flex()
447 .size_full()
448 .gap_2()
449 .p_1()
450 .child(Label::new(ollama_intro))
451 .child(Label::new(ollama_reqs))
452 .child(
453 h_flex()
454 .gap_0p5()
455 .child(Label::new("Once installed, try "))
456 .child(
457 div()
458 .bg(inline_code_bg)
459 .px_1p5()
460 .rounded_md()
461 .child(Label::new("ollama run llama3.1")),
462 ),
463 ),
464 )
465 .child(
466 h_flex()
467 .w_full()
468 .pt_2()
469 .justify_between()
470 .gap_2()
471 .child(
472 h_flex()
473 .w_full()
474 .gap_2()
475 .map(|this| {
476 if is_authenticated {
477 this.child(
478 Button::new("ollama-site", "Ollama")
479 .style(ButtonStyle::Subtle)
480 .icon(IconName::ExternalLink)
481 .icon_size(IconSize::XSmall)
482 .icon_color(Color::Muted)
483 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
484 .into_any_element(),
485 )
486 } else {
487 this.child(
488 Button::new(
489 "download_ollama_button",
490 "Download Ollama",
491 )
492 .style(ButtonStyle::Subtle)
493 .icon(IconName::ExternalLink)
494 .icon_size(IconSize::XSmall)
495 .icon_color(Color::Muted)
496 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
497 .into_any_element(),
498 )
499 }
500 })
501 .child(
502 Button::new("view-models", "All Models")
503 .style(ButtonStyle::Subtle)
504 .icon(IconName::ExternalLink)
505 .icon_size(IconSize::XSmall)
506 .icon_color(Color::Muted)
507 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
508 ),
509 )
510 .child(if is_authenticated {
511 // This is only a button to ensure the spacing is correct
512 // it should stay disabled
513 ButtonLike::new("connected")
514 .disabled(true)
515 // Since this won't ever be clickable, we can use the arrow cursor
516 .cursor_style(gpui::CursorStyle::Arrow)
517 .child(
518 h_flex()
519 .gap_2()
520 .child(Indicator::dot().color(Color::Success))
521 .child(Label::new("Connected"))
522 .into_any_element(),
523 )
524 .into_any_element()
525 } else {
526 Button::new("retry_ollama_models", "Connect")
527 .icon_position(IconPosition::Start)
528 .icon(IconName::ArrowCircle)
529 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
530 .into_any_element()
531 }),
532 )
533 .into_any()
534 }
535 }
536}