1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7
8 "github.com/charmbracelet/crush/internal/llm/tools"
9 "github.com/charmbracelet/crush/internal/message"
10)
11
12type bedrockProvider struct {
13 *baseProvider
14 region string
15 childProvider Provider
16}
17
18func NewBedrockProvider(base *baseProvider) Provider {
19 // Get AWS region from environment
20 region := base.extraParams["region"]
21 if region == "" {
22 region = "us-east-1" // default region
23 }
24
25 return &bedrockProvider{
26 baseProvider: base,
27 childProvider: NewAnthropicProvider(base, true),
28 }
29}
30
31func (b *bedrockProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
32 if len(b.region) < 2 {
33 return nil, errors.New("no region selected")
34 }
35 regionPrefix := b.region[:2]
36 modelName := model
37 model = fmt.Sprintf("%s.%s", regionPrefix, modelName)
38 messages = b.cleanMessages(messages)
39 return b.childProvider.Send(ctx, model, messages, tools)
40}
41
42func (b *bedrockProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
43 if len(b.region) < 2 {
44 eventChan := make(chan ProviderEvent)
45 go func() {
46 eventChan <- ProviderEvent{
47 Type: EventError,
48 Error: errors.New("no region selected"),
49 }
50 close(eventChan)
51 }()
52 return eventChan
53 }
54 regionPrefix := b.region[:2]
55 modelName := model
56 model = fmt.Sprintf("%s.%s", regionPrefix, modelName)
57 messages = b.cleanMessages(messages)
58 return b.childProvider.Stream(ctx, model, messages, tools)
59}