Write an openapi spec for the worker API and use `deepmap/oapi-codegen`
to generate scaffolding for the server-side using the `labstack/echo`
server.
Incidentally, echo by default returns the errors in the same format that
worker API always has:
{ "message": "..." }
The API itself is unchanged to make this change easier to understand. It
will be changed to better suit our needs in future commits.
311 lines
8.9 KiB
Go
311 lines
8.9 KiB
Go
package echo
|
|
|
|
import (
|
|
"encoding"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type (
|
|
// Binder is the interface that wraps the Bind method.
|
|
Binder interface {
|
|
Bind(i interface{}, c Context) error
|
|
}
|
|
|
|
// DefaultBinder is the default implementation of the Binder interface.
|
|
DefaultBinder struct{}
|
|
|
|
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
|
|
// Types that don't implement this, but do implement encoding.TextUnmarshaler
|
|
// will use that interface instead.
|
|
BindUnmarshaler interface {
|
|
// UnmarshalParam decodes and assigns a value from an form or query param.
|
|
UnmarshalParam(param string) error
|
|
}
|
|
)
|
|
|
|
// Bind implements the `Binder#Bind` function.
|
|
func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
|
|
req := c.Request()
|
|
|
|
names := c.ParamNames()
|
|
values := c.ParamValues()
|
|
params := map[string][]string{}
|
|
for i, name := range names {
|
|
params[name] = []string{values[i]}
|
|
}
|
|
if err := b.bindData(i, params, "param"); err != nil {
|
|
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
|
}
|
|
if err = b.bindData(i, c.QueryParams(), "query"); err != nil {
|
|
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
|
}
|
|
if req.ContentLength == 0 {
|
|
return
|
|
}
|
|
ctype := req.Header.Get(HeaderContentType)
|
|
switch {
|
|
case strings.HasPrefix(ctype, MIMEApplicationJSON):
|
|
if err = json.NewDecoder(req.Body).Decode(i); err != nil {
|
|
if ute, ok := err.(*json.UnmarshalTypeError); ok {
|
|
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
|
|
} else if se, ok := err.(*json.SyntaxError); ok {
|
|
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
|
|
}
|
|
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
|
}
|
|
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
|
|
if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
|
|
if ute, ok := err.(*xml.UnsupportedTypeError); ok {
|
|
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
|
|
} else if se, ok := err.(*xml.SyntaxError); ok {
|
|
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
|
|
}
|
|
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
|
}
|
|
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
|
|
params, err := c.FormParams()
|
|
if err != nil {
|
|
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
|
}
|
|
if err = b.bindData(i, params, "form"); err != nil {
|
|
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
|
}
|
|
default:
|
|
return ErrUnsupportedMediaType
|
|
}
|
|
return
|
|
}
|
|
|
|
func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error {
|
|
if ptr == nil || len(data) == 0 {
|
|
return nil
|
|
}
|
|
typ := reflect.TypeOf(ptr).Elem()
|
|
val := reflect.ValueOf(ptr).Elem()
|
|
|
|
if m, ok := ptr.(*map[string]interface{}); ok {
|
|
for k, v := range data {
|
|
(*m)[k] = v[0]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if typ.Kind() != reflect.Struct {
|
|
return errors.New("binding element must be a struct")
|
|
}
|
|
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
typeField := typ.Field(i)
|
|
structField := val.Field(i)
|
|
if !structField.CanSet() {
|
|
continue
|
|
}
|
|
structFieldKind := structField.Kind()
|
|
inputFieldName := typeField.Tag.Get(tag)
|
|
|
|
if inputFieldName == "" {
|
|
inputFieldName = typeField.Name
|
|
// If tag is nil, we inspect if the field is a struct.
|
|
if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct {
|
|
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
inputValue, exists := data[inputFieldName]
|
|
if !exists {
|
|
// Go json.Unmarshal supports case insensitive binding. However the
|
|
// url params are bound case sensitive which is inconsistent. To
|
|
// fix this we must check all of the map values in a
|
|
// case-insensitive search.
|
|
inputFieldName = strings.ToLower(inputFieldName)
|
|
for k, v := range data {
|
|
if strings.ToLower(k) == inputFieldName {
|
|
inputValue = v
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if !exists {
|
|
continue
|
|
}
|
|
|
|
// Call this first, in case we're dealing with an alias to an array type
|
|
if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
|
|
numElems := len(inputValue)
|
|
if structFieldKind == reflect.Slice && numElems > 0 {
|
|
sliceOf := structField.Type().Elem().Kind()
|
|
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
|
|
for j := 0; j < numElems; j++ {
|
|
if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
val.Field(i).Set(slice)
|
|
} else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
|
|
return err
|
|
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
|
|
// But also call it here, in case we're dealing with an array of BindUnmarshalers
|
|
if ok, err := unmarshalField(valueKind, val, structField); ok {
|
|
return err
|
|
}
|
|
|
|
switch valueKind {
|
|
case reflect.Ptr:
|
|
return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
|
|
case reflect.Int:
|
|
return setIntField(val, 0, structField)
|
|
case reflect.Int8:
|
|
return setIntField(val, 8, structField)
|
|
case reflect.Int16:
|
|
return setIntField(val, 16, structField)
|
|
case reflect.Int32:
|
|
return setIntField(val, 32, structField)
|
|
case reflect.Int64:
|
|
return setIntField(val, 64, structField)
|
|
case reflect.Uint:
|
|
return setUintField(val, 0, structField)
|
|
case reflect.Uint8:
|
|
return setUintField(val, 8, structField)
|
|
case reflect.Uint16:
|
|
return setUintField(val, 16, structField)
|
|
case reflect.Uint32:
|
|
return setUintField(val, 32, structField)
|
|
case reflect.Uint64:
|
|
return setUintField(val, 64, structField)
|
|
case reflect.Bool:
|
|
return setBoolField(val, structField)
|
|
case reflect.Float32:
|
|
return setFloatField(val, 32, structField)
|
|
case reflect.Float64:
|
|
return setFloatField(val, 64, structField)
|
|
case reflect.String:
|
|
structField.SetString(val)
|
|
default:
|
|
return errors.New("unknown type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
|
|
switch valueKind {
|
|
case reflect.Ptr:
|
|
return unmarshalFieldPtr(val, field)
|
|
default:
|
|
return unmarshalFieldNonPtr(val, field)
|
|
}
|
|
}
|
|
|
|
// bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler
|
|
func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
|
|
ptr := reflect.New(field.Type())
|
|
if ptr.CanInterface() {
|
|
iface := ptr.Interface()
|
|
if unmarshaler, ok := iface.(BindUnmarshaler); ok {
|
|
return unmarshaler, ok
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
// textUnmarshaler attempts to unmarshal a reflect.Value into a TextUnmarshaler
|
|
func textUnmarshaler(field reflect.Value) (encoding.TextUnmarshaler, bool) {
|
|
ptr := reflect.New(field.Type())
|
|
if ptr.CanInterface() {
|
|
iface := ptr.Interface()
|
|
if unmarshaler, ok := iface.(encoding.TextUnmarshaler); ok {
|
|
return unmarshaler, ok
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
|
|
if unmarshaler, ok := bindUnmarshaler(field); ok {
|
|
err := unmarshaler.UnmarshalParam(value)
|
|
field.Set(reflect.ValueOf(unmarshaler).Elem())
|
|
return true, err
|
|
}
|
|
if unmarshaler, ok := textUnmarshaler(field); ok {
|
|
err := unmarshaler.UnmarshalText([]byte(value))
|
|
field.Set(reflect.ValueOf(unmarshaler).Elem())
|
|
return true, err
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
|
|
if field.IsNil() {
|
|
// Initialize the pointer to a nil value
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
}
|
|
return unmarshalFieldNonPtr(value, field.Elem())
|
|
}
|
|
|
|
func setIntField(value string, bitSize int, field reflect.Value) error {
|
|
if value == "" {
|
|
value = "0"
|
|
}
|
|
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
|
if err == nil {
|
|
field.SetInt(intVal)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func setUintField(value string, bitSize int, field reflect.Value) error {
|
|
if value == "" {
|
|
value = "0"
|
|
}
|
|
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
|
if err == nil {
|
|
field.SetUint(uintVal)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func setBoolField(value string, field reflect.Value) error {
|
|
if value == "" {
|
|
value = "false"
|
|
}
|
|
boolVal, err := strconv.ParseBool(value)
|
|
if err == nil {
|
|
field.SetBool(boolVal)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func setFloatField(value string, bitSize int, field reflect.Value) error {
|
|
if value == "" {
|
|
value = "0.0"
|
|
}
|
|
floatVal, err := strconv.ParseFloat(value, bitSize)
|
|
if err == nil {
|
|
field.SetFloat(floatVal)
|
|
}
|
|
return err
|
|
}
|