bedrock.go

 1package provider
 2
 3import (
 4	"context"
 5	"errors"
 6	"fmt"
 7
 8	"github.com/charmbracelet/crush/internal/llm/tools"
 9	"github.com/charmbracelet/crush/internal/message"
10)
11
12type bedrockProvider struct {
13	*baseProvider
14	region        string
15	childProvider Provider
16}
17
18func NewBedrockProvider(base *baseProvider) Provider {
19	// Get AWS region from environment
20	region := base.extraParams["region"]
21	if region == "" {
22		region = "us-east-1" // default region
23	}
24
25	return &bedrockProvider{
26		baseProvider:  base,
27		childProvider: NewAnthropicProvider(base, true),
28	}
29}
30
31func (b *bedrockProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
32	if len(b.region) < 2 {
33		return nil, errors.New("no region selected")
34	}
35	regionPrefix := b.region[:2]
36	modelName := model
37	model = fmt.Sprintf("%s.%s", regionPrefix, modelName)
38	messages = b.cleanMessages(messages)
39	return b.childProvider.Send(ctx, model, messages, tools)
40}
41
42func (b *bedrockProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
43	if len(b.region) < 2 {
44		eventChan := make(chan ProviderEvent)
45		go func() {
46			eventChan <- ProviderEvent{
47				Type:  EventError,
48				Error: errors.New("no region selected"),
49			}
50			close(eventChan)
51		}()
52		return eventChan
53	}
54	regionPrefix := b.region[:2]
55	modelName := model
56	model = fmt.Sprintf("%s.%s", regionPrefix, modelName)
57	messages = b.cleanMessages(messages)
58	return b.childProvider.Stream(ctx, model, messages, tools)
59}