Merge pull request #557 from charmbracelet/change-summarize-model

Kujtim Hoxha created

fix: summarize provider should be large

Change summary

internal/llm/agent/agent.go | 68 +++++++++++++++++++++-----------------
1 file changed, 37 insertions(+), 31 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -159,11 +159,12 @@ func NewAgent(
 	if err != nil {
 		return nil, err
 	}
+
 	summarizeOpts := []provider.ProviderClientOption{
-		provider.WithModel(config.SelectedModelTypeSmall),
-		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
+		provider.WithModel(config.SelectedModelTypeLarge),
+		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 	}
-	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
+	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 	if err != nil {
 		return nil, err
 	}
@@ -224,7 +225,7 @@ func NewAgent(
 		sessions:            sessions,
 		titleProvider:       titleProvider,
 		summarizeProvider:   summarizeProvider,
-		summarizeProviderID: string(smallModelProviderCfg.ID),
+		summarizeProviderID: string(providerCfg.ID),
 		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 		tools:               csync.NewLazySlice(toolFn),
 	}, nil
@@ -904,54 +905,59 @@ func (a *agent) UpdateModel() error {
 		a.providerID = string(currentProviderCfg.ID)
 	}
 
-	// Check if small model provider has changed (affects title and summarize providers)
+	// Check if providers have changed for title (small) and summarize (large)
 	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 	var smallModelProviderCfg config.ProviderConfig
-
 	for p := range cfg.Providers.Seq() {
 		if p.ID == smallModelCfg.Provider {
 			smallModelProviderCfg = p
 			break
 		}
 	}
-
 	if smallModelProviderCfg.ID == "" {
 		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 	}
 
-	// Check if summarize provider has changed
-	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
-		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
-		if smallModel == nil {
-			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
+	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
+	var largeModelProviderCfg config.ProviderConfig
+	for p := range cfg.Providers.Seq() {
+		if p.ID == largeModelCfg.Provider {
+			largeModelProviderCfg = p
+			break
 		}
+	}
+	if largeModelProviderCfg.ID == "" {
+		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
+	}
 
-		// Recreate title provider
-		titleOpts := []provider.ProviderClientOption{
-			provider.WithModel(config.SelectedModelTypeSmall),
-			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
-			// We want the title to be short, so we limit the max tokens
-			provider.WithMaxTokens(40),
-		}
-		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
-		if err != nil {
-			return fmt.Errorf("failed to create new title provider: %w", err)
-		}
+	// Recreate title provider
+	titleOpts := []provider.ProviderClientOption{
+		provider.WithModel(config.SelectedModelTypeSmall),
+		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+		provider.WithMaxTokens(40),
+	}
+	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
+	if err != nil {
+		return fmt.Errorf("failed to create new title provider: %w", err)
+	}
+	a.titleProvider = newTitleProvider
 
-		// Recreate summarize provider
+	// Recreate summarize provider if provider changed (now large model)
+	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
+		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
+		if largeModel == nil {
+			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
+		}
 		summarizeOpts := []provider.ProviderClientOption{
-			provider.WithModel(config.SelectedModelTypeSmall),
-			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
+			provider.WithModel(config.SelectedModelTypeLarge),
+			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
 		}
-		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
+		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
 		if err != nil {
 			return fmt.Errorf("failed to create new summarize provider: %w", err)
 		}
-
-		// Update the providers and provider ID
-		a.titleProvider = newTitleProvider
 		a.summarizeProvider = newSummarizeProvider
-		a.summarizeProviderID = string(smallModelProviderCfg.ID)
+		a.summarizeProviderID = string(largeModelProviderCfg.ID)
 	}
 
 	return nil