1package bedrock
2
3import (
4 "bytes"
5 "context"
6 "crypto/sha256"
7 "encoding/base64"
8 "encoding/hex"
9 "encoding/json"
10 "fmt"
11 "io"
12 "net/http"
13 "time"
14
15 "github.com/aws/aws-sdk-go-v2/aws"
16 "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
17 "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi"
18 v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
19 "github.com/aws/aws-sdk-go-v2/config"
20 "github.com/tidwall/gjson"
21 "github.com/tidwall/sjson"
22
23 "github.com/anthropics/anthropic-sdk-go/internal/requestconfig"
24 "github.com/anthropics/anthropic-sdk-go/option"
25 "github.com/anthropics/anthropic-sdk-go/packages/ssestream"
26)
27
28const DefaultVersion = "bedrock-2023-05-31"
29
30var DefaultEndpoints = map[string]bool{
31 "/v1/complete": true,
32 "/v1/messages": true,
33}
34
35type eventstreamChunk struct {
36 Bytes string `json:"bytes"`
37 P string `json:"p"`
38}
39
40type eventstreamDecoder struct {
41 eventstream.Decoder
42
43 rc io.ReadCloser
44 evt ssestream.Event
45 err error
46}
47
48func (e *eventstreamDecoder) Close() error {
49 return e.rc.Close()
50}
51
52func (e *eventstreamDecoder) Err() error {
53 return e.err
54}
55
56func (e *eventstreamDecoder) Next() bool {
57 if e.err != nil {
58 return false
59 }
60
61 msg, err := e.Decoder.Decode(e.rc, nil)
62 if err != nil {
63 e.err = err
64 return false
65 }
66
67 messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader)
68 if messageType == nil {
69 e.err = fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader)
70 return false
71 }
72
73 switch messageType.String() {
74 case eventstreamapi.EventMessageType:
75 eventType := msg.Headers.Get(eventstreamapi.EventTypeHeader)
76 if eventType == nil {
77 e.err = fmt.Errorf("%s event header not present", eventstreamapi.EventTypeHeader)
78 return false
79 }
80
81 if eventType.String() == "chunk" {
82 chunk := eventstreamChunk{}
83 err = json.Unmarshal(msg.Payload, &chunk)
84 if err != nil {
85 e.err = err
86 return false
87 }
88 decoded, err := base64.StdEncoding.DecodeString(chunk.Bytes)
89 if err != nil {
90 e.err = err
91 return false
92 }
93 e.evt = ssestream.Event{
94 Type: gjson.GetBytes(decoded, "type").String(),
95 Data: decoded,
96 }
97 }
98
99 case eventstreamapi.ExceptionMessageType:
100 // See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/service/iotsitewise/deserializers.go#L15511C1-L15567C2
101 exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader)
102 if exceptionType == nil {
103 e.err = fmt.Errorf("%s event header not present", eventstreamapi.ExceptionTypeHeader)
104 return false
105 }
106
107 // See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/aws/protocol/restjson/decoder_util.go#L15-L48k
108 var errInfo struct {
109 Code string
110 Type string `json:"__type"`
111 Message string
112 }
113 err = json.Unmarshal(msg.Payload, &errInfo)
114 if err != nil && err != io.EOF {
115 e.err = fmt.Errorf("received exception %s: parsing exception payload failed: %w", exceptionType.String(), err)
116 return false
117 }
118
119 errorCode := "UnknownError"
120 errorMessage := errorCode
121 if ev := exceptionType.String(); len(ev) > 0 {
122 errorCode = ev
123 } else if len(errInfo.Code) > 0 {
124 errorCode = errInfo.Code
125 } else if len(errInfo.Type) > 0 {
126 errorCode = errInfo.Type
127 }
128
129 if len(errInfo.Message) > 0 {
130 errorMessage = errInfo.Message
131 }
132 e.err = fmt.Errorf("received exception %s: %s", errorCode, errorMessage)
133 return false
134
135 case eventstreamapi.ErrorMessageType:
136 errorCode := "UnknownError"
137 errorMessage := errorCode
138 if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil {
139 errorCode = header.String()
140 }
141 if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil {
142 errorMessage = header.String()
143 }
144 e.err = fmt.Errorf("received error %s: %s", errorCode, errorMessage)
145 return false
146 }
147
148 return true
149}
150
151func (e *eventstreamDecoder) Event() ssestream.Event {
152 return e.evt
153}
154
155var (
156 _ ssestream.Decoder = &eventstreamDecoder{}
157)
158
159func init() {
160 ssestream.RegisterDecoder("application/vnd.amazon.eventstream", func(rc io.ReadCloser) ssestream.Decoder {
161 return &eventstreamDecoder{rc: rc}
162 })
163}
164
165// WithLoadDefaultConfig returns a request option which loads the default config for Amazon and registers
166// middleware that intercepts request to the Messages API so that this SDK can be used with Amazon Bedrock.
167//
168// If you already have an [aws.Config], it is recommended that you instead call [WithConfig] directly.
169func WithLoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) option.RequestOption {
170 cfg, err := config.LoadDefaultConfig(ctx, optFns...)
171 if err != nil {
172 panic(err)
173 }
174 return WithConfig(cfg)
175}
176
177// WithConfig returns a request option which uses the provided config and registers middleware that
178// intercepts request to the Messages API so that this SDK can be used with Amazon Bedrock.
179func WithConfig(cfg aws.Config) option.RequestOption {
180 signer := v4.NewSigner()
181 middleware := bedrockMiddleware(signer, cfg)
182
183 return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
184 return rc.Apply(
185 option.WithBaseURL(fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", cfg.Region)),
186 option.WithMiddleware(middleware),
187 )
188 })
189}
190
191func bedrockMiddleware(signer *v4.Signer, cfg aws.Config) option.Middleware {
192 return func(r *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
193 var body []byte
194 if r.Body != nil {
195 body, err = io.ReadAll(r.Body)
196 if err != nil {
197 return nil, err
198 }
199 r.Body.Close()
200
201 if !gjson.GetBytes(body, "anthropic_version").Exists() {
202 body, _ = sjson.SetBytes(body, "anthropic_version", DefaultVersion)
203 }
204
205 if r.Method == http.MethodPost && DefaultEndpoints[r.URL.Path] {
206 model := gjson.GetBytes(body, "model").String()
207 stream := gjson.GetBytes(body, "stream").Bool()
208
209 body, _ = sjson.DeleteBytes(body, "model")
210 body, _ = sjson.DeleteBytes(body, "stream")
211
212 var path string
213 if stream {
214 path = fmt.Sprintf("/model/%s/invoke-with-response-stream", model)
215 } else {
216 path = fmt.Sprintf("/model/%s/invoke", model)
217 }
218
219 r.URL.Path = path
220 }
221
222 reader := bytes.NewReader(body)
223 r.Body = io.NopCloser(reader)
224 r.GetBody = func() (io.ReadCloser, error) {
225 _, err := reader.Seek(0, 0)
226 return io.NopCloser(reader), err
227 }
228 r.ContentLength = int64(len(body))
229 }
230
231 ctx := r.Context()
232 credentials, err := cfg.Credentials.Retrieve(ctx)
233 if err != nil {
234 return nil, err
235 }
236
237 hash := sha256.Sum256(body)
238 err = signer.SignHTTP(ctx, credentials, r, hex.EncodeToString(hash[:]), "bedrock", cfg.Region, time.Now())
239 if err != nil {
240 return nil, err
241 }
242
243 return next(r)
244 }
245}