build(deps): bump github.com/jackc/pgtype from 1.7.0 to 1.8.1

Bumps [github.com/jackc/pgtype](https://github.com/jackc/pgtype) from 1.7.0 to 1.8.1.
- [Release notes](https://github.com/jackc/pgtype/releases)
- [Changelog](https://github.com/jackc/pgtype/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jackc/pgtype/compare/v1.7.0...v1.8.1)

---
updated-dependencies:
- dependency-name: github.com/jackc/pgtype
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot] 2021-09-01 08:13:04 +00:00 committed by Ondřej Budai
parent 97d6142609
commit 9ceeaa1dfd
205 changed files with 17496 additions and 13126 deletions

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@ -14,6 +15,9 @@ type AuthenticationCleartextPassword struct {
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
@ -37,3 +41,12 @@ func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationCleartextPassword",
})
}

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@ -15,6 +16,9 @@ type AuthenticationMD5Password struct {
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationMD5Password) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
@ -41,3 +45,33 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, src.Salt[:]...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Salt [4]byte
}{
Type: "AuthenticationMD5Password",
Salt: src.Salt,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Salt [4]byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Salt = msg.Salt
return nil
}

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@ -14,6 +15,9 @@ type AuthenticationOk struct {
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationOk) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
@ -37,3 +41,12 @@ func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationOK",
})
}

View file

@ -3,6 +3,7 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@ -16,6 +17,9 @@ type AuthenticationSASL struct {
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASL) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
@ -58,3 +62,14 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
AuthMechanisms []string
}{
Type: "AuthenticationSASL",
AuthMechanisms: src.AuthMechanisms,
})
}

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@ -15,6 +16,9 @@ type AuthenticationSASLContinue struct {
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
@ -46,3 +50,32 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLContinue",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@ -15,6 +16,9 @@ type AuthenticationSASLFinal struct {
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
@ -46,3 +50,32 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
return dst
}
// MarshalJSON implements encoding/json.Unmarshaler.
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLFinal",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

View file

@ -12,27 +12,34 @@ type Backend struct {
w io.Writer
// Frontend message flyweights
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
describe Describe
execute Execute
flush Flush
gssEncRequest GSSEncRequest
parse Parse
passwordMessage PasswordMessage
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
)
// NewBackend creates a new Backend.
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
@ -54,9 +61,13 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
}
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, err
return nil, translateEOFtoErrUnexpectedEOF(err)
}
code := binary.BigEndian.Uint32(buf)
@ -96,7 +107,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.msgType = header[0]
@ -116,12 +127,28 @@ func (b *Backend) Receive() (FrontendMessage, error) {
msg = &b.execute
case 'f':
msg = &b.copyFail
case 'd':
msg = &b.copyData
case 'c':
msg = &b.copyDone
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
msg = &b.passwordMessage
switch b.authType {
case AuthTypeSASL:
msg = &SASLInitialResponse{}
case AuthTypeSASLContinue:
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatability
msg = &PasswordMessage{}
}
case 'Q':
msg = &b.query
case 'S':
@ -134,7 +161,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, err
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.partialMsg = false
@ -142,3 +169,36 @@ func (b *Backend) Receive() (FrontendMessage, error) {
err = msg.Decode(msgBody)
return msg, err
}
// SetAuthType sets the authentication type in the backend.
// Since multiple message types can start with 'p', SetAuthType allows
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
func (b *Backend) SetAuthType(authType uint32) error {
switch authType {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
b.authType = authType
default:
return fmt.Errorf("authType not recognized: %d", authType)
}
return nil
}

View file

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/jackc/pgio"
)
@ -181,3 +182,35 @@ func (src Bind) MarshalJSON() ([]byte, error) {
ResultFormatCodes: src.ResultFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Bind) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.DestinationPortal = msg.DestinationPortal
dst.PreparedStatement = msg.PreparedStatement
dst.ParameterFormatCodes = msg.ParameterFormatCodes
dst.Parameters = make([][]byte, len(msg.Parameters))
dst.ResultFormatCodes = msg.ResultFormatCodes
for n, parameter := range msg.Parameters {
dst.Parameters[n], err = getValueFromJSON(parameter)
if err != nil {
return fmt.Errorf("cannot get param %d: %w", n, err)
}
}
return nil
}

View file

@ -3,6 +3,7 @@ package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@ -62,3 +63,27 @@ func (src Close) MarshalJSON() ([]byte, error) {
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Close) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Close.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

View file

@ -51,3 +51,21 @@ func (src CommandComplete) MarshalJSON() ([]byte, error) {
CommandTag: string(src.CommandTag),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
CommandTag string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.CommandTag = []byte(msg.CommandTag)
return nil
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@ -68,3 +69,27 @@ func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyBothResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

View file

@ -42,3 +42,21 @@ func (src CopyData) MarshalJSON() ([]byte, error) {
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyData) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@ -69,3 +70,27 @@ func (src CopyInResponse) MarshalJSON() ([]byte, error) {
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyInResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@ -69,3 +70,27 @@ func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyOutResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

View file

@ -115,3 +115,28 @@ func (src DataRow) MarshalJSON() ([]byte, error) {
Values: formattedValues,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *DataRow) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Values []map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Values = make([][]byte, len(msg.Values))
for n, parameter := range msg.Values {
var err error
dst.Values[n], err = getValueFromJSON(parameter)
if err != nil {
return err
}
}
return nil
}

View file

@ -3,6 +3,7 @@ package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@ -62,3 +63,26 @@ func (src Describe) MarshalJSON() ([]byte, error) {
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Describe) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Describe.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

View file

@ -3,27 +3,29 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"strconv"
)
type ErrorResponse struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
@ -56,6 +58,8 @@ func (dst *ErrorResponse) Decode(src []byte) error {
switch k {
case 'S':
dst.Severity = v
case 'V':
dst.SeverityUnlocalized = v
case 'C':
dst.Code = v
case 'M':
@ -123,6 +127,11 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.SeverityUnlocalized != "" {
buf.WriteByte('V')
buf.WriteString(src.SeverityUnlocalized)
buf.WriteByte(0)
}
if src.Code != "" {
buf.WriteByte('C')
buf.WriteString(src.Code)
@ -210,9 +219,116 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}{
Type: "ErrorResponse",
Severity: src.Severity,
SeverityUnlocalized: src.SeverityUnlocalized,
Code: src.Code,
Message: src.Message,
Detail: src.Detail,
Hint: src.Hint,
Position: src.Position,
InternalPosition: src.InternalPosition,
InternalQuery: src.InternalQuery,
Where: src.Where,
SchemaName: src.SchemaName,
TableName: src.TableName,
ColumnName: src.ColumnName,
DataTypeName: src.DataTypeName,
ConstraintName: src.ConstraintName,
File: src.File,
Line: src.Line,
Routine: src.Routine,
UnknownFields: src.UnknownFields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Severity = msg.Severity
dst.SeverityUnlocalized = msg.SeverityUnlocalized
dst.Code = msg.Code
dst.Message = msg.Message
dst.Detail = msg.Detail
dst.Hint = msg.Hint
dst.Position = msg.Position
dst.InternalPosition = msg.InternalPosition
dst.InternalQuery = msg.InternalQuery
dst.Where = msg.Where
dst.SchemaName = msg.SchemaName
dst.TableName = msg.TableName
dst.ColumnName = msg.ColumnName
dst.DataTypeName = msg.DataTypeName
dst.ConstraintName = msg.ConstraintName
dst.File = msg.File
dst.Line = msg.Line
dst.Routine = msg.Routine
dst.UnknownFields = msg.UnknownFields
return nil
}

View file

@ -45,6 +45,7 @@ type Frontend struct {
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
// NewFrontend creates a new Frontend.
@ -146,10 +147,16 @@ func (f *Frontend) Receive() (BackendMessage, error) {
}
// Authentication message type constants.
// See src/include/libpq/pqcomm.h for all
// constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSCMCreds = 6
AuthTypeGSS = 7
AuthTypeGSSCont = 8
AuthTypeSSPI = 9
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
@ -159,15 +166,23 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src[:4])
f.authType = binary.BigEndian.Uint32(src[:4])
switch authType {
switch f.authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return nil, errors.New("AuthTypeGSS is unimplemented")
case AuthTypeGSSCont:
return nil, errors.New("AuthTypeGSSCont is unimplemented")
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
@ -175,6 +190,12 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", authType)
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
}
}
// GetAuthType returns the authType used in the current state of the frontend.
// See SetAuthType for more information.
func (f *Frontend) GetAuthType() uint32 {
return f.authType
}

View file

@ -81,3 +81,21 @@ func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
Result: formattedValue,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Result map[string]string
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.Result, err = getValueFromJSON(msg.Result)
return err
}

View file

@ -14,6 +14,9 @@ type PasswordMessage struct {
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*PasswordMessage) Frontend() {}
// Frontend identifies this message as an authentication response.
func (*PasswordMessage) InitialResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *PasswordMessage) Decode(src []byte) error {

View file

@ -1,6 +1,10 @@
package pgproto3
import "fmt"
import (
"encoding/hex"
"errors"
"fmt"
)
// Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message.
@ -23,6 +27,11 @@ type BackendMessage interface {
Backend() // no-op method to distinguish frontend from backend methods
}
type AuthenticationResponseMessage interface {
BackendMessage
AuthenticationResponse() // no-op method to distinguish authentication responses
}
type invalidMessageLenErr struct {
messageType string
expectedLen int
@ -40,3 +49,17 @@ type invalidMessageFormatErr struct {
func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {
return nil, nil
}
if text, ok := v["text"]; ok {
return []byte(text), nil
}
if binary, ok := v["binary"]; ok {
return hex.DecodeString(binary)
}
return nil, errors.New("unknown protocol representation")
}

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/json"
"errors"
)
type ReadyForQuery struct {
@ -38,3 +39,23 @@ func (src ReadyForQuery) MarshalJSON() ([]byte, error) {
TxStatus: string(src.TxStatus),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
TxStatus string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.TxStatus) != 1 {
return errors.New("invalid length for ReadyForQuery.TxStatus")
}
dst.TxStatus = msg.TxStatus[0]
return nil
}

View file

@ -132,3 +132,34 @@ func (src RowDescription) MarshalJSON() ([]byte, error) {
Fields: src.Fields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *RowDescription) UnmarshalJSON(data []byte) error {
var msg struct {
Fields []struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Fields = make([]FieldDescription, len(msg.Fields))
for n, field := range msg.Fields {
dst.Fields[n] = FieldDescription{
Name: []byte(field.Name),
TableOID: field.TableOID,
TableAttributeNumber: field.TableAttributeNumber,
DataTypeOID: field.DataTypeOID,
DataTypeSize: field.DataTypeSize,
TypeModifier: field.TypeModifier,
Format: field.Format,
}
}
return nil
}

View file

@ -67,3 +67,28 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) {
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
AuthMechanism string
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.AuthMechanism = msg.AuthMechanism
if msg.Data != "" {
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.Data = decoded
}
return nil
}

View file

@ -41,3 +41,21 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) {
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *SASLResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if msg.Data != "" {
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.Data = decoded
}
return nil
}