add bedrock support

Kujtim Hoxha created

Change summary

go.mod                                 | 14 ++++
go.sum                                 | 28 +++++++++
internal/config/config.go              |  7 ++
internal/llm/agent/agent.go            | 23 +++++++
internal/llm/models/models.go          | 16 +++++
internal/llm/provider/anthropic.go     | 33 +++++++++-
internal/llm/provider/bedrock.go       | 87 ++++++++++++++++++++++++++++
internal/tui/components/repl/editor.go | 16 ++++
8 files changed, 217 insertions(+), 7 deletions(-)

Detailed changes

go.mod 🔗

@@ -47,6 +47,20 @@ require (
 	github.com/alecthomas/chroma/v2 v2.15.0 // indirect
 	github.com/andybalholm/cascadia v1.3.2 // indirect
 	github.com/atotto/clipboard v0.1.4 // indirect
+	github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
+	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
+	github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect
+	github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect
+	github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
+	github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect
+	github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect
+	github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect
+	github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect
+	github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect
+	github.com/aws/smithy-go v1.20.3 // indirect
 	github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
 	github.com/aymerick/douceur v0.2.0 // indirect
 	github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect

go.sum 🔗

@@ -28,6 +28,34 @@ github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60M
 github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c=
 github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
 github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
+github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
+github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM=
+github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90=
+github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII=
+github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM=
+github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw=
+github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE=
+github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ=
+github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE=
+github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
 github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
 github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
 github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=

internal/config/config.go 🔗

@@ -36,6 +36,11 @@ type Model struct {
 	// TODO: Maybe support multiple models for different purposes
 }
 
+type AnthropicConfig struct {
+	DisableCache bool `json:"disableCache"`
+	UseBedrock   bool `json:"useBedrock"`
+}
+
 type Provider struct {
 	APIKey  string `json:"apiKey"`
 	Enabled bool   `json:"enabled"`
@@ -130,6 +135,8 @@ func Load(debug bool) error {
 			defaultModelSet = true
 		}
 	}
+
+	viper.SetDefault("providers.bedrock.enabled", true)
 	// TODO: add more providers
 	cfg = &Config{}
 

internal/llm/agent/agent.go 🔗

@@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			return nil, nil, err
 		}
 
+	case models.ProviderBedrock:
+		var err error
+		agentProvider, err = provider.NewBedrockProvider(
+			provider.WithBedrockSystemMessage(
+				prompt.CoderAnthropicSystemPrompt(),
+			),
+			provider.WithBedrockMaxTokens(maxTokens),
+			provider.WithBedrockModel(model),
+		)
+		if err != nil {
+			return nil, nil, err
+		}
+		titleGenerator, err = provider.NewBedrockProvider(
+			provider.WithBedrockSystemMessage(
+				prompt.TitlePrompt(),
+			),
+			provider.WithBedrockMaxTokens(maxTokens),
+			provider.WithBedrockModel(model),
+		)
+		if err != nil {
+			return nil, nil, err
+		}
+
 	}
 
 	return agentProvider, titleGenerator, nil

internal/llm/models/models.go 🔗

@@ -31,11 +31,15 @@ const (
 
 	// GROQ
 	QWENQwq ModelID = "qwen-qwq"
+
+	// Bedrock
+	BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
 )
 
 const (
 	ProviderOpenAI    ModelProvider = "openai"
 	ProviderAnthropic ModelProvider = "anthropic"
+	ProviderBedrock   ModelProvider = "bedrock"
 	ProviderGemini    ModelProvider = "gemini"
 	ProviderGROQ      ModelProvider = "groq"
 )
@@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{
 		CostPer1MOutCached: 0,
 		CostPer1MOut:       0,
 	},
+
+	// Bedrock
+	BedrockClaude37Sonnet: {
+		ID:                 BedrockClaude37Sonnet,
+		Name:               "Bedrock: Claude 3.7 Sonnet",
+		Provider:           ProviderBedrock,
+		APIModel:           "anthropic.claude-3-7-sonnet-20250219-v1:0",
+		CostPer1MIn:        3.0,
+		CostPer1MInCached:  3.75,
+		CostPer1MOutCached: 0.30,
+		CostPer1MOut:       15.0,
+	},
 }

internal/llm/provider/anthropic.go 🔗

@@ -9,6 +9,7 @@ import (
 	"time"
 
 	"github.com/anthropics/anthropic-sdk-go"
+	"github.com/anthropics/anthropic-sdk-go/bedrock"
 	"github.com/anthropics/anthropic-sdk-go/option"
 	"github.com/kujtimiihoxha/termai/internal/llm/models"
 	"github.com/kujtimiihoxha/termai/internal/llm/tools"
@@ -21,6 +22,8 @@ type anthropicProvider struct {
 	maxTokens     int64
 	apiKey        string
 	systemMessage string
+	useBedrock    bool
+	disableCache  bool
 }
 
 type AnthropicOption func(*anthropicProvider)
@@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption {
 	}
 }
 
+func WithAnthropicBedrock() AnthropicOption {
+	return func(a *anthropicProvider) {
+		a.useBedrock = true
+	}
+}
+
+func WithAnthropicDisableCache() AnthropicOption {
+	return func(a *anthropicProvider) {
+		a.disableCache = true
+	}
+}
+
 func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 	provider := &anthropicProvider{
 		maxTokens: 1024,
@@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 		return nil, errors.New("system message is required")
 	}
 
-	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
+	anthropicOptions := []option.RequestOption{}
+
+	if provider.apiKey != "" {
+		anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
+	}
+	if provider.useBedrock {
+		anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
+	}
+
+	provider.client = anthropic.NewClient(anthropicOptions...)
 	return provider, nil
 }
 
@@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an
 			},
 		}
 
-		if i == len(tools)-1 {
+		if i == len(tools)-1 && !a.disableCache {
 			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
 				Type: "ephemeral",
 			}
@@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
 		switch msg.Role {
 		case message.User:
 			content := anthropic.NewTextBlock(msg.Content().String())
-			if cachedBlocks < 2 {
+			if cachedBlocks < 2 && !a.disableCache {
 				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
 					Type: "ephemeral",
 				}
@@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
 			blocks := []anthropic.ContentBlockParamUnion{}
 			if msg.Content().String() != "" {
 				content := anthropic.NewTextBlock(msg.Content().String())
-				if cachedBlocks < 2 {
+				if cachedBlocks < 2 && !a.disableCache {
 					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
 						Type: "ephemeral",
 					}
@@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
 
 	return anthropicMessages
 }
-

internal/llm/provider/bedrock.go 🔗

@@ -0,0 +1,87 @@
+package provider
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+	"strings"
+
+	"github.com/kujtimiihoxha/termai/internal/llm/models"
+	"github.com/kujtimiihoxha/termai/internal/llm/tools"
+	"github.com/kujtimiihoxha/termai/internal/message"
+)
+
+type bedrockProvider struct {
+	childProvider Provider
+	model         models.Model
+	maxTokens     int64
+	systemMessage string
+}
+
+func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+	return b.childProvider.SendMessages(ctx, messages, tools)
+}
+
+func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
+	return b.childProvider.StreamResponse(ctx, messages, tools)
+}
+
+func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
+	provider := &bedrockProvider{}
+	for _, opt := range opts {
+		opt(provider)
+	}
+
+	// based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
+	region := os.Getenv("AWS_REGION")
+	if region == "" {
+		region = os.Getenv("AWS_DEFAULT_REGION")
+	}
+
+	if region == "" {
+		return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required")
+	}
+	if len(region) < 2 {
+		return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
+	}
+	regionPrefix := region[:2]
+	provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
+
+	if strings.Contains(string(provider.model.APIModel), "anthropic") {
+		anthropic, err := NewAnthropicProvider(
+			WithAnthropicModel(provider.model),
+			WithAnthropicMaxTokens(provider.maxTokens),
+			WithAnthropicSystemMessage(provider.systemMessage),
+			WithAnthropicBedrock(),
+			WithAnthropicDisableCache(),
+		)
+		provider.childProvider = anthropic
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		return nil, errors.New("unsupported model for bedrock provider")
+	}
+	return provider, nil
+}
+
+type BedrockOption func(*bedrockProvider)
+
+func WithBedrockSystemMessage(message string) BedrockOption {
+	return func(a *bedrockProvider) {
+		a.systemMessage = message
+	}
+}
+
+func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
+	return func(a *bedrockProvider) {
+		a.maxTokens = maxTokens
+	}
+}
+
+func WithBedrockModel(model models.Model) BedrockOption {
+	return func(a *bedrockProvider) {
+		a.model = model
+	}
+}

internal/tui/components/repl/editor.go 🔗

@@ -1,6 +1,7 @@
 package repl
 
 import (
+	"log"
 	"strings"
 
 	"github.com/charmbracelet/bubbles/key"
@@ -138,11 +139,22 @@ func (m *editorCmp) SetSize(width int, height int) {
 
 func (m *editorCmp) Send() tea.Cmd {
 	return func() tea.Msg {
-		messages, _ := m.app.Messages.List(m.sessionID)
+		messages, err := m.app.Messages.List(m.sessionID)
+		log.Printf("error: %v", err)
+		log.Printf("messages: %v", messages)
+
+		if err != nil {
+			return util.ReportError(err)
+		}
 		if hasUnfinishedMessages(messages) {
 			return util.ReportWarn("Assistant is still working on the previous message")
 		}
-		a, _ := agent.NewCoderAgent(m.app)
+		a, err := agent.NewCoderAgent(m.app)
+		log.Printf("error: %v", err)
+		log.Printf("agent: %v", a)
+		if err != nil {
+			return util.ReportError(err)
+		}
 
 		content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
 		go a.Generate(m.sessionID, content)