bedrock.go

  1package bedrock
  2
  3import (
  4	"bytes"
  5	"context"
  6	"crypto/sha256"
  7	"encoding/base64"
  8	"encoding/hex"
  9	"encoding/json"
 10	"fmt"
 11	"io"
 12	"net/http"
 13	"time"
 14
 15	"github.com/aws/aws-sdk-go-v2/aws"
 16	"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
 17	"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi"
 18	v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
 19	"github.com/aws/aws-sdk-go-v2/config"
 20	"github.com/tidwall/gjson"
 21	"github.com/tidwall/sjson"
 22
 23	"github.com/anthropics/anthropic-sdk-go/internal/requestconfig"
 24	"github.com/anthropics/anthropic-sdk-go/option"
 25	"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
 26)
 27
 28const DefaultVersion = "bedrock-2023-05-31"
 29
 30var DefaultEndpoints = map[string]bool{
 31	"/v1/complete": true,
 32	"/v1/messages": true,
 33}
 34
 35type eventstreamChunk struct {
 36	Bytes string `json:"bytes"`
 37	P     string `json:"p"`
 38}
 39
 40type eventstreamDecoder struct {
 41	eventstream.Decoder
 42
 43	rc  io.ReadCloser
 44	evt ssestream.Event
 45	err error
 46}
 47
 48func (e *eventstreamDecoder) Close() error {
 49	return e.rc.Close()
 50}
 51
 52func (e *eventstreamDecoder) Err() error {
 53	return e.err
 54}
 55
 56func (e *eventstreamDecoder) Next() bool {
 57	if e.err != nil {
 58		return false
 59	}
 60
 61	msg, err := e.Decoder.Decode(e.rc, nil)
 62	if err != nil {
 63		e.err = err
 64		return false
 65	}
 66
 67	messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader)
 68	if messageType == nil {
 69		e.err = fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader)
 70		return false
 71	}
 72
 73	switch messageType.String() {
 74	case eventstreamapi.EventMessageType:
 75		eventType := msg.Headers.Get(eventstreamapi.EventTypeHeader)
 76		if eventType == nil {
 77			e.err = fmt.Errorf("%s event header not present", eventstreamapi.EventTypeHeader)
 78			return false
 79		}
 80
 81		if eventType.String() == "chunk" {
 82			chunk := eventstreamChunk{}
 83			err = json.Unmarshal(msg.Payload, &chunk)
 84			if err != nil {
 85				e.err = err
 86				return false
 87			}
 88			decoded, err := base64.StdEncoding.DecodeString(chunk.Bytes)
 89			if err != nil {
 90				e.err = err
 91				return false
 92			}
 93			e.evt = ssestream.Event{
 94				Type: gjson.GetBytes(decoded, "type").String(),
 95				Data: decoded,
 96			}
 97		}
 98
 99	case eventstreamapi.ExceptionMessageType:
100		// See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/service/iotsitewise/deserializers.go#L15511C1-L15567C2
101		exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader)
102		if exceptionType == nil {
103			e.err = fmt.Errorf("%s event header not present", eventstreamapi.ExceptionTypeHeader)
104			return false
105		}
106
107		// See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/aws/protocol/restjson/decoder_util.go#L15-L48k
108		var errInfo struct {
109			Code    string
110			Type    string `json:"__type"`
111			Message string
112		}
113		err = json.Unmarshal(msg.Payload, &errInfo)
114		if err != nil && err != io.EOF {
115			e.err = fmt.Errorf("received exception %s: parsing exception payload failed: %w", exceptionType.String(), err)
116			return false
117		}
118
119		errorCode := "UnknownError"
120		errorMessage := errorCode
121		if ev := exceptionType.String(); len(ev) > 0 {
122			errorCode = ev
123		} else if len(errInfo.Code) > 0 {
124			errorCode = errInfo.Code
125		} else if len(errInfo.Type) > 0 {
126			errorCode = errInfo.Type
127		}
128
129		if len(errInfo.Message) > 0 {
130			errorMessage = errInfo.Message
131		}
132		e.err = fmt.Errorf("received exception %s: %s", errorCode, errorMessage)
133		return false
134
135	case eventstreamapi.ErrorMessageType:
136		errorCode := "UnknownError"
137		errorMessage := errorCode
138		if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil {
139			errorCode = header.String()
140		}
141		if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil {
142			errorMessage = header.String()
143		}
144		e.err = fmt.Errorf("received error %s: %s", errorCode, errorMessage)
145		return false
146	}
147
148	return true
149}
150
151func (e *eventstreamDecoder) Event() ssestream.Event {
152	return e.evt
153}
154
155var (
156	_ ssestream.Decoder = &eventstreamDecoder{}
157)
158
159func init() {
160	ssestream.RegisterDecoder("application/vnd.amazon.eventstream", func(rc io.ReadCloser) ssestream.Decoder {
161		return &eventstreamDecoder{rc: rc}
162	})
163}
164
165// WithLoadDefaultConfig returns a request option which loads the default config for Amazon and registers
166// middleware that intercepts request to the Messages API so that this SDK can be used with Amazon Bedrock.
167//
168// If you already have an [aws.Config], it is recommended that you instead call [WithConfig] directly.
169func WithLoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) option.RequestOption {
170	cfg, err := config.LoadDefaultConfig(ctx, optFns...)
171	if err != nil {
172		panic(err)
173	}
174	return WithConfig(cfg)
175}
176
177// WithConfig returns a request option which uses the provided config  and registers middleware that
178// intercepts request to the Messages API so that this SDK can be used with Amazon Bedrock.
179func WithConfig(cfg aws.Config) option.RequestOption {
180	signer := v4.NewSigner()
181	middleware := bedrockMiddleware(signer, cfg)
182
183	return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
184		return rc.Apply(
185			option.WithBaseURL(fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", cfg.Region)),
186			option.WithMiddleware(middleware),
187		)
188	})
189}
190
191func bedrockMiddleware(signer *v4.Signer, cfg aws.Config) option.Middleware {
192	return func(r *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
193		var body []byte
194		if r.Body != nil {
195			body, err = io.ReadAll(r.Body)
196			if err != nil {
197				return nil, err
198			}
199			r.Body.Close()
200
201			if !gjson.GetBytes(body, "anthropic_version").Exists() {
202				body, _ = sjson.SetBytes(body, "anthropic_version", DefaultVersion)
203			}
204
205			if r.Method == http.MethodPost && DefaultEndpoints[r.URL.Path] {
206				model := gjson.GetBytes(body, "model").String()
207				stream := gjson.GetBytes(body, "stream").Bool()
208
209				body, _ = sjson.DeleteBytes(body, "model")
210				body, _ = sjson.DeleteBytes(body, "stream")
211
212				var path string
213				if stream {
214					path = fmt.Sprintf("/model/%s/invoke-with-response-stream", model)
215				} else {
216					path = fmt.Sprintf("/model/%s/invoke", model)
217				}
218
219				r.URL.Path = path
220			}
221
222			reader := bytes.NewReader(body)
223			r.Body = io.NopCloser(reader)
224			r.GetBody = func() (io.ReadCloser, error) {
225				_, err := reader.Seek(0, 0)
226				return io.NopCloser(reader), err
227			}
228			r.ContentLength = int64(len(body))
229		}
230
231		ctx := r.Context()
232		credentials, err := cfg.Credentials.Retrieve(ctx)
233		if err != nil {
234			return nil, err
235		}
236
237		hash := sha256.Sum256(body)
238		err = signer.SignHTTP(ctx, credentials, r, hex.EncodeToString(hash[:]), "bedrock", cfg.Region, time.Now())
239		if err != nil {
240			return nil, err
241		}
242
243		return next(r)
244	}
245}