build(deps): bump github.com/apache/thrift from 0.12.0 to 0.13.0 (#15947)

Bumps [github.com/apache/thrift](https://github.com/apache/thrift) from 0.12.0 to 0.13.0.
- [Release notes](https://github.com/apache/thrift/releases)
- [Changelog](https://github.com/apache/thrift/blob/master/CHANGES.md)
- [Commits](https://github.com/apache/thrift/compare/v0.12.0...v0.13.0)

---
updated-dependencies:
- dependency-name: github.com/apache/thrift
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
dependabot[bot]
2023-02-18 13:03:22 +08:00
committed by GitHub
parent 31021c79bb
commit ebdbbdca1f
17 changed files with 1233 additions and 30 deletions

View File

@@ -1,5 +1,5 @@
Apache Thrift
Copyright 2006-2017 The Apache Software Foundation.
Copyright (C) 2006 - 2019, The Apache Software Foundation
This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).

View File

@@ -28,6 +28,9 @@ const (
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
)
var defaultApplicationExceptionMessage = map[int32]string{
@@ -39,6 +42,9 @@ var defaultApplicationExceptionMessage = map[int32]string{
MISSING_RESULT: "missing result",
INTERNAL_ERROR: "unknown internal error",
PROTOCOL_ERROR: "unknown protocol error",
INVALID_TRANSFORM: "Invalid transform",
INVALID_PROTOCOL: "Invalid protocol",
UNSUPPORTED_CLIENT_TYPE: "Unsupported client type",
}
// Application level Thrift exception

View File

@@ -32,8 +32,6 @@ import (
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
reader io.Reader
writer io.Writer
strictRead bool
strictWrite bool
buffer [64]byte
@@ -55,8 +53,6 @@ func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProt
} else {
p.trans = NewTRichTransport(t)
}
p.reader = p.trans
p.writer = p.trans
return p
}
@@ -192,21 +188,21 @@ func (p *TBinaryProtocol) WriteByte(value int8) error {
func (p *TBinaryProtocol) WriteI16(value int16) error {
v := p.buffer[0:2]
binary.BigEndian.PutUint16(v, uint16(value))
_, e := p.writer.Write(v)
_, e := p.trans.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI32(value int32) error {
v := p.buffer[0:4]
binary.BigEndian.PutUint32(v, uint32(value))
_, e := p.writer.Write(v)
_, e := p.trans.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI64(value int64) error {
v := p.buffer[0:8]
binary.BigEndian.PutUint64(v, uint64(value))
_, err := p.writer.Write(v)
_, err := p.trans.Write(v)
return NewTProtocolException(err)
}
@@ -228,7 +224,7 @@ func (p *TBinaryProtocol) WriteBinary(value []byte) error {
if e != nil {
return e
}
_, err := p.writer.Write(value)
_, err := p.trans.Write(value)
return NewTProtocolException(err)
}
@@ -468,7 +464,7 @@ func (p *TBinaryProtocol) Transport() TTransport {
}
func (p *TBinaryProtocol) readAll(buf []byte) error {
_, err := io.ReadFull(p.reader, buf)
_, err := io.ReadFull(p.trans, buf)
return NewTProtocolException(err)
}

View File

@@ -24,6 +24,16 @@ func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClien
}
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
// Set headers from context object on THeaderProtocol
if headerProt, ok := oprot.(*THeaderProtocol); ok {
headerProt.ClearWriteHeaders()
for _, key := range GetWriteHeaderList(ctx) {
if value, ok := GetHeader(ctx, key); ok {
headerProt.SetWriteHeader(key, value)
}
}
}
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
return err
}

View File

@@ -93,7 +93,21 @@ func (p *TFramedTransport) Read(buf []byte) (l int, err error) {
l, err = p.Read(tmp)
copy(buf, tmp)
if err == nil {
err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf)))
// Note: It's important to only return an error when l
// is zero.
// In io.Reader.Read interface, it's perfectly fine to
// return partial data and nil error, which means
// "This is all the data we have right now without
// blocking. If you need the full data, call Read again
// or use io.ReadFull instead".
// Returning partial data with an error actually means
// there's no more data after the partial data just
// returned, which is not true in this case
// (it might be that the other end just haven't written
// them yet).
if l == 0 {
err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf)))
}
return
}
}

View File

@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
)
// See https://godoc.org/context#WithValue on why do we need the unexported typedefs.
type (
headerKey string
headerKeyList int
)
// Values for headerKeyList.
const (
headerKeyListRead headerKeyList = iota
headerKeyListWrite
)
// SetHeader sets a header in the context.
func SetHeader(ctx context.Context, key, value string) context.Context {
return context.WithValue(
ctx,
headerKey(key),
value,
)
}
// GetHeader returns a value of the given header from the context.
func GetHeader(ctx context.Context, key string) (value string, ok bool) {
if v := ctx.Value(headerKey(key)); v != nil {
value, ok = v.(string)
}
return
}
// SetReadHeaderList sets the key list of read THeaders in the context.
func SetReadHeaderList(ctx context.Context, keys []string) context.Context {
return context.WithValue(
ctx,
headerKeyListRead,
keys,
)
}
// GetReadHeaderList returns the key list of read THeaders from the context.
func GetReadHeaderList(ctx context.Context) []string {
if v := ctx.Value(headerKeyListRead); v != nil {
if value, ok := v.([]string); ok {
return value
}
}
return nil
}
// SetWriteHeaderList sets the key list of THeaders to write in the context.
func SetWriteHeaderList(ctx context.Context, keys []string) context.Context {
return context.WithValue(
ctx,
headerKeyListWrite,
keys,
)
}
// GetWriteHeaderList returns the key list of THeaders to write from the context.
func GetWriteHeaderList(ctx context.Context) []string {
if v := ctx.Value(headerKeyListWrite); v != nil {
if value, ok := v.([]string); ok {
return value
}
}
return nil
}
// AddReadTHeaderToContext adds the whole THeader headers into context.
func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) context.Context {
keys := make([]string, 0, len(headers))
for key, value := range headers {
ctx = SetHeader(ctx, key, value)
keys = append(keys, key)
}
return SetReadHeaderList(ctx, keys)
}

View File

@@ -0,0 +1,305 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
)
// THeaderProtocol is a thrift protocol that implements THeader:
// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md
//
// It supports either binary or compact protocol as the wrapped protocol.
//
// Most of the THeader handlings are happening inside THeaderTransport.
type THeaderProtocol struct {
transport *THeaderTransport
// Will be initialized on first read/write.
protocol TProtocol
}
// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
// transport. The passed in transport will be wrapped with THeaderTransport.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
t := NewTHeaderTransport(trans)
p, _ := THeaderProtocolDefault.GetProtocol(t)
return &THeaderProtocol{
transport: t,
protocol: p,
}
}
type tHeaderProtocolFactory struct{}
func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTHeaderProtocol(trans)
}
// NewTHeaderProtocolFactory creates a factory for THeader.
//
// It's a wrapper for NewTHeaderProtocol
func NewTHeaderProtocolFactory() TProtocolFactory {
return tHeaderProtocolFactory{}
}
// Transport returns the underlying transport.
//
// It's guaranteed to be of type *THeaderTransport.
func (p *THeaderProtocol) Transport() TTransport {
return p.transport
}
// GetReadHeaders returns the THeaderMap read from transport.
func (p *THeaderProtocol) GetReadHeaders() THeaderMap {
return p.transport.GetReadHeaders()
}
// SetWriteHeader sets a header for write.
func (p *THeaderProtocol) SetWriteHeader(key, value string) {
p.transport.SetWriteHeader(key, value)
}
// ClearWriteHeaders clears all write headers previously set.
func (p *THeaderProtocol) ClearWriteHeaders() {
p.transport.ClearWriteHeaders()
}
// AddTransform add a transform for writing.
func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error {
return p.transport.AddTransform(transform)
}
func (p *THeaderProtocol) Flush(ctx context.Context) error {
return p.transport.Flush(ctx)
}
func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error {
newProto, err := p.transport.Protocol().GetProtocol(p.transport)
if err != nil {
return err
}
p.protocol = newProto
p.transport.SequenceID = seqID
return p.protocol.WriteMessageBegin(name, typeID, seqID)
}
func (p *THeaderProtocol) WriteMessageEnd() error {
if err := p.protocol.WriteMessageEnd(); err != nil {
return err
}
return p.transport.Flush(context.Background())
}
func (p *THeaderProtocol) WriteStructBegin(name string) error {
return p.protocol.WriteStructBegin(name)
}
func (p *THeaderProtocol) WriteStructEnd() error {
return p.protocol.WriteStructEnd()
}
func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error {
return p.protocol.WriteFieldBegin(name, typeID, id)
}
func (p *THeaderProtocol) WriteFieldEnd() error {
return p.protocol.WriteFieldEnd()
}
func (p *THeaderProtocol) WriteFieldStop() error {
return p.protocol.WriteFieldStop()
}
func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
return p.protocol.WriteMapBegin(keyType, valueType, size)
}
func (p *THeaderProtocol) WriteMapEnd() error {
return p.protocol.WriteMapEnd()
}
func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error {
return p.protocol.WriteListBegin(elemType, size)
}
func (p *THeaderProtocol) WriteListEnd() error {
return p.protocol.WriteListEnd()
}
func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error {
return p.protocol.WriteSetBegin(elemType, size)
}
func (p *THeaderProtocol) WriteSetEnd() error {
return p.protocol.WriteSetEnd()
}
func (p *THeaderProtocol) WriteBool(value bool) error {
return p.protocol.WriteBool(value)
}
func (p *THeaderProtocol) WriteByte(value int8) error {
return p.protocol.WriteByte(value)
}
func (p *THeaderProtocol) WriteI16(value int16) error {
return p.protocol.WriteI16(value)
}
func (p *THeaderProtocol) WriteI32(value int32) error {
return p.protocol.WriteI32(value)
}
func (p *THeaderProtocol) WriteI64(value int64) error {
return p.protocol.WriteI64(value)
}
func (p *THeaderProtocol) WriteDouble(value float64) error {
return p.protocol.WriteDouble(value)
}
func (p *THeaderProtocol) WriteString(value string) error {
return p.protocol.WriteString(value)
}
func (p *THeaderProtocol) WriteBinary(value []byte) error {
return p.protocol.WriteBinary(value)
}
// ReadFrame calls underlying THeaderTransport's ReadFrame function.
func (p *THeaderProtocol) ReadFrame() error {
return p.transport.ReadFrame()
}
func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) {
if err = p.transport.ReadFrame(); err != nil {
return
}
var newProto TProtocol
newProto, err = p.transport.Protocol().GetProtocol(p.transport)
if err != nil {
tAppExc, ok := err.(TApplicationException)
if !ok {
return
}
if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil {
return
}
if e := tAppExc.Write(p.protocol); e != nil {
return
}
if e := p.protocol.WriteMessageEnd(); e != nil {
return
}
if e := p.transport.Flush(context.Background()); e != nil {
return
}
return
}
p.protocol = newProto
return p.protocol.ReadMessageBegin()
}
func (p *THeaderProtocol) ReadMessageEnd() error {
return p.protocol.ReadMessageEnd()
}
func (p *THeaderProtocol) ReadStructBegin() (name string, err error) {
return p.protocol.ReadStructBegin()
}
func (p *THeaderProtocol) ReadStructEnd() error {
return p.protocol.ReadStructEnd()
}
func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) {
return p.protocol.ReadFieldBegin()
}
func (p *THeaderProtocol) ReadFieldEnd() error {
return p.protocol.ReadFieldEnd()
}
func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
return p.protocol.ReadMapBegin()
}
func (p *THeaderProtocol) ReadMapEnd() error {
return p.protocol.ReadMapEnd()
}
func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) {
return p.protocol.ReadListBegin()
}
func (p *THeaderProtocol) ReadListEnd() error {
return p.protocol.ReadListEnd()
}
func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return p.protocol.ReadSetBegin()
}
func (p *THeaderProtocol) ReadSetEnd() error {
return p.protocol.ReadSetEnd()
}
func (p *THeaderProtocol) ReadBool() (value bool, err error) {
return p.protocol.ReadBool()
}
func (p *THeaderProtocol) ReadByte() (value int8, err error) {
return p.protocol.ReadByte()
}
func (p *THeaderProtocol) ReadI16() (value int16, err error) {
return p.protocol.ReadI16()
}
func (p *THeaderProtocol) ReadI32() (value int32, err error) {
return p.protocol.ReadI32()
}
func (p *THeaderProtocol) ReadI64() (value int64, err error) {
return p.protocol.ReadI64()
}
func (p *THeaderProtocol) ReadDouble() (value float64, err error) {
return p.protocol.ReadDouble()
}
func (p *THeaderProtocol) ReadString() (value string, err error) {
return p.protocol.ReadString()
}
func (p *THeaderProtocol) ReadBinary() (value []byte, err error) {
return p.protocol.ReadBinary()
}
func (p *THeaderProtocol) Skip(fieldType TType) error {
return p.protocol.Skip(fieldType)
}

View File

@@ -0,0 +1,723 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"bytes"
"compress/zlib"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
)
// Size in bytes for 32-bit ints.
const size32 = 4
type headerMeta struct {
MagicFlags uint32
SequenceID int32
HeaderLength uint16
}
const headerMetaSize = 10
type clientType int
const (
clientUnknown clientType = iota
clientHeaders
clientFramedBinary
clientUnframedBinary
clientFramedCompact
clientUnframedCompact
)
// Constants defined in THeader format:
// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md
const (
THeaderHeaderMagic uint32 = 0x0fff0000
THeaderHeaderMask uint32 = 0xffff0000
THeaderFlagsMask uint32 = 0x0000ffff
THeaderMaxFrameSize uint32 = 0x3fffffff
)
// THeaderMap is the type of the header map in THeader transport.
type THeaderMap map[string]string
// THeaderProtocolID is the wrapped protocol id used in THeader.
type THeaderProtocolID int32
// Supported THeaderProtocolID values.
const (
THeaderProtocolBinary THeaderProtocolID = 0x00
THeaderProtocolCompact THeaderProtocolID = 0x02
THeaderProtocolDefault = THeaderProtocolBinary
)
// GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
switch id {
default:
return nil, NewTApplicationException(
INVALID_PROTOCOL,
fmt.Sprintf("THeader protocol id %d not supported", id),
)
case THeaderProtocolBinary:
return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil
case THeaderProtocolCompact:
return NewTCompactProtocol(trans), nil
}
}
// THeaderTransformID defines the numeric id of the transform used.
type THeaderTransformID int32
// THeaderTransformID values
const (
TransformNone THeaderTransformID = iota // 0, no special handling
TransformZlib // 1, zlib
// Rest of the values are not currently supported, namely HMAC and Snappy.
)
var supportedTransformIDs = map[THeaderTransformID]bool{
TransformNone: true,
TransformZlib: true,
}
// TransformReader is an io.ReadCloser that handles transforms reading.
type TransformReader struct {
io.Reader
closers []io.Closer
}
var _ io.ReadCloser = (*TransformReader)(nil)
// NewTransformReaderWithCapacity initializes a TransformReader with expected
// closers capacity.
//
// If you don't know the closers capacity beforehand, just use
//
// &TransformReader{Reader: baseReader}
//
// instead would be sufficient.
func NewTransformReaderWithCapacity(baseReader io.Reader, capacity int) *TransformReader {
return &TransformReader{
Reader: baseReader,
closers: make([]io.Closer, 0, capacity),
}
}
// Close calls the underlying closers in appropriate order,
// stops at and returns the first error encountered.
func (tr *TransformReader) Close() error {
// Call closers in reversed order
for i := len(tr.closers) - 1; i >= 0; i-- {
if err := tr.closers[i].Close(); err != nil {
return err
}
}
return nil
}
// AddTransform adds a transform.
func (tr *TransformReader) AddTransform(id THeaderTransformID) error {
switch id {
default:
return NewTApplicationException(
INVALID_TRANSFORM,
fmt.Sprintf("THeaderTransformID %d not supported", id),
)
case TransformNone:
// no-op
case TransformZlib:
readCloser, err := zlib.NewReader(tr.Reader)
if err != nil {
return err
}
tr.Reader = readCloser
tr.closers = append(tr.closers, readCloser)
}
return nil
}
// TransformWriter is an io.WriteCloser that handles transforms writing.
type TransformWriter struct {
io.Writer
closers []io.Closer
}
var _ io.WriteCloser = (*TransformWriter)(nil)
// NewTransformWriter creates a new TransformWriter with base writer and transforms.
func NewTransformWriter(baseWriter io.Writer, transforms []THeaderTransformID) (io.WriteCloser, error) {
writer := &TransformWriter{
Writer: baseWriter,
closers: make([]io.Closer, 0, len(transforms)),
}
for _, id := range transforms {
if err := writer.AddTransform(id); err != nil {
return nil, err
}
}
return writer, nil
}
// Close calls the underlying closers in appropriate order,
// stops at and returns the first error encountered.
func (tw *TransformWriter) Close() error {
// Call closers in reversed order
for i := len(tw.closers) - 1; i >= 0; i-- {
if err := tw.closers[i].Close(); err != nil {
return err
}
}
return nil
}
// AddTransform adds a transform.
func (tw *TransformWriter) AddTransform(id THeaderTransformID) error {
switch id {
default:
return NewTApplicationException(
INVALID_TRANSFORM,
fmt.Sprintf("THeaderTransformID %d not supported", id),
)
case TransformNone:
// no-op
case TransformZlib:
writeCloser := zlib.NewWriter(tw.Writer)
tw.Writer = writeCloser
tw.closers = append(tw.closers, writeCloser)
}
return nil
}
// THeaderInfoType is the type id of the info headers.
type THeaderInfoType int32
// Supported THeaderInfoType values.
const (
_ THeaderInfoType = iota // Skip 0
InfoKeyValue // 1
// Rest of the info types are not supported.
)
// THeaderTransport is a Transport mode that implements THeader.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
type THeaderTransport struct {
SequenceID int32
Flags uint32
transport TTransport
// THeaderMap for read and write
readHeaders THeaderMap
writeHeaders THeaderMap
// Reading related variables.
reader *bufio.Reader
// When frame is detected, we read the frame fully into frameBuffer.
frameBuffer bytes.Buffer
// When it's non-nil, Read should read from frameReader instead of
// reader, and EOF error indicates end of frame instead of end of all
// transport.
frameReader io.ReadCloser
// Writing related variables
writeBuffer bytes.Buffer
writeTransforms []THeaderTransformID
clientType clientType
protocolID THeaderProtocolID
// buffer is used in the following scenarios to avoid repetitive
// allocations, while 4 is big enough for all those scenarios:
//
// * header padding (max size 4)
// * write the frame size (size 4)
buffer [4]byte
}
var _ TTransport = (*THeaderTransport)(nil)
// NewTHeaderTransport creates THeaderTransport from the underlying transport.
//
// Please note that THeaderTransport handles framing and zlib by itself,
// so the underlying transport should be the raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
//
// If trans is already a *THeaderTransport, it will be returned as is.
func NewTHeaderTransport(trans TTransport) *THeaderTransport {
if ht, ok := trans.(*THeaderTransport); ok {
return ht
}
return &THeaderTransport{
transport: trans,
reader: bufio.NewReader(trans),
writeHeaders: make(THeaderMap),
protocolID: THeaderProtocolDefault,
}
}
// Open calls the underlying transport's Open function.
func (t *THeaderTransport) Open() error {
return t.transport.Open()
}
// IsOpen calls the underlying transport's IsOpen function.
func (t *THeaderTransport) IsOpen() bool {
return t.transport.IsOpen()
}
// ReadFrame tries to read the frame header, guess the client type, and handle
// unframed clients.
func (t *THeaderTransport) ReadFrame() error {
if !t.needReadFrame() {
// No need to read frame, skipping.
return nil
}
// Peek and handle the first 32 bits.
// They could either be the length field of a framed message,
// or the first bytes of an unframed message.
buf, err := t.reader.Peek(size32)
if err != nil {
return err
}
frameSize := binary.BigEndian.Uint32(buf)
if frameSize&VERSION_MASK == VERSION_1 {
t.clientType = clientUnframedBinary
return nil
}
if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == COMPACT_VERSION {
t.clientType = clientUnframedCompact
return nil
}
// At this point it should be a framed message,
// sanity check on frameSize then discard the peeked part.
if frameSize > THeaderMaxFrameSize {
return NewTProtocolExceptionWithType(
SIZE_LIMIT,
errors.New("frame too large"),
)
}
t.reader.Discard(size32)
// Read the frame fully into frameBuffer.
_, err = io.Copy(
&t.frameBuffer,
io.LimitReader(t.reader, int64(frameSize)),
)
if err != nil {
return err
}
t.frameReader = ioutil.NopCloser(&t.frameBuffer)
// Peek and handle the next 32 bits.
buf = t.frameBuffer.Bytes()[:size32]
version := binary.BigEndian.Uint32(buf)
if version&THeaderHeaderMask == THeaderHeaderMagic {
t.clientType = clientHeaders
return t.parseHeaders(frameSize)
}
if version&VERSION_MASK == VERSION_1 {
t.clientType = clientFramedBinary
return nil
}
if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == COMPACT_VERSION {
t.clientType = clientFramedCompact
return nil
}
if err := t.endOfFrame(); err != nil {
return err
}
return NewTProtocolExceptionWithType(
NOT_IMPLEMENTED,
errors.New("unsupported client transport type"),
)
}
// endOfFrame does end of frame handling.
//
// It closes frameReader, and also resets frame related states.
func (t *THeaderTransport) endOfFrame() error {
defer func() {
t.frameBuffer.Reset()
t.frameReader = nil
}()
return t.frameReader.Close()
}
func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
if t.clientType != clientHeaders {
return nil
}
var err error
var meta headerMeta
if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != nil {
return err
}
frameSize -= headerMetaSize
t.Flags = meta.MagicFlags & THeaderFlagsMask
t.SequenceID = meta.SequenceID
headerLength := int64(meta.HeaderLength) * 4
if int64(frameSize) < headerLength {
return NewTProtocolExceptionWithType(
SIZE_LIMIT,
errors.New("header size is larger than the whole frame"),
)
}
headerBuf := NewTMemoryBuffer()
_, err = io.Copy(headerBuf, io.LimitReader(&t.frameBuffer, headerLength))
if err != nil {
return err
}
hp := NewTCompactProtocol(headerBuf)
// At this point the header is already read into headerBuf,
// and t.frameBuffer starts from the actual payload.
protoID, err := hp.readVarint32()
if err != nil {
return err
}
t.protocolID = THeaderProtocolID(protoID)
var transformCount int32
transformCount, err = hp.readVarint32()
if err != nil {
return err
}
if transformCount > 0 {
reader := NewTransformReaderWithCapacity(
&t.frameBuffer,
int(transformCount),
)
t.frameReader = reader
transformIDs := make([]THeaderTransformID, transformCount)
for i := 0; i < int(transformCount); i++ {
id, err := hp.readVarint32()
if err != nil {
return err
}
transformIDs[i] = THeaderTransformID(id)
}
// The transform IDs on the wire was added based on the order of
// writing, so on the reading side we need to reverse the order.
for i := transformCount - 1; i >= 0; i-- {
id := transformIDs[i]
if err := reader.AddTransform(id); err != nil {
return err
}
}
}
// The info part does not use the transforms yet, so it's
// important to continue using headerBuf.
headers := make(THeaderMap)
for {
infoType, err := hp.readVarint32()
if err == io.EOF {
break
}
if err != nil {
return err
}
if THeaderInfoType(infoType) == InfoKeyValue {
count, err := hp.readVarint32()
if err != nil {
return err
}
for i := 0; i < int(count); i++ {
key, err := hp.ReadString()
if err != nil {
return err
}
value, err := hp.ReadString()
if err != nil {
return err
}
headers[key] = value
}
} else {
// Skip reading info section on the first
// unsupported info type.
break
}
}
t.readHeaders = headers
return nil
}
func (t *THeaderTransport) needReadFrame() bool {
if t.clientType == clientUnknown {
// This is a new connection that's never read before.
return true
}
if t.isFramed() && t.frameReader == nil {
// We just finished the last frame.
return true
}
return false
}
func (t *THeaderTransport) Read(p []byte) (read int, err error) {
err = t.ReadFrame()
if err != nil {
return
}
if t.frameReader != nil {
read, err = t.frameReader.Read(p)
if err == io.EOF {
err = t.endOfFrame()
if err != nil {
return
}
if read < len(p) {
var nextRead int
nextRead, err = t.Read(p[read:])
read += nextRead
}
}
return
}
return t.reader.Read(p)
}
// Write writes data to the write buffer.
//
// You need to call Flush to actually write them to the transport.
func (t *THeaderTransport) Write(p []byte) (int, error) {
return t.writeBuffer.Write(p)
}
// Flush writes the appropriate header and the write buffer to the underlying transport.
func (t *THeaderTransport) Flush(ctx context.Context) error {
if t.writeBuffer.Len() == 0 {
return nil
}
defer t.writeBuffer.Reset()
switch t.clientType {
default:
fallthrough
case clientUnknown:
t.clientType = clientHeaders
fallthrough
case clientHeaders:
headers := NewTMemoryBuffer()
hp := NewTCompactProtocol(headers)
if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil {
return NewTTransportExceptionFromError(err)
}
for _, transform := range t.writeTransforms {
if _, err := hp.writeVarint32(int32(transform)); err != nil {
return NewTTransportExceptionFromError(err)
}
}
if len(t.writeHeaders) > 0 {
if _, err := hp.writeVarint32(int32(InfoKeyValue)); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := hp.writeVarint32(int32(len(t.writeHeaders))); err != nil {
return NewTTransportExceptionFromError(err)
}
for key, value := range t.writeHeaders {
if err := hp.WriteString(key); err != nil {
return NewTTransportExceptionFromError(err)
}
if err := hp.WriteString(value); err != nil {
return NewTTransportExceptionFromError(err)
}
}
}
padding := 4 - headers.Len()%4
if padding < 4 {
buf := t.buffer[:padding]
for i := range buf {
buf[i] = 0
}
if _, err := headers.Write(buf); err != nil {
return NewTTransportExceptionFromError(err)
}
}
var payload bytes.Buffer
meta := headerMeta{
MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask,
SequenceID: t.SequenceID,
HeaderLength: uint16(headers.Len() / 4),
}
if err := binary.Write(&payload, binary.BigEndian, meta); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := io.Copy(&payload, headers); err != nil {
return NewTTransportExceptionFromError(err)
}
writer, err := NewTransformWriter(&payload, t.writeTransforms)
if err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := io.Copy(writer, &t.writeBuffer); err != nil {
return NewTTransportExceptionFromError(err)
}
if err := writer.Close(); err != nil {
return NewTTransportExceptionFromError(err)
}
// First write frame length
buf := t.buffer[:size32]
binary.BigEndian.PutUint32(buf, uint32(payload.Len()))
if _, err := t.transport.Write(buf); err != nil {
return NewTTransportExceptionFromError(err)
}
// Then write the payload
if _, err := io.Copy(t.transport, &payload); err != nil {
return NewTTransportExceptionFromError(err)
}
case clientFramedBinary, clientFramedCompact:
buf := t.buffer[:size32]
binary.BigEndian.PutUint32(buf, uint32(t.writeBuffer.Len()))
if _, err := t.transport.Write(buf); err != nil {
return NewTTransportExceptionFromError(err)
}
fallthrough
case clientUnframedBinary, clientUnframedCompact:
if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil {
return NewTTransportExceptionFromError(err)
}
}
select {
default:
case <-ctx.Done():
return NewTTransportExceptionFromError(ctx.Err())
}
return t.transport.Flush(ctx)
}
// Close closes the transport, along with its underlying transport.
func (t *THeaderTransport) Close() error {
if err := t.Flush(context.Background()); err != nil {
return err
}
return t.transport.Close()
}
// RemainingBytes calls underlying transport's RemainingBytes.
//
// Even in framed cases, because of all the possible compression transforms
// involved, the remaining frame size is likely to be different from the actual
// remaining readable bytes, so we don't bother to keep tracking the remaining
// frame size by ourselves and just use the underlying transport's
// RemainingBytes directly.
func (t *THeaderTransport) RemainingBytes() uint64 {
return t.transport.RemainingBytes()
}
// GetReadHeaders returns the THeaderMap read from transport.
func (t *THeaderTransport) GetReadHeaders() THeaderMap {
return t.readHeaders
}
// SetWriteHeader sets a header for write.
func (t *THeaderTransport) SetWriteHeader(key, value string) {
t.writeHeaders[key] = value
}
// ClearWriteHeaders clears all write headers previously set.
func (t *THeaderTransport) ClearWriteHeaders() {
t.writeHeaders = make(THeaderMap)
}
// AddTransform add a transform for writing.
func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error {
if !supportedTransformIDs[transform] {
return NewTProtocolExceptionWithType(
NOT_IMPLEMENTED,
fmt.Errorf("THeaderTransformID %d not supported", transform),
)
}
t.writeTransforms = append(t.writeTransforms, transform)
return nil
}
// Protocol returns the wrapped protocol id used in this THeaderTransport.
func (t *THeaderTransport) Protocol() THeaderProtocolID {
switch t.clientType {
default:
return t.protocolID
case clientFramedBinary, clientUnframedBinary:
return THeaderProtocolBinary
case clientFramedCompact, clientUnframedCompact:
return THeaderProtocolCompact
}
}
func (t *THeaderTransport) isFramed() bool {
switch t.clientType {
default:
return false
case clientHeaders, clientFramedBinary, clientFramedCompact:
return true
}
}
// THeaderTransportFactory is a TTransportFactory implementation to create
// THeaderTransport.
type THeaderTransportFactory struct {
// The underlying factory, could be nil.
Factory TTransportFactory
}
// NewTHeaderTransportFactory creates a new *THeaderTransportFactory.
func NewTHeaderTransportFactory(factory TTransportFactory) TTransportFactory {
return &THeaderTransportFactory{
Factory: factory,
}
}
// GetTransport implements TTransportFactory.
func (f *THeaderTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if f.Factory != nil {
t, err := f.Factory.GetTransport(trans)
if err != nil {
return nil, err
}
return NewTHeaderTransport(t), nil
}
return NewTHeaderTransport(trans), nil
}

View File

@@ -32,10 +32,7 @@ const (
// for references to _ParseContext see tsimplejson_protocol.go
// JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
// suitable for parsing by scripting languages. It should not be
// confused with the full-featured TJSONProtocol.
// Utilizes Simple JSON protocol
//
type TJSONProtocol struct {
*TSimpleJSONProtocol

View File

@@ -41,6 +41,8 @@ package thrift
func Float32Ptr(v float32) *float32 { return &v }
func Float64Ptr(v float64) *float64 { return &v }
func IntPtr(v int) *int { return &v }
func Int8Ptr(v int8) *int8 { return &v }
func Int16Ptr(v int16) *int16 { return &v }
func Int32Ptr(v int32) *int32 { return &v }
func Int64Ptr(v int64) *int64 { return &v }
func StringPtr(v string) *string { return &v }

View File

@@ -96,8 +96,6 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
}
switch fieldType {
case STOP:
return
case BOOL:
_, err = self.ReadBool()
return

View File

@@ -60,7 +60,7 @@ func (p _ParseContext) String() string {
return "UNKNOWN-PARSE-CONTEXT"
}
// JSON protocol implementation for thrift.
// Simple JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
// suitable for parsing by scripting languages. It should not be
@@ -1316,7 +1316,7 @@ func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) {
func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool {
for i := 0; i < len(b); i++ {
a, _ := p.reader.Peek(i + 1)
if len(a) == 0 || a[i] != b[i] {
if len(a) < (i+1) || a[i] != b[i] {
return false
}
}

View File

@@ -42,6 +42,9 @@ type TSimpleServer struct {
outputTransportFactory TTransportFactory
inputProtocolFactory TProtocolFactory
outputProtocolFactory TProtocolFactory
// Headers to auto forward in THeaderProtocol
forwardHeaders []string
}
func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
@@ -125,6 +128,26 @@ func (p *TSimpleServer) Listen() error {
return p.serverTransport.Listen()
}
// SetForwardHeaders sets the list of header keys that will be auto forwarded
// while using THeaderProtocol.
//
// "forward" means that when the server is also a client to other upstream
// thrift servers, the context object user gets in the processor functions will
// have both read and write headers set, with write headers being forwarded.
// Users can always override the write headers by calling SetWriteHeaderList
// before calling thrift client functions.
func (p *TSimpleServer) SetForwardHeaders(headers []string) {
size := len(headers)
if size == 0 {
p.forwardHeaders = nil
return
}
keys := make([]string, size)
copy(keys, headers)
p.forwardHeaders = keys
}
func (p *TSimpleServer) innerAccept() (int32, error) {
client, err := p.serverTransport.Accept()
p.mu.Lock()
@@ -187,12 +210,25 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
if err != nil {
return err
}
outputTransport, err := p.outputTransportFactory.GetTransport(client)
if err != nil {
return err
}
inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport)
var outputTransport TTransport
var outputProtocol TProtocol
// for THeaderProtocol, we must use the same protocol instance for
// input and output so that the response is in the same dialect that
// the server detected the request was in.
headerProtocol, ok := inputProtocol.(*THeaderProtocol)
if ok {
outputProtocol = inputProtocol
} else {
oTrans, err := p.outputTransportFactory.GetTransport(client)
if err != nil {
return err
}
outputTransport = oTrans
outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport)
}
defer func() {
if e := recover(); e != nil {
log.Printf("panic in processor: %s: %s", e, debug.Stack())
@@ -210,7 +246,22 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
return nil
}
ok, err := processor.Process(defaultCtx, inputProtocol, outputProtocol)
ctx := defaultCtx
if headerProtocol != nil {
// We need to call ReadFrame here, otherwise we won't
// get any headers on the AddReadTHeaderToContext call.
//
// ReadFrame is safe to be called multiple times so it
// won't break when it's called again later when we
// actually start to read the message.
if err := headerProtocol.ReadFrame(); err != nil {
return err
}
ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders())
ctx = SetWriteHeaderList(ctx, p.forwardHeaders)
}
ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
return nil
} else if err != nil {

View File

@@ -162,5 +162,5 @@ func (p *TSocket) Interrupt() error {
func (p *TSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the truth is, we just don't know unless framed is used
}