1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use futures::{Stream, TryFutureExt, stream};
4use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
5use http_client::HttpClient;
6use language_model::{
7 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
9 LanguageModelToolUseId, MessageContent, StopReason,
10};
11use language_model::{
12 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
13 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
14 LanguageModelRequest, RateLimiter, Role,
15};
16use ollama::{
17 ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
18 OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
19};
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use settings::{Settings, SettingsStore};
23use std::pin::Pin;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::{collections::BTreeMap, sync::Arc};
26use ui::{ButtonLike, Indicator, List, prelude::*};
27use util::ResultExt;
28
29use crate::AllLanguageModelSettings;
30use crate::ui::InstructionListItem;
31
32const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
33const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
34const OLLAMA_SITE: &str = "https://ollama.com/";
35
36const PROVIDER_ID: &str = "ollama";
37const PROVIDER_NAME: &str = "Ollama";
38
39#[derive(Default, Debug, Clone, PartialEq)]
40pub struct OllamaSettings {
41 pub api_url: String,
42 pub available_models: Vec<AvailableModel>,
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
46pub struct AvailableModel {
47 /// The model name in the Ollama API (e.g. "llama3.2:latest")
48 pub name: String,
49 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
50 pub display_name: Option<String>,
51 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
52 pub max_tokens: usize,
53 /// The number of seconds to keep the connection open after the last request
54 pub keep_alive: Option<KeepAlive>,
55 /// Whether the model supports tools
56 pub supports_tools: Option<bool>,
57 /// Whether to enable think mode
58 pub supports_thinking: Option<bool>,
59}
60
61pub struct OllamaLanguageModelProvider {
62 http_client: Arc<dyn HttpClient>,
63 state: gpui::Entity<State>,
64}
65
66pub struct State {
67 http_client: Arc<dyn HttpClient>,
68 available_models: Vec<ollama::Model>,
69 fetch_model_task: Option<Task<Result<()>>>,
70 _subscription: Subscription,
71}
72
73impl State {
74 fn is_authenticated(&self) -> bool {
75 !self.available_models.is_empty()
76 }
77
78 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
79 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
80 let http_client = Arc::clone(&self.http_client);
81 let api_url = settings.api_url.clone();
82
83 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
84 cx.spawn(async move |this, cx| {
85 let models = get_models(http_client.as_ref(), &api_url, None).await?;
86
87 let tasks = models
88 .into_iter()
89 // Since there is no metadata from the Ollama API
90 // indicating which models are embedding models,
91 // simply filter out models with "-embed" in their name
92 .filter(|model| !model.name.contains("-embed"))
93 .map(|model| {
94 let http_client = Arc::clone(&http_client);
95 let api_url = api_url.clone();
96 async move {
97 let name = model.name.as_str();
98 let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
99 let ollama_model = ollama::Model::new(
100 name,
101 None,
102 None,
103 Some(capabilities.supports_tools()),
104 Some(capabilities.supports_thinking()),
105 );
106 Ok(ollama_model)
107 }
108 });
109
110 // Rate-limit capability fetches
111 // since there is an arbitrary number of models available
112 let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
113 .buffer_unordered(5)
114 .collect::<Vec<Result<_>>>()
115 .await
116 .into_iter()
117 .collect::<Result<Vec<_>>>()?;
118
119 ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
120
121 this.update(cx, |this, cx| {
122 this.available_models = ollama_models;
123 cx.notify();
124 })
125 })
126 }
127
128 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
129 let task = self.fetch_models(cx);
130 self.fetch_model_task.replace(task);
131 }
132
133 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
134 if self.is_authenticated() {
135 return Task::ready(Ok(()));
136 }
137
138 let fetch_models_task = self.fetch_models(cx);
139 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
140 }
141}
142
143impl OllamaLanguageModelProvider {
144 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
145 let this = Self {
146 http_client: http_client.clone(),
147 state: cx.new(|cx| {
148 let subscription = cx.observe_global::<SettingsStore>({
149 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
150 move |this: &mut State, cx| {
151 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
152 if &settings != new_settings {
153 settings = new_settings.clone();
154 this.restart_fetch_models_task(cx);
155 cx.notify();
156 }
157 }
158 });
159
160 State {
161 http_client,
162 available_models: Default::default(),
163 fetch_model_task: None,
164 _subscription: subscription,
165 }
166 }),
167 };
168 this.state
169 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
170 this
171 }
172}
173
174impl LanguageModelProviderState for OllamaLanguageModelProvider {
175 type ObservableEntity = State;
176
177 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
178 Some(self.state.clone())
179 }
180}
181
182impl LanguageModelProvider for OllamaLanguageModelProvider {
183 fn id(&self) -> LanguageModelProviderId {
184 LanguageModelProviderId(PROVIDER_ID.into())
185 }
186
187 fn name(&self) -> LanguageModelProviderName {
188 LanguageModelProviderName(PROVIDER_NAME.into())
189 }
190
191 fn icon(&self) -> IconName {
192 IconName::AiOllama
193 }
194
195 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
196 self.provided_models(cx).into_iter().next()
197 }
198
199 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
200 self.default_model(cx)
201 }
202
203 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
204 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
205
206 // Add models from the Ollama API
207 for model in self.state.read(cx).available_models.iter() {
208 models.insert(model.name.clone(), model.clone());
209 }
210
211 // Override with available models from settings
212 for model in AllLanguageModelSettings::get_global(cx)
213 .ollama
214 .available_models
215 .iter()
216 {
217 models.insert(
218 model.name.clone(),
219 ollama::Model {
220 name: model.name.clone(),
221 display_name: model.display_name.clone(),
222 max_tokens: model.max_tokens,
223 keep_alive: model.keep_alive.clone(),
224 supports_tools: model.supports_tools,
225 supports_thinking: model.supports_thinking,
226 },
227 );
228 }
229
230 models
231 .into_values()
232 .map(|model| {
233 Arc::new(OllamaLanguageModel {
234 id: LanguageModelId::from(model.name.clone()),
235 model: model.clone(),
236 http_client: self.http_client.clone(),
237 request_limiter: RateLimiter::new(4),
238 }) as Arc<dyn LanguageModel>
239 })
240 .collect()
241 }
242
243 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
244 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
245 let http_client = self.http_client.clone();
246 let api_url = settings.api_url.clone();
247 let id = model.id().0.to_string();
248 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
249 .detach_and_log_err(cx);
250 }
251
252 fn is_authenticated(&self, cx: &App) -> bool {
253 self.state.read(cx).is_authenticated()
254 }
255
256 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
257 self.state.update(cx, |state, cx| state.authenticate(cx))
258 }
259
260 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
261 let state = self.state.clone();
262 cx.new(|cx| ConfigurationView::new(state, window, cx))
263 .into()
264 }
265
266 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
267 self.state.update(cx, |state, cx| state.fetch_models(cx))
268 }
269}
270
271pub struct OllamaLanguageModel {
272 id: LanguageModelId,
273 model: ollama::Model,
274 http_client: Arc<dyn HttpClient>,
275 request_limiter: RateLimiter,
276}
277
278impl OllamaLanguageModel {
279 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
280 ChatRequest {
281 model: self.model.name.clone(),
282 messages: request
283 .messages
284 .into_iter()
285 .map(|msg| match msg.role {
286 Role::User => ChatMessage::User {
287 content: msg.string_contents(),
288 },
289 Role::Assistant => {
290 let content = msg.string_contents();
291 let thinking = msg.content.into_iter().find_map(|content| match content {
292 MessageContent::Thinking { text, .. } if !text.is_empty() => Some(text),
293 _ => None,
294 });
295 ChatMessage::Assistant {
296 content,
297 tool_calls: None,
298 thinking,
299 }
300 }
301 Role::System => ChatMessage::System {
302 content: msg.string_contents(),
303 },
304 })
305 .collect(),
306 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
307 stream: true,
308 options: Some(ChatOptions {
309 num_ctx: Some(self.model.max_tokens),
310 stop: Some(request.stop),
311 temperature: request.temperature.or(Some(1.0)),
312 ..Default::default()
313 }),
314 think: self.model.supports_thinking,
315 tools: request.tools.into_iter().map(tool_into_ollama).collect(),
316 }
317 }
318}
319
320impl LanguageModel for OllamaLanguageModel {
321 fn id(&self) -> LanguageModelId {
322 self.id.clone()
323 }
324
325 fn name(&self) -> LanguageModelName {
326 LanguageModelName::from(self.model.display_name().to_string())
327 }
328
329 fn provider_id(&self) -> LanguageModelProviderId {
330 LanguageModelProviderId(PROVIDER_ID.into())
331 }
332
333 fn provider_name(&self) -> LanguageModelProviderName {
334 LanguageModelProviderName(PROVIDER_NAME.into())
335 }
336
337 fn supports_tools(&self) -> bool {
338 self.model.supports_tools.unwrap_or(false)
339 }
340
341 fn supports_images(&self) -> bool {
342 false
343 }
344
345 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
346 match choice {
347 LanguageModelToolChoice::Auto => false,
348 LanguageModelToolChoice::Any => false,
349 LanguageModelToolChoice::None => false,
350 }
351 }
352
353 fn telemetry_id(&self) -> String {
354 format!("ollama/{}", self.model.id())
355 }
356
357 fn max_token_count(&self) -> usize {
358 self.model.max_token_count()
359 }
360
361 fn count_tokens(
362 &self,
363 request: LanguageModelRequest,
364 _cx: &App,
365 ) -> BoxFuture<'static, Result<usize>> {
366 // There is no endpoint for this _yet_ in Ollama
367 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
368 let token_count = request
369 .messages
370 .iter()
371 .map(|msg| msg.string_contents().chars().count())
372 .sum::<usize>()
373 / 4;
374
375 async move { Ok(token_count) }.boxed()
376 }
377
378 fn stream_completion(
379 &self,
380 request: LanguageModelRequest,
381 cx: &AsyncApp,
382 ) -> BoxFuture<
383 'static,
384 Result<
385 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
386 >,
387 > {
388 let request = self.to_ollama_request(request);
389
390 let http_client = self.http_client.clone();
391 let Ok(api_url) = cx.update(|cx| {
392 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
393 settings.api_url.clone()
394 }) else {
395 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
396 };
397
398 let future = self.request_limiter.stream(async move {
399 let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
400 let stream = map_to_language_model_completion_events(stream);
401 Ok(stream)
402 });
403
404 future.map_ok(|f| f.boxed()).boxed()
405 }
406}
407
408fn map_to_language_model_completion_events(
409 stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
410) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
411 // Used for creating unique tool use ids
412 static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
413
414 struct State {
415 stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
416 used_tools: bool,
417 }
418
419 // We need to create a ToolUse and Stop event from a single
420 // response from the original stream
421 let stream = stream::unfold(
422 State {
423 stream,
424 used_tools: false,
425 },
426 async move |mut state| {
427 let response = state.stream.next().await?;
428
429 let delta = match response {
430 Ok(delta) => delta,
431 Err(e) => {
432 let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
433 return Some((vec![event], state));
434 }
435 };
436
437 let mut events = Vec::new();
438
439 match delta.message {
440 ChatMessage::User { content } => {
441 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
442 }
443 ChatMessage::System { content } => {
444 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
445 }
446 ChatMessage::Assistant {
447 content,
448 tool_calls,
449 thinking,
450 } => {
451 if let Some(text) = thinking {
452 events.push(Ok(LanguageModelCompletionEvent::Thinking {
453 text,
454 signature: None,
455 }));
456 }
457
458 if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
459 match tool_call {
460 OllamaToolCall::Function(function) => {
461 let tool_id = format!(
462 "{}-{}",
463 &function.name,
464 TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
465 );
466 let event =
467 LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
468 id: LanguageModelToolUseId::from(tool_id),
469 name: Arc::from(function.name),
470 raw_input: function.arguments.to_string(),
471 input: function.arguments,
472 is_input_complete: true,
473 });
474 events.push(Ok(event));
475 state.used_tools = true;
476 }
477 }
478 } else if !content.is_empty() {
479 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
480 }
481 }
482 };
483
484 if delta.done {
485 if state.used_tools {
486 state.used_tools = false;
487 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
488 } else {
489 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
490 }
491 }
492
493 Some((events, state))
494 },
495 );
496
497 stream.flat_map(futures::stream::iter)
498}
499
500struct ConfigurationView {
501 state: gpui::Entity<State>,
502 loading_models_task: Option<Task<()>>,
503}
504
505impl ConfigurationView {
506 pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
507 let loading_models_task = Some(cx.spawn_in(window, {
508 let state = state.clone();
509 async move |this, cx| {
510 if let Some(task) = state
511 .update(cx, |state, cx| state.authenticate(cx))
512 .log_err()
513 {
514 task.await.log_err();
515 }
516 this.update(cx, |this, cx| {
517 this.loading_models_task = None;
518 cx.notify();
519 })
520 .log_err();
521 }
522 }));
523
524 Self {
525 state,
526 loading_models_task,
527 }
528 }
529
530 fn retry_connection(&self, cx: &mut App) {
531 self.state
532 .update(cx, |state, cx| state.fetch_models(cx))
533 .detach_and_log_err(cx);
534 }
535}
536
537impl Render for ConfigurationView {
538 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
539 let is_authenticated = self.state.read(cx).is_authenticated();
540
541 let ollama_intro =
542 "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
543
544 if self.loading_models_task.is_some() {
545 div().child(Label::new("Loading models...")).into_any()
546 } else {
547 v_flex()
548 .gap_2()
549 .child(
550 v_flex().gap_1().child(Label::new(ollama_intro)).child(
551 List::new()
552 .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
553 .child(InstructionListItem::text_only(
554 "Once installed, try `ollama run llama3.2`",
555 )),
556 ),
557 )
558 .child(
559 h_flex()
560 .w_full()
561 .justify_between()
562 .gap_2()
563 .child(
564 h_flex()
565 .w_full()
566 .gap_2()
567 .map(|this| {
568 if is_authenticated {
569 this.child(
570 Button::new("ollama-site", "Ollama")
571 .style(ButtonStyle::Subtle)
572 .icon(IconName::ArrowUpRight)
573 .icon_size(IconSize::XSmall)
574 .icon_color(Color::Muted)
575 .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
576 .into_any_element(),
577 )
578 } else {
579 this.child(
580 Button::new(
581 "download_ollama_button",
582 "Download Ollama",
583 )
584 .style(ButtonStyle::Subtle)
585 .icon(IconName::ArrowUpRight)
586 .icon_size(IconSize::XSmall)
587 .icon_color(Color::Muted)
588 .on_click(move |_, _, cx| {
589 cx.open_url(OLLAMA_DOWNLOAD_URL)
590 })
591 .into_any_element(),
592 )
593 }
594 })
595 .child(
596 Button::new("view-models", "All Models")
597 .style(ButtonStyle::Subtle)
598 .icon(IconName::ArrowUpRight)
599 .icon_size(IconSize::XSmall)
600 .icon_color(Color::Muted)
601 .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
602 ),
603 )
604 .map(|this| {
605 if is_authenticated {
606 this.child(
607 ButtonLike::new("connected")
608 .disabled(true)
609 .cursor_style(gpui::CursorStyle::Arrow)
610 .child(
611 h_flex()
612 .gap_2()
613 .child(Indicator::dot().color(Color::Success))
614 .child(Label::new("Connected"))
615 .into_any_element(),
616 ),
617 )
618 } else {
619 this.child(
620 Button::new("retry_ollama_models", "Connect")
621 .icon_position(IconPosition::Start)
622 .icon_size(IconSize::XSmall)
623 .icon(IconName::Play)
624 .on_click(cx.listener(move |this, _, _, cx| {
625 this.retry_connection(cx)
626 })),
627 )
628 }
629 })
630 )
631 .into_any()
632 }
633 }
634}
635
636fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
637 ollama::OllamaTool::Function {
638 function: OllamaFunctionTool {
639 name: tool.name,
640 description: Some(tool.description),
641 parameters: Some(tool.input_schema),
642 },
643 }
644}