go.mod: Update oapi-codegen and kin-openapi

This commit is contained in:
sanne 2022-01-11 19:00:14 +01:00 committed by Sanne Raymaekers
parent add17bba45
commit a83cf95d5b
156 changed files with 29663 additions and 2248 deletions

View file

@ -20,8 +20,11 @@ import (
"os"
"path"
"path/filepath"
"runtime/debug"
"strings"
"gopkg.in/yaml.v2"
"github.com/deepmap/oapi-codegen/pkg/codegen"
"github.com/deepmap/oapi-codegen/pkg/util"
)
@ -31,45 +34,79 @@ func errExit(format string, args ...interface{}) {
os.Exit(1)
}
var (
flagPackageName string
flagGenerate string
flagOutputFile string
flagIncludeTags string
flagExcludeTags string
flagTemplatesDir string
flagImportMapping string
flagExcludeSchemas string
flagConfigFile string
flagAliasTypes bool
flagPrintVersion bool
)
type configuration struct {
PackageName string `yaml:"package"`
GenerateTargets []string `yaml:"generate"`
OutputFile string `yaml:"output"`
IncludeTags []string `yaml:"include-tags"`
ExcludeTags []string `yaml:"exclude-tags"`
TemplatesDir string `yaml:"templates"`
ImportMapping map[string]string `yaml:"import-mapping"`
ExcludeSchemas []string `yaml:"exclude-schemas"`
}
func main() {
var (
packageName string
generate string
outputFile string
includeTags string
excludeTags string
templatesDir string
importMapping string
excludeSchemas string
)
flag.StringVar(&packageName, "package", "", "The package name for generated code")
flag.StringVar(&generate, "generate", "types,client,server,spec",
flag.StringVar(&flagPackageName, "package", "", "The package name for generated code")
flag.StringVar(&flagGenerate, "generate", "types,client,server,spec",
`Comma-separated list of code to generate; valid options: "types", "client", "chi-server", "server", "spec", "skip-fmt", "skip-prune"`)
flag.StringVar(&outputFile, "o", "", "Where to output generated code, stdout is default")
flag.StringVar(&includeTags, "include-tags", "", "Only include operations with the given tags. Comma-separated list of tags.")
flag.StringVar(&excludeTags, "exclude-tags", "", "Exclude operations that are tagged with the given tags. Comma-separated list of tags.")
flag.StringVar(&templatesDir, "templates", "", "Path to directory containing user templates")
flag.StringVar(&importMapping, "import-mapping", "", "A dict from the external reference to golang package path")
flag.StringVar(&excludeSchemas, "exclude-schemas", "", "A comma separated list of schemas which must be excluded from generation")
flag.StringVar(&flagOutputFile, "o", "", "Where to output generated code, stdout is default")
flag.StringVar(&flagIncludeTags, "include-tags", "", "Only include operations with the given tags. Comma-separated list of tags.")
flag.StringVar(&flagExcludeTags, "exclude-tags", "", "Exclude operations that are tagged with the given tags. Comma-separated list of tags.")
flag.StringVar(&flagTemplatesDir, "templates", "", "Path to directory containing user templates")
flag.StringVar(&flagImportMapping, "import-mapping", "", "A dict from the external reference to golang package path")
flag.StringVar(&flagExcludeSchemas, "exclude-schemas", "", "A comma separated list of schemas which must be excluded from generation")
flag.StringVar(&flagConfigFile, "config", "", "a YAML config file that controls oapi-codegen behavior")
flag.BoolVar(&flagAliasTypes, "alias-types", false, "Alias type declarations of possible")
flag.BoolVar(&flagPrintVersion, "version", false, "when specified, print version and exit")
flag.Parse()
if flagPrintVersion {
bi, ok := debug.ReadBuildInfo()
if !ok {
fmt.Fprintln(os.Stderr, "error reading build info")
os.Exit(1)
}
fmt.Println(bi.Main.Path + "/cmd/oapi-codegen")
fmt.Println(bi.Main.Version)
return
}
if flag.NArg() < 1 {
fmt.Println("Please specify a path to a OpenAPI 3.0 spec file")
os.Exit(1)
}
cfg := configFromFlags()
// If the package name has not been specified, we will use the name of the
// swagger file.
if packageName == "" {
if cfg.PackageName == "" {
path := flag.Arg(0)
baseName := filepath.Base(path)
// Split the base name on '.' to get the first part of the file.
nameParts := strings.Split(baseName, ".")
packageName = codegen.ToCamelCase(nameParts[0])
cfg.PackageName = codegen.ToCamelCase(nameParts[0])
}
opts := codegen.Options{}
for _, g := range splitCSVArg(generate) {
opts := codegen.Options{
AliasTypes: flagAliasTypes,
}
for _, g := range cfg.GenerateTargets {
switch g {
case "client":
opts.GenerateClient = true
@ -92,9 +129,9 @@ func main() {
}
}
opts.IncludeTags = splitCSVArg(includeTags)
opts.ExcludeTags = splitCSVArg(excludeTags)
opts.ExcludeSchemas = splitCSVArg(excludeSchemas)
opts.IncludeTags = cfg.IncludeTags
opts.ExcludeTags = cfg.ExcludeTags
opts.ExcludeSchemas = cfg.ExcludeSchemas
if opts.GenerateEchoServer && opts.GenerateChiServer {
errExit("can not specify both server and chi-server targets simultaneously")
@ -105,26 +142,21 @@ func main() {
errExit("error loading swagger spec\n: %s", err)
}
templates, err := loadTemplateOverrides(templatesDir)
templates, err := loadTemplateOverrides(cfg.TemplatesDir)
if err != nil {
errExit("error loading template overrides: %s\n", err)
}
opts.UserTemplates = templates
if len(importMapping) > 0 {
opts.ImportMapping, err = util.ParseCommandlineMap(importMapping)
if err != nil {
errExit("error parsing import-mapping: %s\n", err)
}
}
opts.ImportMapping = cfg.ImportMapping
code, err := codegen.Generate(swagger, packageName, opts)
code, err := codegen.Generate(swagger, cfg.PackageName, opts)
if err != nil {
errExit("error generating code: %s\n", err)
}
if outputFile != "" {
err = ioutil.WriteFile(outputFile, []byte(code), 0644)
if cfg.OutputFile != "" {
err = ioutil.WriteFile(cfg.OutputFile, []byte(code), 0644)
if err != nil {
errExit("error writing generated code to file: %s", err)
}
@ -133,22 +165,6 @@ func main() {
}
}
func splitCSVArg(input string) []string {
input = strings.TrimSpace(input)
if len(input) == 0 {
return nil
}
splitInput := strings.Split(input, ",")
args := make([]string, 0, len(splitInput))
for _, s := range splitInput {
s = strings.TrimSpace(s)
if len(s) > 0 {
args = append(args, s)
}
}
return args
}
func loadTemplateOverrides(templatesDir string) (map[string]string, error) {
var templates = make(map[string]string)
@ -171,3 +187,53 @@ func loadTemplateOverrides(templatesDir string) (map[string]string, error) {
return templates, nil
}
// configFromFlags parses the flags and the config file. Anything which is
// a zerovalue in the configuration file will be replaced with the flag
// value, this means that the config file overrides flag values.
func configFromFlags() *configuration {
var cfg configuration
// Load the configuration file first.
if flagConfigFile != "" {
f, err := os.Open(flagConfigFile)
if err != nil {
errExit("failed to open config file with error: %v\n", err)
}
defer f.Close()
err = yaml.NewDecoder(f).Decode(&cfg)
if err != nil {
errExit("error parsing config file: %v\n", err)
}
}
if cfg.PackageName == "" {
cfg.PackageName = flagPackageName
}
if cfg.GenerateTargets == nil {
cfg.GenerateTargets = util.ParseCommandLineList(flagGenerate)
}
if cfg.IncludeTags == nil {
cfg.IncludeTags = util.ParseCommandLineList(flagIncludeTags)
}
if cfg.ExcludeTags == nil {
cfg.ExcludeTags = util.ParseCommandLineList(flagExcludeTags)
}
if cfg.TemplatesDir == "" {
cfg.TemplatesDir = flagTemplatesDir
}
if cfg.ImportMapping == nil && flagImportMapping != "" {
var err error
cfg.ImportMapping, err = util.ParseCommandlineMap(flagImportMapping)
if err != nil {
errExit("error parsing import-mapping: %s\n", err)
}
}
if cfg.ExcludeSchemas == nil {
cfg.ExcludeSchemas = util.ParseCommandLineList(flagExcludeSchemas)
}
if cfg.OutputFile == "" {
cfg.OutputFile = flagOutputFile
}
return &cfg
}

View file

@ -18,14 +18,14 @@ import (
"bufio"
"bytes"
"fmt"
"go/format"
"regexp"
"runtime/debug"
"sort"
"strings"
"text/template"
"github.com/getkin/kin-openapi/openapi3"
"github.com/pkg/errors"
"golang.org/x/tools/imports"
"github.com/deepmap/oapi-codegen/pkg/codegen/templates"
)
@ -37,8 +37,9 @@ type Options struct {
GenerateClient bool // GenerateClient specifies whether to generate client boilerplate
GenerateTypes bool // GenerateTypes specifies whether to generate type definitions
EmbedSpec bool // Whether to embed the swagger spec in the generated code
SkipFmt bool // Whether to skip go fmt on the generated code
SkipFmt bool // Whether to skip go imports on the generated code
SkipPrune bool // Whether to skip pruning unused components on the generated code
AliasTypes bool // Whether to alias types if possible
IncludeTags []string // Only include operations that have one of these tags. Ignored when empty.
ExcludeTags []string // Exclude operations that have one of these tags. Ignored when empty.
UserTemplates map[string]string // Override built-in templates from user-provided files
@ -46,54 +47,38 @@ type Options struct {
ExcludeSchemas []string // Exclude from generation schemas with given names. Ignored when empty.
}
// goImport represents a go package to be imported in the generated code
type goImport struct {
lookFor string
alias string
packageName string
Name string // package name
Path string // package path
}
func (i goImport) String() string {
if i.alias != "" {
return fmt.Sprintf("%s %q", i.alias, i.packageName)
// String returns a go import statement
func (gi goImport) String() string {
if gi.Name != "" {
return fmt.Sprintf("%s %q", gi.Name, gi.Path)
}
return fmt.Sprintf("%q", i.packageName)
return fmt.Sprintf("%q", gi.Path)
}
type goImports []goImport
// importMap maps external OpenAPI specifications files/urls to external go packages
type importMap map[string]goImport
var (
allGoImports = goImports{
{lookFor: "base64\\.", packageName: "encoding/base64"},
{lookFor: "bytes\\.", packageName: "bytes"},
{lookFor: "chi\\.", packageName: "github.com/go-chi/chi"},
{lookFor: "context\\.", packageName: "context"},
{lookFor: "echo\\.", packageName: "github.com/labstack/echo/v4"},
{lookFor: "errors\\.", packageName: "github.com/pkg/errors"},
{lookFor: "fmt\\.", packageName: "fmt"},
{lookFor: "gzip\\.", packageName: "compress/gzip"},
{lookFor: "http\\.", packageName: "net/http"},
{lookFor: "io\\.", packageName: "io"},
{lookFor: "ioutil\\.", packageName: "io/ioutil"},
{lookFor: "json\\.", packageName: "encoding/json"},
{lookFor: "openapi3\\.", packageName: "github.com/getkin/kin-openapi/openapi3"},
{lookFor: "openapi_types\\.", alias: "openapi_types", packageName: "github.com/deepmap/oapi-codegen/pkg/types"},
{lookFor: "path\\.", packageName: "path"},
{lookFor: "runtime\\.", packageName: "github.com/deepmap/oapi-codegen/pkg/runtime"},
{lookFor: "strings\\.", packageName: "strings"},
{lookFor: "time\\.Duration", packageName: "time"},
{lookFor: "time\\.Time", packageName: "time"},
{lookFor: "url\\.", packageName: "net/url"},
{lookFor: "xml\\.", packageName: "encoding/xml"},
{lookFor: "yaml\\.", packageName: "gopkg.in/yaml.v2"},
// GoImports returns a slice of go import statements
func (im importMap) GoImports() []string {
goImports := make([]string, 0, len(im))
for _, v := range im {
goImports = append(goImports, v.String())
}
return goImports
}
importMapping = map[string]goImport{}
)
var importMapping importMap
func constructImportMapping(input map[string]string) map[string]goImport {
func constructImportMapping(input map[string]string) importMap {
var (
nameToAlias = map[string]string{}
result = map[string]goImport{}
pathToName = map[string]string{}
result = importMap{}
)
{
@ -103,14 +88,14 @@ func constructImportMapping(input map[string]string) map[string]goImport {
}
sort.Strings(packagePaths)
for _, packageName := range packagePaths {
if _, ok := nameToAlias[packageName]; !ok {
nameToAlias[packageName] = fmt.Sprintf("externalRef%d", len(nameToAlias))
for _, packagePath := range packagePaths {
if _, ok := pathToName[packagePath]; !ok {
pathToName[packagePath] = fmt.Sprintf("externalRef%d", len(pathToName))
}
}
}
for urlOrPath, packageName := range input {
result[urlOrPath] = goImport{alias: nameToAlias[packageName], packageName: packageName}
for specPath, packagePath := range input {
result[specPath] = goImport{Name: pathToName[packagePath], Path: packagePath}
}
return result
}
@ -118,7 +103,7 @@ func constructImportMapping(input map[string]string) map[string]goImport {
// Uses the Go templating engine to generate all of our server wrappers from
// the descriptions we've built up above from the schema objects.
// opts defines
func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (string, error) {
func Generate(swagger *openapi3.T, packageName string, opts Options) (string, error) {
importMapping = constructImportMapping(opts.ImportMapping)
filterOperationsByTag(swagger, opts)
@ -127,6 +112,7 @@ func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (stri
}
// This creates the golang templates text package
TemplateFunctions["opts"] = func() Options { return opts }
t := template.New("oapi-codegen").Funcs(TemplateFunctions)
// This parses all of our own template files into the template object
// above
@ -150,12 +136,18 @@ func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (stri
return "", errors.Wrap(err, "error creating operation definitions")
}
var typeDefinitions string
var typeDefinitions, constantDefinitions string
if opts.GenerateTypes {
typeDefinitions, err = GenerateTypeDefinitions(t, swagger, ops, opts.ExcludeSchemas)
if err != nil {
return "", errors.Wrap(err, "error generating type definitions")
}
constantDefinitions, err = GenerateConstants(t, ops)
if err != nil {
return "", errors.Wrap(err, "error generating constants")
}
}
var echoServerOut string
@ -192,49 +184,17 @@ func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (stri
var inlinedSpec string
if opts.EmbedSpec {
inlinedSpec, err = GenerateInlinedSpec(t, swagger)
inlinedSpec, err = GenerateInlinedSpec(t, importMapping, swagger)
if err != nil {
return "", errors.Wrap(err, "error generating Go handlers for Paths")
}
}
// Imports needed for the generated code to compile
var imports []string
for _, importGo := range importMapping {
imports = append(imports, importGo.String())
}
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
// Based on module prefixes, figure out which optional imports are required.
pkgs := make(map[string]int)
for _, str := range []string{typeDefinitions, chiServerOut, echoServerOut, clientOut, clientWithResponsesOut, inlinedSpec} {
for _, line := range strings.Split(strings.TrimSpace(str), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "//") {
continue
}
for _, goImport := range allGoImports {
match, err := regexp.MatchString(fmt.Sprintf("[^a-zA-Z0-9_]%s", goImport.lookFor), line)
if err != nil {
return "", errors.Wrap(err, "error figuring out imports")
}
if match {
pkgs[goImport.String()]++
}
}
}
}
for k := range pkgs {
imports = append(imports, k)
}
importsOut, err := GenerateImports(t, imports, packageName)
externalImports := importMapping.GoImports()
importsOut, err := GenerateImports(t, externalImports, packageName)
if err != nil {
return "", errors.Wrap(err, "error generating imports")
}
@ -244,6 +204,11 @@ func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (stri
return "", errors.Wrap(err, "error writing imports")
}
_, err = w.WriteString(constantDefinitions)
if err != nil {
return "", errors.Wrap(err, "error writing constants")
}
_, err = w.WriteString(typeDefinitions)
if err != nil {
return "", errors.Wrap(err, "error writing type definitions")
@ -290,12 +255,13 @@ func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (stri
// remove any byte-order-marks which break Go-Code
goCode := SanitizeCode(buf.String())
// The generation code produces unindented horrors. Use the Go formatter
// The generation code produces unindented horrors. Use the Go Imports
// to make it all pretty.
if opts.SkipFmt {
return goCode, nil
}
outBytes, err := format.Source([]byte(goCode))
outBytes, err := imports.Process(packageName+".go", []byte(goCode), nil)
if err != nil {
fmt.Println(goCode)
return "", errors.Wrap(err, "error formatting Go code")
@ -303,7 +269,7 @@ func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (stri
return string(outBytes), nil
}
func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.Swagger, ops []OperationDefinition, excludeSchemas []string) (string, error) {
func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.T, ops []OperationDefinition, excludeSchemas []string) (string, error) {
schemaTypes, err := GenerateTypesForSchemas(t, swagger.Components.Schemas, excludeSchemas)
if err != nil {
return "", errors.Wrap(err, "error generating Go types for component schemas")
@ -332,6 +298,11 @@ func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.Swagger, op
return "", errors.Wrap(err, "error generating Go types for operation parameters")
}
enumsOut, err := GenerateEnums(t, allTypes)
if err != nil {
return "", errors.Wrap(err, "error generating code for type enums")
}
typesOut, err := GenerateTypes(t, allTypes)
if err != nil {
return "", errors.Wrap(err, "error generating code for type definitions")
@ -342,10 +313,50 @@ func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.Swagger, op
return "", errors.Wrap(err, "error generating allOf boilerplate")
}
typeDefinitions := strings.Join([]string{typesOut, paramTypesOut, allOfBoilerplate}, "")
typeDefinitions := strings.Join([]string{enumsOut, typesOut, paramTypesOut, allOfBoilerplate}, "")
return typeDefinitions, nil
}
// Generates operation ids, context keys, paths, etc. to be exported as constants
func GenerateConstants(t *template.Template, ops []OperationDefinition) (string, error) {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
constants := Constants{
SecuritySchemeProviderNames: []string{},
}
providerNameMap := map[string]struct{}{}
for _, op := range ops {
for _, def := range op.SecurityDefinitions {
providerName := SanitizeGoIdentity(def.ProviderName)
providerNameMap[providerName] = struct{}{}
}
}
var providerNames []string
for providerName := range providerNameMap {
providerNames = append(providerNames, providerName)
}
sort.Strings(providerNames)
for _, providerName := range providerNames {
constants.SecuritySchemeProviderNames = append(constants.SecuritySchemeProviderNames, providerName)
}
err := t.ExecuteTemplate(w, "constants.tmpl", constants)
if err != nil {
return "", fmt.Errorf("error generating server interface: %s", err)
}
err = w.Flush()
if err != nil {
return "", fmt.Errorf("error flushing output buffer for server interface: %s", err)
}
return buf.String(), nil
}
// Generates type definitions for any custom types defined in the
// components/schemas section of the Swagger spec.
func GenerateTypesForSchemas(t *template.Template, schemas map[string]*openapi3.SchemaRef, excludeSchemas []string) ([]TypeDefinition, error) {
@ -509,18 +520,64 @@ func GenerateTypes(t *template.Template, types []TypeDefinition) (string, error)
return buf.String(), nil
}
// Generate our import statements and package definition.
func GenerateImports(t *template.Template, imports []string, packageName string) (string, error) {
sort.Strings(imports)
func GenerateEnums(t *template.Template, types []TypeDefinition) (string, error) {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
c := Constants{
EnumDefinitions: []EnumDefinition{},
}
for _, tp := range types {
if len(tp.Schema.EnumValues) > 0 {
wrapper := ""
if tp.Schema.GoType == "string" {
wrapper = `"`
}
c.EnumDefinitions = append(c.EnumDefinitions, EnumDefinition{
Schema: tp.Schema,
TypeName: tp.TypeName,
ValueWrapper: wrapper,
})
}
}
err := t.ExecuteTemplate(w, "constants.tmpl", c)
if err != nil {
return "", errors.Wrap(err, "error generating enums")
}
err = w.Flush()
if err != nil {
return "", errors.Wrap(err, "error flushing output buffer for enums")
}
return buf.String(), nil
}
// Generate our import statements and package definition.
func GenerateImports(t *template.Template, externalImports []string, packageName string) (string, error) {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
// Read build version for incorporating into generated files
var modulePath string
var moduleVersion string
if bi, ok := debug.ReadBuildInfo(); ok {
modulePath = bi.Main.Path
moduleVersion = bi.Main.Version
} else {
// Unit tests have ok=false, so we'll just use "unknown" for the
// version if we can't read this.
modulePath = "unknown module path"
moduleVersion = "unknown version"
}
context := struct {
Imports []string
PackageName string
ExternalImports []string
PackageName string
ModuleName string
Version string
}{
Imports: imports,
PackageName: packageName,
ExternalImports: externalImports,
PackageName: packageName,
ModuleName: modulePath,
Version: moduleVersion,
}
err := t.ExecuteTemplate(w, "imports.tmpl", context)
if err != nil {

View file

@ -8,7 +8,9 @@ import (
)
const (
extPropGoType = "x-go-type"
extPropGoType = "x-go-type"
extPropOmitEmpty = "x-omitempty"
extPropExtraTags = "x-oapi-codegen-extra-tags"
)
func extTypeName(extPropValue interface{}) (string, error) {
@ -23,3 +25,29 @@ func extTypeName(extPropValue interface{}) (string, error) {
return name, nil
}
func extParseOmitEmpty(extPropValue interface{}) (bool, error) {
raw, ok := extPropValue.(json.RawMessage)
if !ok {
return false, fmt.Errorf("failed to convert type: %T", extPropValue)
}
var omitEmpty bool
if err := json.Unmarshal(raw, &omitEmpty); err != nil {
return false, errors.Wrap(err, "failed to unmarshal json")
}
return omitEmpty, nil
}
func extExtraTags(extPropValue interface{}) (map[string]string, error) {
raw, ok := extPropValue.(json.RawMessage)
if !ok {
return nil, fmt.Errorf("failed to convert type: %T", extPropValue)
}
var tags map[string]string
if err := json.Unmarshal(raw, &tags); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal json")
}
return tags, nil
}

View file

@ -2,7 +2,7 @@ package codegen
import "github.com/getkin/kin-openapi/openapi3"
func filterOperationsByTag(swagger *openapi3.Swagger, opts Options) {
func filterOperationsByTag(swagger *openapi3.T, opts Options) {
if len(opts.ExcludeTags) > 0 {
excludeOperationsWithTags(swagger.Paths, opts.ExcludeTags)
}

View file

@ -26,7 +26,7 @@ import (
// This generates a gzipped, base64 encoded JSON representation of the
// swagger definition, which we embed inside the generated code.
func GenerateInlinedSpec(t *template.Template, swagger *openapi3.Swagger) (string, error) {
func GenerateInlinedSpec(t *template.Template, importMapping importMap, swagger *openapi3.T) (string, error) {
// Marshal to json
encoded, err := swagger.MarshalJSON()
if err != nil {
@ -65,7 +65,10 @@ func GenerateInlinedSpec(t *template.Template, swagger *openapi3.Swagger) (strin
// Generate inline code.
buf.Reset()
w := bufio.NewWriter(&buf)
err = t.ExecuteTemplate(w, "inline.tmpl", parts)
err = t.ExecuteTemplate(w, "inline.tmpl", struct {
SpecParts []string
ImportMapping importMap
}{SpecParts: parts, ImportMapping: importMapping})
if err != nil {
return "", fmt.Errorf("error generating inlined spec: %s", err)
}

View file

@ -120,7 +120,7 @@ func (pd ParameterDefinition) GoVariableName() string {
}
func (pd ParameterDefinition) GoName() string {
return ToCamelCase(pd.ParamName)
return SchemaNameToTypeName(pd.ParamName)
}
func (pd ParameterDefinition) IndirectOptional() bool {
@ -163,7 +163,7 @@ func DescribeParameters(params openapi3.Parameters, path []string) ([]ParameterD
// If this is a reference to a predefined type, simply use the reference
// name as the type. $ref: "#/components/schemas/custom_type" becomes
// "CustomType".
if paramOrRef.Ref != "" {
if IsGoTypeReference(paramOrRef.Ref) {
goType, err := RefPathToGoType(paramOrRef.Ref)
if err != nil {
return nil, fmt.Errorf("error dereferencing (%s) for param (%s): %s",
@ -185,7 +185,8 @@ func DescribeSecurityDefinition(securityRequirements openapi3.SecurityRequiremen
outDefs := make([]SecurityDefinition, 0)
for _, sr := range securityRequirements {
for k, v := range sr {
for _, k := range SortedSecurityRequirementKeys(sr) {
v := sr[k]
outDefs = append(outDefs, SecurityDefinition{ProviderName: k, Scopes: v})
}
}
@ -258,8 +259,8 @@ func (o *OperationDefinition) SummaryAsComment() string {
// types which we know how to parse. These will be turned into fields on a
// response object for automatic deserialization of responses in the generated
// Client code. See "client-with-responses.tmpl".
func (o *OperationDefinition) GetResponseTypeDefinitions() ([]TypeDefinition, error) {
var tds []TypeDefinition
func (o *OperationDefinition) GetResponseTypeDefinitions() ([]ResponseTypeDefinition, error) {
var tds []ResponseTypeDefinition
responses := o.Spec.Responses
sortedResponsesKeys := SortedResponsesKeys(responses)
@ -292,12 +293,15 @@ func (o *OperationDefinition) GetResponseTypeDefinitions() ([]TypeDefinition, er
continue
}
td := TypeDefinition{
TypeName: typeName,
Schema: responseSchema,
ResponseName: responseName,
td := ResponseTypeDefinition{
TypeDefinition: TypeDefinition{
TypeName: typeName,
Schema: responseSchema,
},
ResponseName: responseName,
ContentTypeName: contentTypeName,
}
if contentType.Schema.Ref != "" {
if IsGoTypeReference(contentType.Schema.Ref) {
refType, err := RefPathToGoType(contentType.Schema.Ref)
if err != nil {
return nil, errors.Wrap(err, "error dereferencing response Ref")
@ -333,8 +337,11 @@ type RequestBodyDefinition struct {
}
// Returns the Go type definition for a request body
func (r RequestBodyDefinition) TypeDef() string {
return r.Schema.TypeDecl()
func (r RequestBodyDefinition) TypeDef(opID string) *TypeDefinition {
return &TypeDefinition{
TypeName: fmt.Sprintf("%s%sRequestBody", opID, r.NameTag),
Schema: r.Schema,
}
}
// Returns whether the body is a custom inline type, or pre-defined. This is
@ -368,7 +375,7 @@ func FilterParameterDefinitionByType(params []ParameterDefinition, in string) []
}
// OperationDefinitions returns all operations for a swagger definition.
func OperationDefinitions(swagger *openapi3.Swagger) ([]OperationDefinition, error) {
func OperationDefinitions(swagger *openapi3.T) ([]OperationDefinition, error) {
var operations []OperationDefinition
for _, requestPath := range SortedPathsKeys(swagger.Paths) {
@ -385,6 +392,9 @@ func OperationDefinitions(swagger *openapi3.Swagger) ([]OperationDefinition, err
pathOps := pathItem.Operations()
for _, opName := range SortedOperationsKeys(pathOps) {
op := pathOps[opName]
if pathItem.Servers != nil {
op.Servers = &pathItem.Servers
}
// We rely on OperationID to generate function names, it's required
if op.OperationID == "" {
op.OperationID, err = generateDefaultOperationID(opName, requestPath)
@ -514,7 +524,7 @@ func GenerateBodyDefinitions(operationID string, bodyOrRef *openapi3.RequestBody
}
// If the body is a pre-defined type
if bodyOrRef.Ref != "" {
if IsGoTypeReference(bodyOrRef.Ref) {
// Convert the reference path to Go type
refType, err := RefPathToGoType(bodyOrRef.Ref)
if err != nil {
@ -589,14 +599,16 @@ func GenerateParamsTypes(op OperationDefinition) []TypeDefinition {
})
}
prop := Property{
Description: param.Spec.Description,
JsonFieldName: param.ParamName,
Required: param.Required,
Schema: pSchema,
Description: param.Spec.Description,
JsonFieldName: param.ParamName,
Required: param.Required,
Schema: pSchema,
ExtensionProps: &param.Spec.ExtensionProps,
}
s.Properties = append(s.Properties, prop)
}
s.Description = op.Spec.Description
s.GoType = GenStructFromSchema(s)
td := TypeDefinition{

View file

@ -21,7 +21,7 @@ type RefWrapper struct {
SourceRef interface{}
}
func walkSwagger(swagger *openapi3.Swagger, doFn func(RefWrapper) (bool, error)) error {
func walkSwagger(swagger *openapi3.T, doFn func(RefWrapper) (bool, error)) error {
if swagger == nil {
return nil
}
@ -377,7 +377,7 @@ func walkExampleRef(ref *openapi3.ExampleRef, doFn func(RefWrapper) (bool, error
return nil
}
func findComponentRefs(swagger *openapi3.Swagger) []string {
func findComponentRefs(swagger *openapi3.T) []string {
refs := []string{}
walkSwagger(swagger, func(ref RefWrapper) (bool, error) {
@ -391,7 +391,7 @@ func findComponentRefs(swagger *openapi3.Swagger) []string {
return refs
}
func removeOrphanedComponents(swagger *openapi3.Swagger, refs []string) int {
func removeOrphanedComponents(swagger *openapi3.T, refs []string) int {
countRemoved := 0
for key, _ := range swagger.Components.Schemas {
@ -472,7 +472,7 @@ func removeOrphanedComponents(swagger *openapi3.Swagger, refs []string) int {
return countRemoved
}
func pruneUnusedComponents(swagger *openapi3.Swagger) {
func pruneUnusedComponents(swagger *openapi3.T) {
for {
refs := findComponentRefs(swagger)
countRemoved := removeOrphanedComponents(swagger, refs)

View file

@ -13,6 +13,8 @@ type Schema struct {
GoType string // The Go type needed to represent the schema
RefType string // If the type has a type name, this is set
ArrayType *Schema // The schema of array element
EnumValues map[string]string // Enum values
Properties []Property // For an object, the fields with names
@ -21,6 +23,11 @@ type Schema struct {
AdditionalTypes []TypeDefinition // We may need to generate auxiliary helper types, stored here
SkipOptionalPointer bool // Some types don't need a * in front when they're optional
Description string // The description of the element
// The original OpenAPIv3 Schema.
OAPISchema *openapi3.Schema
}
func (s Schema) IsRef() bool {
@ -55,11 +62,12 @@ func (s Schema) GetAdditionalTypeDefs() []TypeDefinition {
}
type Property struct {
Description string
JsonFieldName string
Schema Schema
Required bool
Nullable bool
Description string
JsonFieldName string
Schema Schema
Required bool
Nullable bool
ExtensionProps *openapi3.ExtensionProps
}
func (p Property) GoFieldName() string {
@ -74,11 +82,56 @@ func (p Property) GoTypeDef() string {
return typeDef
}
type TypeDefinition struct {
TypeName string
JsonName string
ResponseName string
// EnumDefinition holds type information for enum
type EnumDefinition struct {
Schema Schema
TypeName string
ValueWrapper string
}
type Constants struct {
// SecuritySchemeProviderNames holds all provider names for security schemes.
SecuritySchemeProviderNames []string
// EnumDefinitions holds type and value information for all enums
EnumDefinitions []EnumDefinition
}
// TypeDefinition describes a Go type definition in generated code.
//
// Let's use this example schema:
// components:
// schemas:
// Person:
// type: object
// properties:
// name:
// type: string
type TypeDefinition struct {
// The name of the type, eg, type <...> Person
TypeName string
// The name of the corresponding JSON description, as it will sometimes
// differ due to invalid characters.
JsonName string
// This is the Schema wrapper is used to populate the type description
Schema Schema
}
// ResponseTypeDefinition is an extension of TypeDefinition, specifically for
// response unmarshaling in ClientWithResponses.
type ResponseTypeDefinition struct {
TypeDefinition
// The content type name where this is used, eg, application/json
ContentTypeName string
// The type name of a response model.
ResponseName string
}
func (t *TypeDefinition) CanAlias() bool {
return t.Schema.IsRef() || /* actual reference */
(t.Schema.ArrayType != nil && t.Schema.ArrayType.IsRef()) /* array to ref */
}
func PropertiesEqual(a, b Property) bool {
@ -86,39 +139,44 @@ func PropertiesEqual(a, b Property) bool {
}
func GenerateGoSchema(sref *openapi3.SchemaRef, path []string) (Schema, error) {
// If Ref is set on the SchemaRef, it means that this type is actually a reference to
// another type. We're not de-referencing, so simply use the referenced type.
var refType string
// Add a fallback value in case the sref is nil.
// i.e. the parent schema defines a type:array, but the array has
// no items defined. Therefore we have at least valid Go-Code.
if sref == nil {
return Schema{GoType: "interface{}", RefType: refType}, nil
return Schema{GoType: "interface{}"}, nil
}
schema := sref.Value
if sref.Ref != "" {
var err error
// If Ref is set on the SchemaRef, it means that this type is actually a reference to
// another type. We're not de-referencing, so simply use the referenced type.
if IsGoTypeReference(sref.Ref) {
// Convert the reference path to Go type
refType, err = RefPathToGoType(sref.Ref)
refType, err := RefPathToGoType(sref.Ref)
if err != nil {
return Schema{}, fmt.Errorf("error turning reference (%s) into a Go type: %s",
sref.Ref, err)
}
return Schema{
GoType: refType,
GoType: refType,
Description: StringToGoComment(schema.Description),
}, nil
}
outSchema := Schema{
Description: StringToGoComment(schema.Description),
OAPISchema: schema,
}
// We can't support this in any meaningful way
if schema.AnyOf != nil {
return Schema{GoType: "interface{}", RefType: refType}, nil
outSchema.GoType = "interface{}"
return outSchema, nil
}
// We can't support this in any meaningful way
if schema.OneOf != nil {
return Schema{GoType: "interface{}", RefType: refType}, nil
outSchema.GoType = "interface{}"
return outSchema, nil
}
// AllOf is interesting, and useful. It's the union of a number of other
@ -130,14 +188,10 @@ func GenerateGoSchema(sref *openapi3.SchemaRef, path []string) (Schema, error) {
if err != nil {
return Schema{}, errors.Wrap(err, "error merging schemas")
}
mergedSchema.RefType = refType
mergedSchema.OAPISchema = schema
return mergedSchema, nil
}
outSchema := Schema{
RefType: refType,
}
// Check for custom Go type extension
if extension, ok := schema.Extensions[extPropGoType]; ok {
typeName, err := extTypeName(extension)
@ -200,11 +254,12 @@ func GenerateGoSchema(sref *openapi3.SchemaRef, path []string) (Schema, error) {
description = p.Value.Description
}
prop := Property{
JsonFieldName: pName,
Schema: pSchema,
Required: required,
Description: description,
Nullable: p.Value.Nullable,
JsonFieldName: pName,
Schema: pSchema,
Required: required,
Description: description,
Nullable: p.Value.Nullable,
ExtensionProps: &p.Value.ExtensionProps,
}
outSchema.Properties = append(outSchema.Properties, prop)
}
@ -224,76 +279,133 @@ func GenerateGoSchema(sref *openapi3.SchemaRef, path []string) (Schema, error) {
outSchema.GoType = GenStructFromSchema(outSchema)
}
return outSchema, nil
} else if len(schema.Enum) > 0 {
err := resolveType(schema, path, &outSchema)
if err != nil {
return Schema{}, errors.Wrap(err, "error resolving primitive type")
}
enumValues := make([]string, len(schema.Enum))
for i, enumValue := range schema.Enum {
enumValues[i] = fmt.Sprintf("%v", enumValue)
}
sanitizedValues := SanitizeEnumNames(enumValues)
outSchema.EnumValues = make(map[string]string, len(sanitizedValues))
var constNamePath []string
for k, v := range sanitizedValues {
if v == "" {
constNamePath = append(path, "Empty")
} else {
constNamePath = append(path, k)
}
outSchema.EnumValues[SchemaNameToTypeName(PathToTypeName(constNamePath))] = v
}
if len(path) > 1 { // handle additional type only on non-toplevel types
typeName := SchemaNameToTypeName(PathToTypeName(path))
typeDef := TypeDefinition{
TypeName: typeName,
JsonName: strings.Join(path, "."),
Schema: outSchema,
}
outSchema.AdditionalTypes = append(outSchema.AdditionalTypes, typeDef)
outSchema.RefType = typeName
}
//outSchema.RefType = typeName
} else {
f := schema.Format
switch t {
case "array":
// For arrays, we'll get the type of the Items and throw a
// [] in front of it.
arrayType, err := GenerateGoSchema(schema.Items, path)
if err != nil {
return Schema{}, errors.Wrap(err, "error generating type for array")
}
outSchema.GoType = "[]" + arrayType.TypeDecl()
outSchema.Properties = arrayType.Properties
case "integer":
// We default to int if format doesn't ask for something else.
if f == "int64" {
outSchema.GoType = "int64"
} else if f == "int32" {
outSchema.GoType = "int32"
} else if f == "" {
outSchema.GoType = "int"
} else {
return Schema{}, fmt.Errorf("invalid integer format: %s", f)
}
case "number":
// We default to float for "number"
if f == "double" {
outSchema.GoType = "float64"
} else if f == "float" || f == "" {
outSchema.GoType = "float32"
} else {
return Schema{}, fmt.Errorf("invalid number format: %s", f)
}
case "boolean":
if f != "" {
return Schema{}, fmt.Errorf("invalid format (%s) for boolean", f)
}
outSchema.GoType = "bool"
case "string":
enumValues := make([]string, len(schema.Enum))
for i, enumValue := range schema.Enum {
enumValues[i] = enumValue.(string)
}
outSchema.EnumValues = SanitizeEnumNames(enumValues)
// Special case string formats here.
switch f {
case "byte":
outSchema.GoType = "[]byte"
case "email":
outSchema.GoType = "openapi_types.Email"
case "date":
outSchema.GoType = "openapi_types.Date"
case "date-time":
outSchema.GoType = "time.Time"
case "json":
outSchema.GoType = "json.RawMessage"
outSchema.SkipOptionalPointer = true
default:
// All unrecognized formats are simply a regular string.
outSchema.GoType = "string"
}
default:
return Schema{}, fmt.Errorf("unhandled Schema type: %s", t)
err := resolveType(schema, path, &outSchema)
if err != nil {
return Schema{}, errors.Wrap(err, "error resolving primitive type")
}
}
return outSchema, nil
}
// resolveType resolves primitive type or array for schema
func resolveType(schema *openapi3.Schema, path []string, outSchema *Schema) error {
f := schema.Format
t := schema.Type
switch t {
case "array":
// For arrays, we'll get the type of the Items and throw a
// [] in front of it.
arrayType, err := GenerateGoSchema(schema.Items, path)
if err != nil {
return errors.Wrap(err, "error generating type for array")
}
outSchema.ArrayType = &arrayType
outSchema.GoType = "[]" + arrayType.TypeDecl()
additionalTypes := arrayType.GetAdditionalTypeDefs()
// Check also types defined in array item
if len(additionalTypes) > 0 {
outSchema.AdditionalTypes = append(outSchema.AdditionalTypes, additionalTypes...)
}
outSchema.Properties = arrayType.Properties
case "integer":
// We default to int if format doesn't ask for something else.
if f == "int64" {
outSchema.GoType = "int64"
} else if f == "int32" {
outSchema.GoType = "int32"
} else if f == "int16" {
outSchema.GoType = "int16"
} else if f == "int8" {
outSchema.GoType = "int8"
} else if f == "int" {
outSchema.GoType = "int"
} else if f == "uint64" {
outSchema.GoType = "uint64"
} else if f == "uint32" {
outSchema.GoType = "uint32"
} else if f == "uint16" {
outSchema.GoType = "uint16"
} else if f == "uint8" {
outSchema.GoType = "uint8"
} else if f == "uint" {
outSchema.GoType = "uint"
} else if f == "" {
outSchema.GoType = "int"
} else {
return fmt.Errorf("invalid integer format: %s", f)
}
case "number":
// We default to float for "number"
if f == "double" {
outSchema.GoType = "float64"
} else if f == "float" || f == "" {
outSchema.GoType = "float32"
} else {
return fmt.Errorf("invalid number format: %s", f)
}
case "boolean":
if f != "" {
return fmt.Errorf("invalid format (%s) for boolean", f)
}
outSchema.GoType = "bool"
case "string":
// Special case string formats here.
switch f {
case "byte":
outSchema.GoType = "[]byte"
case "email":
outSchema.GoType = "openapi_types.Email"
case "date":
outSchema.GoType = "openapi_types.Date"
case "date-time":
outSchema.GoType = "time.Time"
case "json":
outSchema.GoType = "json.RawMessage"
outSchema.SkipOptionalPointer = true
default:
// All unrecognized formats are simply a regular string.
outSchema.GoType = "string"
}
default:
return fmt.Errorf("unhandled Schema type: %s", t)
}
return nil
}
// This describes a Schema, a type definition.
type SchemaDescriptor struct {
Fields []FieldDescriptor
@ -313,20 +425,49 @@ type FieldDescriptor struct {
// JSON annotations
func GenFieldsFromProperties(props []Property) []string {
var fields []string
for _, p := range props {
for i, p := range props {
field := ""
// Add a comment to a field in case we have one, otherwise skip.
if p.Description != "" {
// Separate the comment from a previous-defined, unrelated field.
// Make sure the actual field is separated by a newline.
field += fmt.Sprintf("\n%s\n", StringToGoComment(p.Description))
if i != 0 {
field += "\n"
}
field += fmt.Sprintf("%s\n", StringToGoComment(p.Description))
}
field += fmt.Sprintf(" %s %s", p.GoFieldName(), p.GoTypeDef())
if p.Required || p.Nullable {
field += fmt.Sprintf(" `json:\"%s\"`", p.JsonFieldName)
} else {
field += fmt.Sprintf(" `json:\"%s,omitempty\"`", p.JsonFieldName)
// Support x-omitempty
omitEmpty := true
if _, ok := p.ExtensionProps.Extensions[extPropOmitEmpty]; ok {
if extOmitEmpty, err := extParseOmitEmpty(p.ExtensionProps.Extensions[extPropOmitEmpty]); err == nil {
omitEmpty = extOmitEmpty
}
}
fieldTags := make(map[string]string)
if p.Required || p.Nullable || !omitEmpty {
fieldTags["json"] = p.JsonFieldName
} else {
fieldTags["json"] = p.JsonFieldName + ",omitempty"
}
if extension, ok := p.ExtensionProps.Extensions[extPropExtraTags]; ok {
if tags, err := extExtraTags(extension); err == nil {
keys := SortedStringKeys(tags)
for _, k := range keys {
fieldTags[k] = tags[k]
}
}
}
// Convert the fieldTags map into Go field annotations.
keys := SortedStringKeys(fieldTags)
tags := make([]string, len(keys))
for i, k := range keys {
tags[i] = fmt.Sprintf(`%s:"%s"`, k, fieldTags[k])
}
field += "`" + strings.Join(tags, " ") + "`"
fields = append(fields, field)
}
return fields
@ -359,7 +500,7 @@ func MergeSchemas(allOf []*openapi3.SchemaRef, path []string) (Schema, error) {
var refType string
var err error
if ref != "" {
if IsGoTypeReference(ref) {
refType, err = RefPathToGoType(ref)
if err != nil {
return Schema{}, errors.Wrap(err, "error converting reference path to a go type")
@ -412,7 +553,7 @@ func GenStructFromAllOf(allOf []*openapi3.SchemaRef, path []string) (string, err
objectParts := []string{"struct {"}
for _, schemaOrRef := range allOf {
ref := schemaOrRef.Ref
if ref != "" {
if IsGoTypeReference(ref) {
// We have a referenced type, we will generate an inlined struct
// member.
// struct {
@ -426,7 +567,7 @@ func GenStructFromAllOf(allOf []*openapi3.SchemaRef, path []string) (string, err
objectParts = append(objectParts,
fmt.Sprintf(" // Embedded struct due to allOf(%s)", ref))
objectParts = append(objectParts,
fmt.Sprintf(" %s", goType))
fmt.Sprintf(" %s `yaml:\",inline\"`", goType))
} else {
// Inline all the fields from the schema into the output struct,
// just like in the simple case of generating an object.
@ -471,7 +612,8 @@ func paramToGoType(param *openapi3.Parameter, path []string) (Schema, error) {
// so we'll return the parameter as a string, not bothering to decode it.
if len(param.Content) > 1 {
return Schema{
GoType: "string",
GoType: "string",
Description: StringToGoComment(param.Description),
}, nil
}
@ -480,7 +622,8 @@ func paramToGoType(param *openapi3.Parameter, path []string) (Schema, error) {
if !found {
// If we don't have json, it's a string
return Schema{
GoType: "string",
GoType: "string",
Description: StringToGoComment(param.Description),
}, nil
}

View file

@ -80,10 +80,6 @@ func genParamNames(params []ParameterDefinition) string {
return ", " + strings.Join(parts, ", ")
}
func genParamFmtString(path string) string {
return ReplacePathParamsWithStr(path)
}
// genResponsePayload generates the payload returned at the end of each client request function
func genResponsePayload(operationID string) string {
var buffer = bytes.NewBufferString("")
@ -99,8 +95,6 @@ func genResponsePayload(operationID string) string {
// genResponseUnmarshal generates unmarshaling steps for structured response payloads
func genResponseUnmarshal(op *OperationDefinition) string {
var buffer = bytes.NewBufferString("")
var handledCaseClauses = make(map[string]string)
var unhandledCaseClauses = make(map[string]string)
@ -110,7 +104,13 @@ func genResponseUnmarshal(op *OperationDefinition) string {
panic(err)
}
if len(typeDefinitions) == 0 {
// No types.
return ""
}
// Add a case for each possible response:
buffer := new(bytes.Buffer)
responses := op.Spec.Responses
for _, typeDefinition := range typeDefinitions {
@ -128,13 +128,8 @@ func genResponseUnmarshal(op *OperationDefinition) string {
// If there is no content-type then we have no unmarshaling to do:
if len(responseRef.Value.Content) == 0 {
caseAction := "break // No content-type"
if typeDefinition.ResponseName == "default" {
caseClauseKey := "default:"
unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
} else {
caseClauseKey := fmt.Sprintf("case rsp.StatusCode == %s:", typeDefinition.ResponseName)
unhandledCaseClauses[prefixLessSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
}
caseClauseKey := "case " + getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) + ":"
unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
continue
}
@ -153,59 +148,65 @@ func genResponseUnmarshal(op *OperationDefinition) string {
// JSON:
case StringInArray(contentTypeName, contentTypesJSON):
var caseAction string
if typeDefinition.ContentTypeName == contentTypeName {
var caseAction string
caseAction = fmt.Sprintf("var dest %s\n"+
"if err := json.Unmarshal(bodyBytes, &dest); err != nil { \n"+
" return nil, err \n"+
"}\n"+
"response.%s = &dest",
typeDefinition.Schema.TypeDecl(),
typeDefinition.TypeName)
caseAction = fmt.Sprintf("var dest %s\n"+
"if err := json.Unmarshal(bodyBytes, &dest); err != nil { \n"+
" return nil, err \n"+
"}\n"+
"response.%s = &dest",
typeDefinition.Schema.TypeDecl(),
typeDefinition.TypeName)
caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "json")
handledCaseClauses[caseKey] = caseClause
caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "json")
handledCaseClauses[caseKey] = caseClause
}
// YAML:
case StringInArray(contentTypeName, contentTypesYAML):
var caseAction string
caseAction = fmt.Sprintf("var dest %s\n"+
"if err := yaml.Unmarshal(bodyBytes, &dest); err != nil { \n"+
" return nil, err \n"+
"}\n"+
"response.%s = &dest",
typeDefinition.Schema.TypeDecl(),
typeDefinition.TypeName)
caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "yaml")
handledCaseClauses[caseKey] = caseClause
if typeDefinition.ContentTypeName == contentTypeName {
var caseAction string
caseAction = fmt.Sprintf("var dest %s\n"+
"if err := yaml.Unmarshal(bodyBytes, &dest); err != nil { \n"+
" return nil, err \n"+
"}\n"+
"response.%s = &dest",
typeDefinition.Schema.TypeDecl(),
typeDefinition.TypeName)
caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "yaml")
handledCaseClauses[caseKey] = caseClause
}
// XML:
case StringInArray(contentTypeName, contentTypesXML):
var caseAction string
caseAction = fmt.Sprintf("var dest %s\n"+
"if err := xml.Unmarshal(bodyBytes, &dest); err != nil { \n"+
" return nil, err \n"+
"}\n"+
"response.%s = &dest",
typeDefinition.Schema.TypeDecl(),
typeDefinition.TypeName)
caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "xml")
handledCaseClauses[caseKey] = caseClause
if typeDefinition.ContentTypeName == contentTypeName {
var caseAction string
caseAction = fmt.Sprintf("var dest %s\n"+
"if err := xml.Unmarshal(bodyBytes, &dest); err != nil { \n"+
" return nil, err \n"+
"}\n"+
"response.%s = &dest",
typeDefinition.Schema.TypeDecl(),
typeDefinition.TypeName)
caseKey, caseClause := buildUnmarshalCase(typeDefinition, caseAction, "xml")
handledCaseClauses[caseKey] = caseClause
}
// Everything else:
default:
caseAction := fmt.Sprintf("// Content-type (%s) unsupported", contentTypeName)
if typeDefinition.ResponseName == "default" {
caseClauseKey := "default:"
unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
} else {
caseClauseKey := fmt.Sprintf("case rsp.StatusCode == %s:", typeDefinition.ResponseName)
unhandledCaseClauses[prefixLessSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
}
caseClauseKey := "case " + getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName) + ":"
unhandledCaseClauses[prefixLeastSpecific+caseClauseKey] = fmt.Sprintf("%s\n%s\n", caseClauseKey, caseAction)
}
}
}
if len(handledCaseClauses)+len(unhandledCaseClauses) == 0 {
// switch would be empty.
return ""
}
// Now build the switch statement in order of most-to-least specific:
// See: https://github.com/deepmap/oapi-codegen/issues/127 for why we handle this in two separate
// groups.
@ -224,13 +225,10 @@ func genResponseUnmarshal(op *OperationDefinition) string {
}
// buildUnmarshalCase builds an unmarshalling case clause for different content-types:
func buildUnmarshalCase(typeDefinition TypeDefinition, caseAction string, contentType string) (caseKey string, caseClause string) {
func buildUnmarshalCase(typeDefinition ResponseTypeDefinition, caseAction string, contentType string) (caseKey string, caseClause string) {
caseKey = fmt.Sprintf("%s.%s.%s", prefixLeastSpecific, contentType, typeDefinition.ResponseName)
if typeDefinition.ResponseName == "default" {
caseClause = fmt.Sprintf("case strings.Contains(rsp.Header.Get(\"%s\"), \"%s\"):\n%s\n", echo.HeaderContentType, contentType, caseAction)
} else {
caseClause = fmt.Sprintf("case strings.Contains(rsp.Header.Get(\"%s\"), \"%s\") && rsp.StatusCode == %s:\n%s\n", echo.HeaderContentType, contentType, typeDefinition.ResponseName, caseAction)
}
caseClauseKey := getConditionOfResponseName("rsp.StatusCode", typeDefinition.ResponseName)
caseClause = fmt.Sprintf("case strings.Contains(rsp.Header.Get(\"%s\"), \"%s\") && %s:\n%s\n", echo.HeaderContentType, contentType, caseClauseKey, caseAction)
return caseKey, caseClause
}
@ -239,7 +237,7 @@ func genResponseTypeName(operationID string) string {
return fmt.Sprintf("%s%s", UppercaseFirstCharacter(operationID), responseTypeSuffix)
}
func getResponseTypeDefinitions(op *OperationDefinition) []TypeDefinition {
func getResponseTypeDefinitions(op *OperationDefinition) []ResponseTypeDefinition {
td, err := op.GetResponseTypeDefinitions()
if err != nil {
panic(err)
@ -247,6 +245,18 @@ func getResponseTypeDefinitions(op *OperationDefinition) []TypeDefinition {
return td
}
// Return the statusCode comparison clause from the response name.
func getConditionOfResponseName(statusCodeVar, responseName string) string {
switch responseName {
case "default":
return "true"
case "1XX", "2XX", "3XX", "4XX", "5XX":
return fmt.Sprintf("%s / 100 == %s", statusCodeVar, responseName[:1])
default:
return fmt.Sprintf("%s == %s", statusCodeVar, responseName)
}
}
// This outputs a string array
func toStringArray(sarr []string) string {
return `[]string{"` + strings.Join(sarr, `","`) + `"}`
@ -263,7 +273,7 @@ var TemplateFunctions = template.FuncMap{
"genParamArgs": genParamArgs,
"genParamTypes": genParamTypes,
"genParamNames": genParamNames,
"genParamFmtString": genParamFmtString,
"genParamFmtString": ReplacePathParamsWithStr,
"swaggerUriToEchoUri": SwaggerUriToEchoUri,
"swaggerUriToChiUri": SwaggerUriToChiUri,
"lcFirst": LowercaseFirstCharacter,
@ -277,4 +287,5 @@ var TemplateFunctions = template.FuncMap{
"lower": strings.ToLower,
"title": strings.Title,
"stripNewLines": stripNewLines,
"sanitizeGoIdentity": SanitizeGoIdentity,
}

View file

@ -1,17 +1,43 @@
// Handler creates http.Handler with routing matching OpenAPI spec.
func Handler(si ServerInterface) http.Handler {
return HandlerFromMux(si, chi.NewRouter())
return HandlerWithOptions(si, ChiServerOptions{})
}
type ChiServerOptions struct {
BaseURL string
BaseRouter chi.Router
Middlewares []MiddlewareFunc
}
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
return HandlerWithOptions(si, ChiServerOptions {
BaseRouter: r,
})
}
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
return HandlerWithOptions(si, ChiServerOptions {
BaseURL: baseURL,
BaseRouter: r,
})
}
// HandlerWithOptions creates http.Handler with additional options
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
r := options.BaseRouter
if r == nil {
r = chi.NewRouter()
}
{{if .}}wrapper := ServerInterfaceWrapper{
Handler: si,
}
Handler: si,
HandlerMiddlewares: options.Middlewares,
}
{{end}}
{{range .}}r.Group(func(r chi.Router) {
r.{{.Method | lower | title }}("{{.Path | swaggerUriToChiUri}}", wrapper.{{.OperationId}})
r.{{.Method | lower | title }}(options.BaseURL+"{{.Path | swaggerUriToChiUri}}", wrapper.{{.OperationId}})
})
{{end}}
return r
return r
}

View file

@ -1,8 +1,11 @@
// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
HandlerMiddlewares []MiddlewareFunc
}
type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc
{{range .}}{{$opid := .OperationId}}
// {{$opid}} operation middleware
@ -36,7 +39,7 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{end}}
{{range .SecurityDefinitions}}
ctx = context.WithValue(ctx, "{{.ProviderName}}.Scopes", {{toStringArray .Scopes}})
ctx = context.WithValue(ctx, {{.ProviderName | ucFirst}}Scopes, {{toStringArray .Scopes}})
{{end}}
{{if .RequiresParamObject}}
@ -98,7 +101,7 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", valueList[0], &{{.GoName}})
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, valueList[0], &{{.GoName}})
if err != nil {
http.Error(w, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err), http.StatusBadRequest)
return
@ -116,7 +119,9 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{end}}
{{range .CookieParams}}
if cookie, err := r.Cookie("{{.ParamName}}"); err == nil {
var cookie *http.Cookie
if cookie, err = r.Cookie("{{.ParamName}}"); err == nil {
{{- if .IsPassThrough}}
params.{{.GoName}} = {{if not .Required}}&{{end}}cookie.Value
@ -159,7 +164,16 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{- end}}
{{end}}
{{end}}
siw.Handler.{{.OperationId}}(w, r.WithContext(ctx){{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}})
var handler = func(w http.ResponseWriter, r *http.Request) {
siw.Handler.{{.OperationId}}(w, r{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}})
}
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler(w, r.WithContext(ctx))
}
{{end}}

View file

@ -31,10 +31,10 @@ type ClientWithResponsesInterface interface {
{{$hasParams := .RequiresParamObject -}}
{{$pathParams := .PathParams -}}
{{$opid := .OperationId -}}
// {{$opid}} request {{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*{{genResponseTypeName $opid}}, error)
// {{$opid}} request{{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error)
{{range .Bodies}}
{{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*{{genResponseTypeName $opid}}, error)
{{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error)
{{end}}{{/* range .Bodies */}}
{{end}}{{/* range . $opid := .OperationId */}}
}
@ -71,8 +71,8 @@ func (r {{$opid | ucFirst}}Response) StatusCode() int {
{{/* Generate client methods (with responses)*/}}
// {{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse request{{if .HasBody}} with arbitrary body{{end}} returning *{{$opid}}Response
func (c *ClientWithResponses) {{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*{{genResponseTypeName $opid}}, error){
rsp, err := c.{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}}{{if .HasBody}}, contentType, body{{end}})
func (c *ClientWithResponses) {{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error){
rsp, err := c.{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}}{{if .HasBody}}, contentType, body{{end}}, reqEditors...)
if err != nil {
return nil, err
}
@ -83,8 +83,8 @@ func (c *ClientWithResponses) {{$opid}}{{if .HasBody}}WithBody{{end}}WithRespons
{{$pathParams := .PathParams -}}
{{$bodyRequired := .BodyRequired -}}
{{range .Bodies}}
func (c *ClientWithResponses) {{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*{{genResponseTypeName $opid}}, error) {
rsp, err := c.{{$opid}}{{.Suffix}}(ctx{{genParamNames $pathParams}}{{if $hasParams}}, params{{end}}, body)
func (c *ClientWithResponses) {{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error) {
rsp, err := c.{{$opid}}{{.Suffix}}(ctx{{genParamNames $pathParams}}{{if $hasParams}}, params{{end}}, body, reqEditors...)
if err != nil {
return nil, err
}

View file

@ -11,16 +11,18 @@ type HttpRequestDoer interface {
// Client which conforms to the OpenAPI3 specification for this service.
type Client struct {
// The endpoint of the server conforming to this interface, with scheme,
// https://api.deepmap.com for example.
// https://api.deepmap.com for example. This can contain a path relative
// to the server, such as https://api.deepmap.com/dev-test, and all the
// paths in the swagger spec will be appended to the server.
Server string
// Doer for performing requests, typically a *http.Client with any
// customized settings, such as certificate chains.
Client HttpRequestDoer
// A callback for modifying requests which are generated before sending over
// A list of callbacks for modifying requests which are generated before sending over
// the network.
RequestEditor RequestEditorFn
RequestEditors []RequestEditorFn
}
// ClientOption allows setting custom parameters during construction
@ -44,7 +46,7 @@ func NewClient(server string, opts ...ClientOption) (*Client, error) {
}
// create httpClient, if not already present
if client.Client == nil {
client.Client = http.DefaultClient
client.Client = &http.Client{}
}
return &client, nil
}
@ -62,7 +64,7 @@ func WithHTTPClient(doer HttpRequestDoer) ClientOption {
// called right before sending the request. This can be used to mutate the request.
func WithRequestEditorFn(fn RequestEditorFn) ClientOption {
return func(c *Client) error {
c.RequestEditor = fn
c.RequestEditors = append(c.RequestEditors, fn)
return nil
}
}
@ -73,10 +75,10 @@ type ClientInterface interface {
{{$hasParams := .RequiresParamObject -}}
{{$pathParams := .PathParams -}}
{{$opid := .OperationId -}}
// {{$opid}} request {{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*http.Response, error)
// {{$opid}} request{{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*http.Response, error)
{{range .Bodies}}
{{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*http.Response, error)
{{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*http.Response, error)
{{end}}{{/* range .Bodies */}}
{{end}}{{/* range . $opid := .OperationId */}}
}
@ -88,33 +90,27 @@ type ClientInterface interface {
{{$pathParams := .PathParams -}}
{{$opid := .OperationId -}}
func (c *Client) {{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*http.Response, error) {
func (c *Client) {{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*http.Response, error) {
req, err := New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(c.Server{{genParamNames .PathParams}}{{if $hasParams}}, params{{end}}{{if .HasBody}}, contentType, body{{end}})
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if c.RequestEditor != nil {
err = c.RequestEditor(ctx, req)
if err != nil {
return nil, err
}
if err := c.applyEditors(ctx, req, reqEditors); err != nil {
return nil, err
}
return c.Client.Do(req)
}
{{range .Bodies}}
func (c *Client) {{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*http.Response, error) {
func (c *Client) {{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*http.Response, error) {
req, err := New{{$opid}}{{.Suffix}}Request(c.Server{{genParamNames $pathParams}}{{if $hasParams}}, params{{end}}, body)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if c.RequestEditor != nil {
err = c.RequestEditor(ctx, req)
if err != nil {
return nil, err
}
if err := c.applyEditors(ctx, req, reqEditors); err != nil {
return nil, err
}
return c.Client.Do(req)
}
@ -147,39 +143,40 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
{{range $paramIdx, $param := .PathParams}}
var pathParam{{$paramIdx}} string
{{if .IsPassThrough}}
pathParam{{$paramIdx}} = {{.ParamName}}
pathParam{{$paramIdx}} = {{.GoVariableName}}
{{end}}
{{if .IsJson}}
var pathParamBuf{{$paramIdx}} []byte
pathParamBuf{{$paramIdx}}, err = json.Marshal({{.ParamName}})
pathParamBuf{{$paramIdx}}, err = json.Marshal({{.GoVariableName}})
if err != nil {
return nil, err
}
pathParam{{$paramIdx}} = string(pathParamBuf{{$paramIdx}})
{{end}}
{{if .IsStyled}}
pathParam{{$paramIdx}}, err = runtime.StyleParam("{{.Style}}", {{.Explode}}, "{{.ParamName}}", {{.GoVariableName}})
pathParam{{$paramIdx}}, err = runtime.StyleParamWithLocation("{{.Style}}", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationPath, {{.GoVariableName}})
if err != nil {
return nil, err
}
{{end}}
{{end}}
queryUrl, err := url.Parse(server)
serverURL, err := url.Parse(server)
if err != nil {
return nil, err
}
basePath := fmt.Sprintf("{{genParamFmtString .Path}}"{{range $paramIdx, $param := .PathParams}}, pathParam{{$paramIdx}}{{end}})
if basePath[0] == '/' {
basePath = basePath[1:]
operationPath := fmt.Sprintf("{{genParamFmtString .Path}}"{{range $paramIdx, $param := .PathParams}}, pathParam{{$paramIdx}}{{end}})
if operationPath[0] == '/' {
operationPath = "." + operationPath
}
queryUrl, err = queryUrl.Parse(basePath)
queryURL, err := serverURL.Parse(operationPath)
if err != nil {
return nil, err
}
{{if .QueryParams}}
queryValues := queryUrl.Query()
queryValues := queryURL.Query()
{{range $paramIdx, $param := .QueryParams}}
{{if not .Required}} if params.{{.GoName}} != nil { {{end}}
{{if .IsPassThrough}}
@ -194,7 +191,7 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
{{end}}
{{if .IsStyled}}
if queryFrag, err := runtime.StyleParam("{{.Style}}", {{.Explode}}, "{{.ParamName}}", {{if not .Required}}*{{end}}params.{{.GoName}}); err != nil {
if queryFrag, err := runtime.StyleParamWithLocation("{{.Style}}", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationQuery, {{if not .Required}}*{{end}}params.{{.GoName}}); err != nil {
return nil, err
} else if parsed, err := url.ParseQuery(queryFrag); err != nil {
return nil, err
@ -208,13 +205,14 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
{{end}}
{{if not .Required}}}{{end}}
{{end}}
queryUrl.RawQuery = queryValues.Encode()
queryURL.RawQuery = queryValues.Encode()
{{end}}{{/* if .QueryParams */}}
req, err := http.NewRequest("{{.Method}}", queryUrl.String(), {{if .HasBody}}body{{else}}nil{{end}})
req, err := http.NewRequest("{{.Method}}", queryURL.String(), {{if .HasBody}}body{{else}}nil{{end}})
if err != nil {
return nil, err
}
{{if .HasBody}}req.Header.Add("Content-Type", contentType){{end}}
{{range $paramIdx, $param := .HeaderParams}}
{{if not .Required}} if params.{{.GoName}} != nil { {{end}}
var headerParam{{$paramIdx}} string
@ -230,12 +228,12 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
headerParam{{$paramIdx}} = string(headerParamBuf{{$paramIdx}})
{{end}}
{{if .IsStyled}}
headerParam{{$paramIdx}}, err = runtime.StyleParam("{{.Style}}", {{.Explode}}, "{{.ParamName}}", {{if not .Required}}*{{end}}params.{{.GoName}})
headerParam{{$paramIdx}}, err = runtime.StyleParamWithLocation("{{.Style}}", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, {{if not .Required}}*{{end}}params.{{.GoName}})
if err != nil {
return nil, err
}
{{end}}
req.Header.Add("{{.ParamName}}", headerParam{{$paramIdx}})
req.Header.Set("{{.ParamName}}", headerParam{{$paramIdx}})
{{if not .Required}}}{{end}}
{{end}}
@ -254,7 +252,7 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
cookieParam{{$paramIdx}} = url.QueryEscape(string(cookieParamBuf{{$paramIdx}}))
{{end}}
{{if .IsStyled}}
cookieParam{{$paramIdx}}, err = runtime.StyleParam("simple", {{.Explode}}, "{{.ParamName}}", {{if not .Required}}*{{end}}params.{{.GoName}})
cookieParam{{$paramIdx}}, err = runtime.StyleParamWithLocation("simple", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationCookie, {{if not .Required}}*{{end}}params.{{.GoName}})
if err != nil {
return nil, err
}
@ -266,8 +264,21 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
req.AddCookie(cookie{{$paramIdx}})
{{if not .Required}}}{{end}}
{{end}}
{{if .HasBody}}req.Header.Add("Content-Type", contentType){{end}}
return req, nil
}
{{end}}{{/* Range */}}
func (c *Client) applyEditors(ctx context.Context, req *http.Request, additionalEditors []RequestEditorFn) error {
for _, r := range c.RequestEditors {
if err := r(ctx, req); err != nil {
return err
}
}
for _, r := range additionalEditors {
if err := r(ctx, req); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,17 @@
{{- if gt (len .SecuritySchemeProviderNames) 0 }}
const (
{{range $ProviderName := .SecuritySchemeProviderNames}}
{{- $ProviderName | ucFirst}}Scopes = "{{$ProviderName}}.Scopes"
{{end}}
)
{{end}}
{{if gt (len .EnumDefinitions) 0 }}
{{range $Enum := .EnumDefinitions}}
// Defines values for {{$Enum.TypeName}}.
const (
{{range $index, $value := $Enum.Schema.EnumValues}}
{{$index}} {{$Enum.TypeName}} = {{$Enum.ValueWrapper}}{{$value}}{{$Enum.ValueWrapper}}
{{end}}
)
{{end}}
{{end}}

View file

@ -1,10 +1,32 @@
// Package {{.PackageName}} provides primitives to interact the openapi HTTP API.
// Package {{.PackageName}} provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen DO NOT EDIT.
// Code generated by {{.ModuleName}} version {{.Version}} DO NOT EDIT.
package {{.PackageName}}
{{if .Imports}}
import (
{{range .Imports}} {{ . }}
{{end}})
{{end}}
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"encoding/xml"
"fmt"
"gopkg.in/yaml.v2"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/deepmap/oapi-codegen/pkg/runtime"
openapi_types "github.com/deepmap/oapi-codegen/pkg/types"
"github.com/getkin/kin-openapi/openapi3"
"github.com/go-chi/chi/v5"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
{{- range .ExternalImports}}
{{ . }}
{{- end}}
)

View file

@ -1,12 +1,12 @@
// Base64 encoded, gzipped, json marshaled Swagger object
var swaggerSpec = []string{
{{range .}}
{{range .SpecParts}}
"{{.}}",{{end}}
}
// GetSwagger returns the Swagger specification corresponding to the generated code
// in this file.
func GetSwagger() (*openapi3.Swagger, error) {
// GetSwagger returns the content of the embedded swagger specification file
// or error if failed to decode
func decodeSpec() ([]byte, error) {
zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, ""))
if err != nil {
return nil, fmt.Errorf("error base64 decoding spec: %s", err)
@ -21,9 +21,67 @@ func GetSwagger() (*openapi3.Swagger, error) {
return nil, fmt.Errorf("error decompressing spec: %s", err)
}
swagger, err := openapi3.NewSwaggerLoader().LoadSwaggerFromData(buf.Bytes())
if err != nil {
return nil, fmt.Errorf("error loading Swagger: %s", err)
}
return swagger, nil
return buf.Bytes(), nil
}
var rawSpec = decodeSpecCached()
// a naive cached of a decoded swagger spec
func decodeSpecCached() func() ([]byte, error) {
data, err := decodeSpec()
return func() ([]byte, error) {
return data, err
}
}
// Constructs a synthetic filesystem for resolving external references when loading openapi specifications.
func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) {
var res = make(map[string]func() ([]byte, error))
if len(pathToFile) > 0 {
res[pathToFile] = rawSpec
}
{{ if .ImportMapping }}
pathPrefix := path.Dir(pathToFile)
{{ end }}
{{ range $key, $value := .ImportMapping }}
for rawPath, rawFunc := range {{ $value.Name }}.PathToRawSpec(path.Join(pathPrefix, "{{ $key }}")) {
if _, ok := res[rawPath]; ok {
// it is not possible to compare functions in golang, so always overwrite the old value
}
res[rawPath] = rawFunc
}
{{- end }}
return res
}
// GetSwagger returns the Swagger specification corresponding to the generated code
// in this file. The external references of Swagger specification are resolved.
// The logic of resolving external references is tightly connected to "import-mapping" feature.
// Externally referenced files must be embedded in the corresponding golang packages.
// Urls can be supported but this task was out of the scope.
func GetSwagger() (swagger *openapi3.T, err error) {
var resolvePath = PathToRawSpec("")
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) {
var pathToFile = url.String()
pathToFile = path.Clean(pathToFile)
getSpec, ok := resolvePath[pathToFile]
if !ok {
err1 := fmt.Errorf("path not found: %s", pathToFile)
return nil, err1
}
return getSpec()
}
var specData []byte
specData, err = rawSpec()
if err != nil {
return
}
swagger, err = loader.LoadFromData(specData)
if err != nil {
return
}
return
}

View file

@ -1,6 +1,6 @@
{{range .}}{{$opid := .OperationId}}
{{range .TypeDefinitions}}
// {{.TypeName}} defines parameters for {{$opid}}.
type {{.TypeName}} {{.Schema.TypeDecl}}
type {{.TypeName}} {{if and (opts.AliasTypes) (.CanAlias)}}={{end}} {{.Schema.TypeDecl}}
{{end}}
{{end}}

View file

@ -17,11 +17,17 @@ type EchoRouter interface {
// RegisterHandlers adds each server route to the EchoRouter.
func RegisterHandlers(router EchoRouter, si ServerInterface) {
RegisterHandlersWithBaseURL(router, si, "")
}
// Registers handlers, and prepends BaseURL to the paths, so that the paths
// can be served under a prefix.
func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) {
{{if .}}
wrapper := ServerInterfaceWrapper{
Handler: si,
}
{{end}}
{{range .}}router.{{.Method}}("{{.Path | swaggerUriToEchoUri}}", wrapper.{{.OperationId}})
{{range .}}router.{{.Method}}(baseURL + "{{.Path | swaggerUriToEchoUri}}", wrapper.{{.OperationId}})
{{end}}
}

View file

@ -1,6 +1,8 @@
{{range .}}{{$opid := .OperationId}}
{{range .Bodies}}
// {{$opid}}RequestBody defines body for {{$opid}} for application/json ContentType.
type {{$opid}}{{.NameTag}}RequestBody {{.TypeDef}}
{{with .TypeDef $opid}}
// {{.TypeName}} defines body for {{$opid}} for application/json ContentType.
type {{.TypeName}} {{if and (opts.AliasTypes) (.CanAlias)}}={{end}} {{.Schema.TypeDecl}}
{{end}}
{{end}}
{{end}}

View file

@ -75,20 +75,46 @@ func (a {{.TypeName}}) MarshalJSON() ([]byte, error) {
`,
"chi-handler.tmpl": `// Handler creates http.Handler with routing matching OpenAPI spec.
func Handler(si ServerInterface) http.Handler {
return HandlerFromMux(si, chi.NewRouter())
return HandlerWithOptions(si, ChiServerOptions{})
}
type ChiServerOptions struct {
BaseURL string
BaseRouter chi.Router
Middlewares []MiddlewareFunc
}
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
return HandlerWithOptions(si, ChiServerOptions {
BaseRouter: r,
})
}
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
return HandlerWithOptions(si, ChiServerOptions {
BaseURL: baseURL,
BaseRouter: r,
})
}
// HandlerWithOptions creates http.Handler with additional options
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
r := options.BaseRouter
if r == nil {
r = chi.NewRouter()
}
{{if .}}wrapper := ServerInterfaceWrapper{
Handler: si,
}
Handler: si,
HandlerMiddlewares: options.Middlewares,
}
{{end}}
{{range .}}r.Group(func(r chi.Router) {
r.{{.Method | lower | title }}("{{.Path | swaggerUriToChiUri}}", wrapper.{{.OperationId}})
r.{{.Method | lower | title }}(options.BaseURL+"{{.Path | swaggerUriToChiUri}}", wrapper.{{.OperationId}})
})
{{end}}
return r
return r
}
`,
"chi-interface.tmpl": `// ServerInterface represents all server handlers.
@ -102,8 +128,11 @@ type ServerInterface interface {
"chi-middleware.tmpl": `// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
HandlerMiddlewares []MiddlewareFunc
}
type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc
{{range .}}{{$opid := .OperationId}}
// {{$opid}} operation middleware
@ -137,7 +166,7 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{end}}
{{range .SecurityDefinitions}}
ctx = context.WithValue(ctx, "{{.ProviderName}}.Scopes", {{toStringArray .Scopes}})
ctx = context.WithValue(ctx, {{.ProviderName | ucFirst}}Scopes, {{toStringArray .Scopes}})
{{end}}
{{if .RequiresParamObject}}
@ -199,7 +228,7 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", valueList[0], &{{.GoName}})
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, valueList[0], &{{.GoName}})
if err != nil {
http.Error(w, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err), http.StatusBadRequest)
return
@ -217,7 +246,9 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{end}}
{{range .CookieParams}}
if cookie, err := r.Cookie("{{.ParamName}}"); err == nil {
var cookie *http.Cookie
if cookie, err = r.Cookie("{{.ParamName}}"); err == nil {
{{- if .IsPassThrough}}
params.{{.GoName}} = {{if not .Required}}&{{end}}cookie.Value
@ -260,7 +291,16 @@ func (siw *ServerInterfaceWrapper) {{$opid}}(w http.ResponseWriter, r *http.Requ
{{- end}}
{{end}}
{{end}}
siw.Handler.{{.OperationId}}(w, r.WithContext(ctx){{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}})
var handler = func(w http.ResponseWriter, r *http.Request) {
siw.Handler.{{.OperationId}}(w, r{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}})
}
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler(w, r.WithContext(ctx))
}
{{end}}
@ -300,10 +340,10 @@ type ClientWithResponsesInterface interface {
{{$hasParams := .RequiresParamObject -}}
{{$pathParams := .PathParams -}}
{{$opid := .OperationId -}}
// {{$opid}} request {{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*{{genResponseTypeName $opid}}, error)
// {{$opid}} request{{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error)
{{range .Bodies}}
{{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*{{genResponseTypeName $opid}}, error)
{{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error)
{{end}}{{/* range .Bodies */}}
{{end}}{{/* range . $opid := .OperationId */}}
}
@ -340,8 +380,8 @@ func (r {{$opid | ucFirst}}Response) StatusCode() int {
{{/* Generate client methods (with responses)*/}}
// {{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse request{{if .HasBody}} with arbitrary body{{end}} returning *{{$opid}}Response
func (c *ClientWithResponses) {{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*{{genResponseTypeName $opid}}, error){
rsp, err := c.{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}}{{if .HasBody}}, contentType, body{{end}})
func (c *ClientWithResponses) {{$opid}}{{if .HasBody}}WithBody{{end}}WithResponse(ctx context.Context{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error){
rsp, err := c.{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx{{genParamNames .PathParams}}{{if .RequiresParamObject}}, params{{end}}{{if .HasBody}}, contentType, body{{end}}, reqEditors...)
if err != nil {
return nil, err
}
@ -352,8 +392,8 @@ func (c *ClientWithResponses) {{$opid}}{{if .HasBody}}WithBody{{end}}WithRespons
{{$pathParams := .PathParams -}}
{{$bodyRequired := .BodyRequired -}}
{{range .Bodies}}
func (c *ClientWithResponses) {{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*{{genResponseTypeName $opid}}, error) {
rsp, err := c.{{$opid}}{{.Suffix}}(ctx{{genParamNames $pathParams}}{{if $hasParams}}, params{{end}}, body)
func (c *ClientWithResponses) {{$opid}}{{.Suffix}}WithResponse(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*{{genResponseTypeName $opid}}, error) {
rsp, err := c.{{$opid}}{{.Suffix}}(ctx{{genParamNames $pathParams}}{{if $hasParams}}, params{{end}}, body, reqEditors...)
if err != nil {
return nil, err
}
@ -396,16 +436,18 @@ type HttpRequestDoer interface {
// Client which conforms to the OpenAPI3 specification for this service.
type Client struct {
// The endpoint of the server conforming to this interface, with scheme,
// https://api.deepmap.com for example.
// https://api.deepmap.com for example. This can contain a path relative
// to the server, such as https://api.deepmap.com/dev-test, and all the
// paths in the swagger spec will be appended to the server.
Server string
// Doer for performing requests, typically a *http.Client with any
// customized settings, such as certificate chains.
Client HttpRequestDoer
// A callback for modifying requests which are generated before sending over
// A list of callbacks for modifying requests which are generated before sending over
// the network.
RequestEditor RequestEditorFn
RequestEditors []RequestEditorFn
}
// ClientOption allows setting custom parameters during construction
@ -429,7 +471,7 @@ func NewClient(server string, opts ...ClientOption) (*Client, error) {
}
// create httpClient, if not already present
if client.Client == nil {
client.Client = http.DefaultClient
client.Client = &http.Client{}
}
return &client, nil
}
@ -447,7 +489,7 @@ func WithHTTPClient(doer HttpRequestDoer) ClientOption {
// called right before sending the request. This can be used to mutate the request.
func WithRequestEditorFn(fn RequestEditorFn) ClientOption {
return func(c *Client) error {
c.RequestEditor = fn
c.RequestEditors = append(c.RequestEditors, fn)
return nil
}
}
@ -458,10 +500,10 @@ type ClientInterface interface {
{{$hasParams := .RequiresParamObject -}}
{{$pathParams := .PathParams -}}
{{$opid := .OperationId -}}
// {{$opid}} request {{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*http.Response, error)
// {{$opid}} request{{if .HasBody}} with any body{{end}}
{{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*http.Response, error)
{{range .Bodies}}
{{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*http.Response, error)
{{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*http.Response, error)
{{end}}{{/* range .Bodies */}}
{{end}}{{/* range . $opid := .OperationId */}}
}
@ -473,33 +515,27 @@ type ClientInterface interface {
{{$pathParams := .PathParams -}}
{{$opid := .OperationId -}}
func (c *Client) {{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}) (*http.Response, error) {
func (c *Client) {{$opid}}{{if .HasBody}}WithBody{{end}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}{{if .HasBody}}, contentType string, body io.Reader{{end}}, reqEditors... RequestEditorFn) (*http.Response, error) {
req, err := New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(c.Server{{genParamNames .PathParams}}{{if $hasParams}}, params{{end}}{{if .HasBody}}, contentType, body{{end}})
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if c.RequestEditor != nil {
err = c.RequestEditor(ctx, req)
if err != nil {
return nil, err
}
if err := c.applyEditors(ctx, req, reqEditors); err != nil {
return nil, err
}
return c.Client.Do(req)
}
{{range .Bodies}}
func (c *Client) {{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody) (*http.Response, error) {
func (c *Client) {{$opid}}{{.Suffix}}(ctx context.Context{{genParamArgs $pathParams}}{{if $hasParams}}, params *{{$opid}}Params{{end}}, body {{$opid}}{{.NameTag}}RequestBody, reqEditors... RequestEditorFn) (*http.Response, error) {
req, err := New{{$opid}}{{.Suffix}}Request(c.Server{{genParamNames $pathParams}}{{if $hasParams}}, params{{end}}, body)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if c.RequestEditor != nil {
err = c.RequestEditor(ctx, req)
if err != nil {
return nil, err
}
if err := c.applyEditors(ctx, req, reqEditors); err != nil {
return nil, err
}
return c.Client.Do(req)
}
@ -532,39 +568,40 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
{{range $paramIdx, $param := .PathParams}}
var pathParam{{$paramIdx}} string
{{if .IsPassThrough}}
pathParam{{$paramIdx}} = {{.ParamName}}
pathParam{{$paramIdx}} = {{.GoVariableName}}
{{end}}
{{if .IsJson}}
var pathParamBuf{{$paramIdx}} []byte
pathParamBuf{{$paramIdx}}, err = json.Marshal({{.ParamName}})
pathParamBuf{{$paramIdx}}, err = json.Marshal({{.GoVariableName}})
if err != nil {
return nil, err
}
pathParam{{$paramIdx}} = string(pathParamBuf{{$paramIdx}})
{{end}}
{{if .IsStyled}}
pathParam{{$paramIdx}}, err = runtime.StyleParam("{{.Style}}", {{.Explode}}, "{{.ParamName}}", {{.GoVariableName}})
pathParam{{$paramIdx}}, err = runtime.StyleParamWithLocation("{{.Style}}", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationPath, {{.GoVariableName}})
if err != nil {
return nil, err
}
{{end}}
{{end}}
queryUrl, err := url.Parse(server)
serverURL, err := url.Parse(server)
if err != nil {
return nil, err
}
basePath := fmt.Sprintf("{{genParamFmtString .Path}}"{{range $paramIdx, $param := .PathParams}}, pathParam{{$paramIdx}}{{end}})
if basePath[0] == '/' {
basePath = basePath[1:]
operationPath := fmt.Sprintf("{{genParamFmtString .Path}}"{{range $paramIdx, $param := .PathParams}}, pathParam{{$paramIdx}}{{end}})
if operationPath[0] == '/' {
operationPath = "." + operationPath
}
queryUrl, err = queryUrl.Parse(basePath)
queryURL, err := serverURL.Parse(operationPath)
if err != nil {
return nil, err
}
{{if .QueryParams}}
queryValues := queryUrl.Query()
queryValues := queryURL.Query()
{{range $paramIdx, $param := .QueryParams}}
{{if not .Required}} if params.{{.GoName}} != nil { {{end}}
{{if .IsPassThrough}}
@ -579,7 +616,7 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
{{end}}
{{if .IsStyled}}
if queryFrag, err := runtime.StyleParam("{{.Style}}", {{.Explode}}, "{{.ParamName}}", {{if not .Required}}*{{end}}params.{{.GoName}}); err != nil {
if queryFrag, err := runtime.StyleParamWithLocation("{{.Style}}", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationQuery, {{if not .Required}}*{{end}}params.{{.GoName}}); err != nil {
return nil, err
} else if parsed, err := url.ParseQuery(queryFrag); err != nil {
return nil, err
@ -593,13 +630,14 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
{{end}}
{{if not .Required}}}{{end}}
{{end}}
queryUrl.RawQuery = queryValues.Encode()
queryURL.RawQuery = queryValues.Encode()
{{end}}{{/* if .QueryParams */}}
req, err := http.NewRequest("{{.Method}}", queryUrl.String(), {{if .HasBody}}body{{else}}nil{{end}})
req, err := http.NewRequest("{{.Method}}", queryURL.String(), {{if .HasBody}}body{{else}}nil{{end}})
if err != nil {
return nil, err
}
{{if .HasBody}}req.Header.Add("Content-Type", contentType){{end}}
{{range $paramIdx, $param := .HeaderParams}}
{{if not .Required}} if params.{{.GoName}} != nil { {{end}}
var headerParam{{$paramIdx}} string
@ -615,12 +653,12 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
headerParam{{$paramIdx}} = string(headerParamBuf{{$paramIdx}})
{{end}}
{{if .IsStyled}}
headerParam{{$paramIdx}}, err = runtime.StyleParam("{{.Style}}", {{.Explode}}, "{{.ParamName}}", {{if not .Required}}*{{end}}params.{{.GoName}})
headerParam{{$paramIdx}}, err = runtime.StyleParamWithLocation("{{.Style}}", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, {{if not .Required}}*{{end}}params.{{.GoName}})
if err != nil {
return nil, err
}
{{end}}
req.Header.Add("{{.ParamName}}", headerParam{{$paramIdx}})
req.Header.Set("{{.ParamName}}", headerParam{{$paramIdx}})
{{if not .Required}}}{{end}}
{{end}}
@ -639,7 +677,7 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
cookieParam{{$paramIdx}} = url.QueryEscape(string(cookieParamBuf{{$paramIdx}}))
{{end}}
{{if .IsStyled}}
cookieParam{{$paramIdx}}, err = runtime.StyleParam("simple", {{.Explode}}, "{{.ParamName}}", {{if not .Required}}*{{end}}params.{{.GoName}})
cookieParam{{$paramIdx}}, err = runtime.StyleParamWithLocation("simple", {{.Explode}}, "{{.ParamName}}", runtime.ParamLocationCookie, {{if not .Required}}*{{end}}params.{{.GoName}})
if err != nil {
return nil, err
}
@ -651,32 +689,85 @@ func New{{$opid}}Request{{if .HasBody}}WithBody{{end}}(server string{{genParamAr
req.AddCookie(cookie{{$paramIdx}})
{{if not .Required}}}{{end}}
{{end}}
{{if .HasBody}}req.Header.Add("Content-Type", contentType){{end}}
return req, nil
}
{{end}}{{/* Range */}}
func (c *Client) applyEditors(ctx context.Context, req *http.Request, additionalEditors []RequestEditorFn) error {
for _, r := range c.RequestEditors {
if err := r(ctx, req); err != nil {
return err
}
}
for _, r := range additionalEditors {
if err := r(ctx, req); err != nil {
return err
}
}
return nil
}
`,
"imports.tmpl": `// Package {{.PackageName}} provides primitives to interact the openapi HTTP API.
"constants.tmpl": `{{- if gt (len .SecuritySchemeProviderNames) 0 }}
const (
{{range $ProviderName := .SecuritySchemeProviderNames}}
{{- $ProviderName | ucFirst}}Scopes = "{{$ProviderName}}.Scopes"
{{end}}
)
{{end}}
{{if gt (len .EnumDefinitions) 0 }}
{{range $Enum := .EnumDefinitions}}
// Defines values for {{$Enum.TypeName}}.
const (
{{range $index, $value := $Enum.Schema.EnumValues}}
{{$index}} {{$Enum.TypeName}} = {{$Enum.ValueWrapper}}{{$value}}{{$Enum.ValueWrapper}}
{{end}}
)
{{end}}
{{end}}
`,
"imports.tmpl": `// Package {{.PackageName}} provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen DO NOT EDIT.
// Code generated by {{.ModuleName}} version {{.Version}} DO NOT EDIT.
package {{.PackageName}}
{{if .Imports}}
import (
{{range .Imports}} {{ . }}
{{end}})
{{end}}
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"encoding/xml"
"fmt"
"gopkg.in/yaml.v2"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/deepmap/oapi-codegen/pkg/runtime"
openapi_types "github.com/deepmap/oapi-codegen/pkg/types"
"github.com/getkin/kin-openapi/openapi3"
"github.com/go-chi/chi/v5"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
{{- range .ExternalImports}}
{{ . }}
{{- end}}
)
`,
"inline.tmpl": `// Base64 encoded, gzipped, json marshaled Swagger object
var swaggerSpec = []string{
{{range .}}
{{range .SpecParts}}
"{{.}}",{{end}}
}
// GetSwagger returns the Swagger specification corresponding to the generated code
// in this file.
func GetSwagger() (*openapi3.Swagger, error) {
// GetSwagger returns the content of the embedded swagger specification file
// or error if failed to decode
func decodeSpec() ([]byte, error) {
zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, ""))
if err != nil {
return nil, fmt.Errorf("error base64 decoding spec: %s", err)
@ -691,17 +782,75 @@ func GetSwagger() (*openapi3.Swagger, error) {
return nil, fmt.Errorf("error decompressing spec: %s", err)
}
swagger, err := openapi3.NewSwaggerLoader().LoadSwaggerFromData(buf.Bytes())
if err != nil {
return nil, fmt.Errorf("error loading Swagger: %s", err)
return buf.Bytes(), nil
}
var rawSpec = decodeSpecCached()
// a naive cached of a decoded swagger spec
func decodeSpecCached() func() ([]byte, error) {
data, err := decodeSpec()
return func() ([]byte, error) {
return data, err
}
}
// Constructs a synthetic filesystem for resolving external references when loading openapi specifications.
func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) {
var res = make(map[string]func() ([]byte, error))
if len(pathToFile) > 0 {
res[pathToFile] = rawSpec
}
return swagger, nil
{{ if .ImportMapping }}
pathPrefix := path.Dir(pathToFile)
{{ end }}
{{ range $key, $value := .ImportMapping }}
for rawPath, rawFunc := range {{ $value.Name }}.PathToRawSpec(path.Join(pathPrefix, "{{ $key }}")) {
if _, ok := res[rawPath]; ok {
// it is not possible to compare functions in golang, so always overwrite the old value
}
res[rawPath] = rawFunc
}
{{- end }}
return res
}
// GetSwagger returns the Swagger specification corresponding to the generated code
// in this file. The external references of Swagger specification are resolved.
// The logic of resolving external references is tightly connected to "import-mapping" feature.
// Externally referenced files must be embedded in the corresponding golang packages.
// Urls can be supported but this task was out of the scope.
func GetSwagger() (swagger *openapi3.T, err error) {
var resolvePath = PathToRawSpec("")
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) {
var pathToFile = url.String()
pathToFile = path.Clean(pathToFile)
getSpec, ok := resolvePath[pathToFile]
if !ok {
err1 := fmt.Errorf("path not found: %s", pathToFile)
return nil, err1
}
return getSpec()
}
var specData []byte
specData, err = rawSpec()
if err != nil {
return
}
swagger, err = loader.LoadFromData(specData)
if err != nil {
return
}
return
}
`,
"param-types.tmpl": `{{range .}}{{$opid := .OperationId}}
{{range .TypeDefinitions}}
// {{.TypeName}} defines parameters for {{$opid}}.
type {{.TypeName}} {{.Schema.TypeDecl}}
type {{.TypeName}} {{if and (opts.AliasTypes) (.CanAlias)}}={{end}} {{.Schema.TypeDecl}}
{{end}}
{{end}}
`,
@ -724,19 +873,27 @@ type EchoRouter interface {
// RegisterHandlers adds each server route to the EchoRouter.
func RegisterHandlers(router EchoRouter, si ServerInterface) {
RegisterHandlersWithBaseURL(router, si, "")
}
// Registers handlers, and prepends BaseURL to the paths, so that the paths
// can be served under a prefix.
func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) {
{{if .}}
wrapper := ServerInterfaceWrapper{
Handler: si,
}
{{end}}
{{range .}}router.{{.Method}}("{{.Path | swaggerUriToEchoUri}}", wrapper.{{.OperationId}})
{{range .}}router.{{.Method}}(baseURL + "{{.Path | swaggerUriToEchoUri}}", wrapper.{{.OperationId}})
{{end}}
}
`,
"request-bodies.tmpl": `{{range .}}{{$opid := .OperationId}}
{{range .Bodies}}
// {{$opid}}RequestBody defines body for {{$opid}} for application/json ContentType.
type {{$opid}}{{.NameTag}}RequestBody {{.TypeDef}}
{{with .TypeDef $opid}}
// {{.TypeName}} defines body for {{$opid}} for application/json ContentType.
type {{.TypeName}} {{if and (opts.AliasTypes) (.CanAlias)}}={{end}} {{.Schema.TypeDecl}}
{{end}}
{{end}}
{{end}}
`,
@ -749,17 +906,8 @@ type ServerInterface interface {
}
`,
"typedef.tmpl": `{{range .Types}}
// {{.TypeName}} defines model for {{.JsonName}}.
type {{.TypeName}} {{.Schema.TypeDecl}}
{{- if gt (len .Schema.EnumValues) 0 }}
// List of {{ .TypeName }}
const (
{{- $typeName := .TypeName }}
{{- range $key, $value := .Schema.EnumValues }}
{{ $typeName }}_{{ $key }} {{ $typeName }} = "{{ $value }}"
{{- end }}
)
{{- end }}
{{ with .Schema.Description }}{{ . }}{{ else }}// {{.TypeName}} defines model for {{.JsonName}}.{{ end }}
type {{.TypeName}} {{if and (opts.AliasTypes) (.CanAlias)}}={{end}} {{.Schema.TypeDecl}}
{{end}}
`,
"wrappers.tmpl": `// ServerInterfaceWrapper converts echo contexts to parameters.
@ -782,7 +930,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
}
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", ctx.Param("{{.ParamName}}"), &{{$varName}})
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationPath, ctx.Param("{{.ParamName}}"), &{{$varName}})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err))
}
@ -790,7 +938,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
{{end}}
{{range .SecurityDefinitions}}
ctx.Set("{{.ProviderName}}.Scopes", {{toStringArray .Scopes}})
ctx.Set({{.ProviderName | sanitizeGoIdentity | ucFirst}}Scopes, {{toStringArray .Scopes}})
{{end}}
{{if .RequiresParamObject}}
@ -840,7 +988,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
}
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", valueList[0], &{{.GoName}})
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, valueList[0], &{{.GoName}})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err))
}
@ -872,7 +1020,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
{{end}}
{{if .IsStyled}}
var value {{.TypeDef}}
err = runtime.BindStyledParameter("simple",{{.Explode}}, "{{.ParamName}}", cookie.Value, &value)
err = runtime.BindStyledParameterWithLocation("simple",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationCookie, cookie.Value, &value)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err))
}

View file

@ -1,13 +1,4 @@
{{range .Types}}
// {{.TypeName}} defines model for {{.JsonName}}.
type {{.TypeName}} {{.Schema.TypeDecl}}
{{- if gt (len .Schema.EnumValues) 0 }}
// List of {{ .TypeName }}
const (
{{- $typeName := .TypeName }}
{{- range $key, $value := .Schema.EnumValues }}
{{ $typeName }}_{{ $key }} {{ $typeName }} = "{{ $value }}"
{{- end }}
)
{{- end }}
{{ with .Schema.Description }}{{ . }}{{ else }}// {{.TypeName}} defines model for {{.JsonName}}.{{ end }}
type {{.TypeName}} {{if and (opts.AliasTypes) (.CanAlias)}}={{end}} {{.Schema.TypeDecl}}
{{end}}

View file

@ -18,7 +18,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
}
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", ctx.Param("{{.ParamName}}"), &{{$varName}})
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationPath, ctx.Param("{{.ParamName}}"), &{{$varName}})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err))
}
@ -26,7 +26,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
{{end}}
{{range .SecurityDefinitions}}
ctx.Set("{{.ProviderName}}.Scopes", {{toStringArray .Scopes}})
ctx.Set({{.ProviderName | sanitizeGoIdentity | ucFirst}}Scopes, {{toStringArray .Scopes}})
{{end}}
{{if .RequiresParamObject}}
@ -76,7 +76,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
}
{{end}}
{{if .IsStyled}}
err = runtime.BindStyledParameter("{{.Style}}",{{.Explode}}, "{{.ParamName}}", valueList[0], &{{.GoName}})
err = runtime.BindStyledParameterWithLocation("{{.Style}}",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationHeader, valueList[0], &{{.GoName}})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err))
}
@ -108,7 +108,7 @@ func (w *ServerInterfaceWrapper) {{.OperationId}} (ctx echo.Context) error {
{{end}}
{{if .IsStyled}}
var value {{.TypeDef}}
err = runtime.BindStyledParameter("simple",{{.Explode}}, "{{.ParamName}}", cookie.Value, &value)
err = runtime.BindStyledParameterWithLocation("simple",{{.Explode}}, "{{.ParamName}}", runtime.ParamLocationCookie, cookie.Value, &value)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter {{.ParamName}}: %s", err))
}

View file

@ -15,6 +15,7 @@ package codegen
import (
"fmt"
"net/url"
"regexp"
"sort"
"strconv"
@ -182,6 +183,17 @@ func SortedRequestBodyKeys(dict map[string]*openapi3.RequestBodyRef) []string {
return keys
}
func SortedSecurityRequirementKeys(sr openapi3.SecurityRequirement) []string {
keys := make([]string, len(sr))
i := 0
for key := range sr {
keys[i] = key
i++
}
sort.Strings(keys)
return keys
}
// This function checks whether the specified string is present in an array
// of strings
func StringInArray(str string, array []string) bool {
@ -198,15 +210,25 @@ func StringInArray(str string, array []string) bool {
// #/components/parameters/Bar -> Bar
// #/components/responses/Baz -> Baz
// Remote components (document.json#/Foo) are supported if they present in --import-mapping
// URL components (http://deepmap.com/schemas/document.json#Foo) are supported if they present in --import-mapping
//
// URL components (http://deepmap.com/schemas/document.json#/Foo) are supported if they present in --import-mapping
// Remote and URL also support standard local paths even though the spec doesn't mention them.
func RefPathToGoType(refPath string) (string, error) {
return refPathToGoType(refPath, true)
}
// refPathToGoType returns the Go typename for refPath given its
func refPathToGoType(refPath string, local bool) (string, error) {
if refPath[0] == '#' {
pathParts := strings.Split(refPath, "/")
if depth := len(pathParts); depth != 4 {
return "", fmt.Errorf("Parameter nesting is deeper than supported: %s has %d", refPath, depth)
depth := len(pathParts)
if local {
if depth != 4 {
return "", fmt.Errorf("unexpected reference depth: %d for ref: %s local: %t", depth, refPath, local)
}
} else if depth != 4 && depth != 2 {
return "", fmt.Errorf("unexpected reference depth: %d for ref: %s local: %t", depth, refPath, local)
}
return SchemaNameToTypeName(pathParts[3]), nil
return SchemaNameToTypeName(pathParts[len(pathParts)-1]), nil
}
pathParts := strings.Split(refPath, "#")
if len(pathParts) != 2 {
@ -216,14 +238,35 @@ func RefPathToGoType(refPath string) (string, error) {
if goImport, ok := importMapping[remoteComponent]; !ok {
return "", fmt.Errorf("unrecognized external reference '%s'; please provide the known import for this reference using option --import-mapping", remoteComponent)
} else {
goType, err := RefPathToGoType("#" + flatComponent)
goType, err := refPathToGoType("#"+flatComponent, false)
if err != nil {
return "", err
}
return fmt.Sprintf("%s.%s", goImport.alias, goType), nil
return fmt.Sprintf("%s.%s", goImport.Name, goType), nil
}
}
// This function takes a $ref value and checks if it has link to go type.
// #/components/schemas/Foo -> true
// ./local/file.yml#/components/parameters/Bar -> true
// ./local/file.yml -> false
// The function can be used to check whether RefPathToGoType($ref) is possible.
//
func IsGoTypeReference(ref string) bool {
return ref != "" && !IsWholeDocumentReference(ref)
}
// This function takes a $ref value and checks if it is whole document reference.
// #/components/schemas/Foo -> false
// ./local/file.yml#/components/parameters/Bar -> false
// ./local/file.yml -> true
// http://deepmap.com/schemas/document.json -> true
// http://deepmap.com/schemas/document.json#/Foo -> false
//
func IsWholeDocumentReference(ref string) bool {
return ref != "" && !strings.ContainsAny(ref, "#")
}
// This function converts a swagger style path URI with parameters to a
// Echo compatible path URI. We need to replace all of Swagger parameters with
// ":param". Valid input parameters are:
@ -265,7 +308,7 @@ func OrderedParamsFromUri(uri string) []string {
return result
}
// Replaces path parameters with %s
// Replaces path parameters of the form {param} with %s
func ReplacePathParamsWithStr(uri string) string {
return pathParamRE.ReplaceAllString(uri, "%s")
}
@ -456,7 +499,6 @@ func SanitizeEnumNames(enumNames []string) map[string]string {
if _, dup := dupCheck[n]; !dup {
deDup = append(deDup, n)
}
dupCheck[n] = 0
}
@ -464,14 +506,14 @@ func SanitizeEnumNames(enumNames []string) map[string]string {
sanitizedDeDup := make(map[string]string, len(deDup))
for _, n := range deDup {
sanitized := SanitizeGoIdentity(n)
sanitized := SanitizeGoIdentity(SchemaNameToTypeName(n))
if _, dup := dupCheck[sanitized]; !dup {
sanitizedDeDup[sanitized] = n
dupCheck[sanitized]++
} else {
sanitizedDeDup[sanitized+strconv.Itoa(dupCheck[sanitized])] = n
}
dupCheck[sanitized]++
}
return sanitizedDeDup
@ -480,10 +522,14 @@ func SanitizeEnumNames(enumNames []string) map[string]string {
// Converts a Schema name to a valid Go type name. It converts to camel case, and makes sure the name is
// valid in Go
func SchemaNameToTypeName(name string) string {
name = ToCamelCase(name)
// Prepend "N" to schemas starting with a number
if name != "" && unicode.IsDigit([]rune(name)[0]) {
name = "N" + name
if name == "$" {
name = "DollarSign"
} else {
name = ToCamelCase(name)
// Prepend "N" to schemas starting with a number
if name != "" && unicode.IsDigit([]rune(name)[0]) {
name = "N" + name
}
}
return name
}
@ -495,9 +541,10 @@ func SchemaNameToTypeName(name string) string {
// you must specify an additionalProperties type
// If additionalProperties it true/false, this field will be non-nil.
func SchemaHasAdditionalProperties(schema *openapi3.Schema) bool {
if schema.AdditionalPropertiesAllowed != nil {
return *schema.AdditionalPropertiesAllowed
if schema.AdditionalPropertiesAllowed != nil && *schema.AdditionalPropertiesAllowed {
return true
}
if schema.AdditionalProperties != nil {
return true
}
@ -516,6 +563,10 @@ func PathToTypeName(path []string) string {
// StringToGoComment renders a possible multi-line string as a valid Go-Comment.
// Each line is prefixed as a comment.
func StringToGoComment(in string) string {
if len(in) == 0 || len(strings.TrimSpace(in)) == 0 { // ignore empty comment
return ""
}
// Normalize newlines from Windows/Mac to Linux
in = strings.Replace(in, "\r\n", "\n", -1)
in = strings.Replace(in, "\r", "\n", -1)
@ -532,3 +583,17 @@ func StringToGoComment(in string) string {
in = strings.TrimSuffix(in, "\n// ")
return in
}
// This function breaks apart a path, and looks at each element. If it's
// not a path parameter, eg, {param}, it will URL-escape the element.
func EscapePathElements(path string) string {
elems := strings.Split(path, "/")
for i, e := range elems {
if strings.HasPrefix(e, "{") && strings.HasSuffix(e, "}") {
// This is a path parameter, we don't want to mess with its value
continue
}
elems[i] = url.QueryEscape(e)
}
return strings.Join(elems, "/")
}

View file

@ -0,0 +1,24 @@
// Copyright 2021 DeepMap, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package runtime
// Binder is the interface implemented by types that can be bound to a query string or a parameter string
// The input can be assumed to be a valid string. If you define a Bind method you are responsible for all
// data being completely bound to the type.
//
// By convention, to approximate the behavior of Bind functions themselves,
// Binder implements Bind("") as a no-op.
type Binder interface {
Bind(src string) error
}

View file

@ -14,6 +14,7 @@
package runtime
import (
"encoding"
"encoding/json"
"fmt"
"net/url"
@ -29,13 +30,51 @@ import (
// This function binds a parameter as described in the Path Parameters
// section here to a Go object:
// https://swagger.io/docs/specification/serialization/
// It is a backward compatible function to clients generated with codegen
// up to version v1.5.5. v1.5.6+ calls the function below.
func BindStyledParameter(style string, explode bool, paramName string,
value string, dest interface{}) error {
return BindStyledParameterWithLocation(style, explode, paramName, ParamLocationUndefined, value, dest)
}
// This function binds a parameter as described in the Path Parameters
// section here to a Go object:
// https://swagger.io/docs/specification/serialization/
func BindStyledParameterWithLocation(style string, explode bool, paramName string,
paramLocation ParamLocation, value string, dest interface{}) error {
if value == "" {
return fmt.Errorf("parameter '%s' is empty, can't bind its value", paramName)
}
// Based on the location of the parameter, we need to unescape it properly.
var err error
switch paramLocation {
case ParamLocationQuery, ParamLocationUndefined:
// We unescape undefined parameter locations here for older generated code,
// since prior to this refactoring, they always query unescaped.
value, err = url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("error unescaping query parameter '%s': %v", paramName, err)
}
case ParamLocationPath:
value, err = url.PathUnescape(value)
if err != nil {
return fmt.Errorf("error unescaping path parameter '%s': %v", paramName, err)
}
default:
// Headers and cookies aren't escaped.
}
// If the destination implements encoding.TextUnmarshaler we use it for binding
if tu, ok := dest.(encoding.TextUnmarshaler); ok {
if err := tu.UnmarshalText([]byte(value)); err != nil {
return fmt.Errorf("error unmarshaling '%s' text as %T: %s", value, dest, err)
}
return nil
}
// Everything comes in by pointer, dereference it
v := reflect.Indirect(reflect.ValueOf(dest))
@ -395,22 +434,15 @@ func BindQueryParameter(style string, explode bool, required bool, paramName str
// We don't try to be smart here, if the field exists as a query argument,
// set its value.
func bindParamsToExplodedObject(paramName string, values url.Values, dest interface{}) error {
// special handling for custom types
switch dest.(type) {
case *types.Date:
// Dereference pointers to their destination values
binder, v, t := indirect(dest)
if binder != nil {
return BindStringToObject(values.Get(paramName), dest)
case *time.Time:
return BindStringToObject(values.Get(paramName), dest)
}
v := reflect.Indirect(reflect.ValueOf(dest))
if v.Type().Kind() != reflect.Struct {
if t.Kind() != reflect.Struct {
return fmt.Errorf("unmarshaling query arg '%s' into wrong type", paramName)
}
t := v.Type()
for i := 0; i < t.NumField(); i++ {
fieldT := t.Field(i)
@ -445,3 +477,26 @@ func bindParamsToExplodedObject(paramName string, values url.Values, dest interf
}
return nil
}
// indirect
func indirect(dest interface{}) (interface{}, reflect.Value, reflect.Type) {
v := reflect.ValueOf(dest)
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Binder); ok {
return u, reflect.Value{}, nil
}
}
v = reflect.Indirect(v)
t := v.Type()
// special handling for custom types which might look like an object. We
// don't want to use object binding on them, but rather treat them as
// primitive types. time.Time{} is a unique case since we can't add a Binder
// to it without changing the underlying generated code.
if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
return dest, reflect.Value{}, nil
}
if t.ConvertibleTo(reflect.TypeOf(types.Date{})) {
return dest, reflect.Value{}, nil
}
return nil, v, t
}

View file

@ -76,8 +76,12 @@ func BindStringToObject(src string, dst interface{}) error {
v.SetBool(val)
}
case reflect.Struct:
switch dstType := dst.(type) {
case *time.Time:
// if this is not of type Time or of type Date look to see if this is of type Binder.
if dstType, ok := dst.(Binder); ok {
return dstType.Bind(src)
}
if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
// Don't fail on empty string.
if src == "" {
return nil
@ -90,9 +94,20 @@ func BindStringToObject(src string, dst interface{}) error {
return fmt.Errorf("error parsing '%s' as RFC3339 or 2006-01-02 time: %s", src, err)
}
}
*dstType = parsedTime
// So, assigning this gets a little fun. We have a value to the
// dereference destination. We can't do a conversion to
// time.Time because the result isn't assignable, so we need to
// convert pointers.
if t != reflect.TypeOf(time.Time{}) {
vPtr := v.Addr()
vtPtr := vPtr.Convert(reflect.TypeOf(&time.Time{}))
v = reflect.Indirect(vtPtr)
}
v.Set(reflect.ValueOf(parsedTime))
return nil
case *types.Date:
}
if t.ConvertibleTo(reflect.TypeOf(types.Date{})) {
// Don't fail on empty string.
if src == "" {
return nil
@ -101,9 +116,21 @@ func BindStringToObject(src string, dst interface{}) error {
if err != nil {
return fmt.Errorf("error parsing '%s' as date: %s", src, err)
}
dstType.Time = parsedTime
parsedDate := types.Date{Time: parsedTime}
// We have to do the same dance here to assign, just like with times
// above.
if t != reflect.TypeOf(types.Date{}) {
vPtr := v.Addr()
vtPtr := vPtr.Convert(reflect.TypeOf(&types.Date{}))
v = reflect.Indirect(vtPtr)
}
v.Set(reflect.ValueOf(parsedDate))
return nil
}
// We fall through to the error case below if we haven't handled the
// destination type above.
fallthrough
default:
// We've got a bunch of types unimplemented, don't fail silently.

View file

@ -3,6 +3,7 @@ package runtime
import (
"encoding/json"
"fmt"
"github.com/deepmap/oapi-codegen/pkg/types"
"net/url"
"reflect"
"sort"
@ -11,8 +12,6 @@ import (
"time"
"github.com/pkg/errors"
"github.com/deepmap/oapi-codegen/pkg/types"
)
func marshalDeepObject(in interface{}, path []string) ([]string, error) {
@ -212,26 +211,52 @@ func assignPathValues(dst interface{}, pathValues fieldOrValue) error {
return nil
case reflect.Struct:
// Some special types we care about are structs. Handle them
// here.
if _, isDate := iv.Interface().(types.Date); isDate {
// here. They may be redefined, so we need to do some hoop
// jumping. If the types are aliased, we need to type convert
// the pointer, then set the value of the dereference pointer.
// We check to see if the object implements the Binder interface first.
if dst, isBinder := v.Interface().(Binder); isBinder {
return dst.Bind(pathValues.value)
}
// Then check the legacy types
if it.ConvertibleTo(reflect.TypeOf(types.Date{})) {
var date types.Date
var err error
date.Time, err = time.Parse(types.DateFormat, pathValues.value)
if err != nil {
return errors.Wrap(err, "invalid date format")
}
iv.Set(reflect.ValueOf(date))
dst := iv
if it != reflect.TypeOf(types.Date{}) {
// Types are aliased, convert the pointers.
ivPtr := iv.Addr()
aPtr := ivPtr.Convert(reflect.TypeOf(&types.Date{}))
dst = reflect.Indirect(aPtr)
}
dst.Set(reflect.ValueOf(date))
}
if _, isTime := iv.Interface().(time.Time); isTime {
if it.ConvertibleTo(reflect.TypeOf(time.Time{})) {
var tm time.Time
var err error
tm, err = time.Parse(types.DateFormat, pathValues.value)
tm, err = time.Parse(time.RFC3339Nano, pathValues.value)
if err != nil {
// Fall back to parsing it as a date.
tm, err = time.Parse(types.DateFormat, pathValues.value)
if err != nil {
return fmt.Errorf("error parsing tim as RFC3339 or 2006-01-02 time: %s", err)
}
return errors.Wrap(err, "invalid date format")
}
iv.Set(reflect.ValueOf(tm))
dst := iv
if it != reflect.TypeOf(time.Time{}) {
// Types are aliased, convert the pointers.
ivPtr := iv.Addr()
aPtr := ivPtr.Convert(reflect.TypeOf(&time.Time{}))
dst = reflect.Indirect(aPtr)
}
dst.Set(reflect.ValueOf(tm))
}
fieldMap, err := fieldIndicesByJsonTag(iv.Interface())
if err != nil {
return errors.Wrap(err, "failed enumerating fields")
@ -311,9 +336,9 @@ func assignSlice(dst reflect.Value, pathValues fieldOrValue) error {
// This could be cleaner, but we can call into assignPathValues to
// avoid recreating this logic.
for i:=0; i < nValues; i++ {
for i := 0; i < nValues; i++ {
dstElem := dst.Index(i).Addr()
err := assignPathValues(dstElem.Interface(), fieldOrValue{value:values[i]})
err := assignPathValues(dstElem.Interface(), fieldOrValue{value: values[i]})
if err != nil {
return errors.Wrap(err, "error binding array")
}

View file

@ -14,18 +14,42 @@
package runtime
import (
"errors"
"fmt"
"net/url"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
"github.com/deepmap/oapi-codegen/pkg/types"
)
// Given an input value, such as a primitive type, array or object, turn it
// into a parameter based on style/explode definition.
// Parameter escaping works differently based on where a header is found
type ParamLocation int
const (
ParamLocationUndefined ParamLocation = iota
ParamLocationQuery
ParamLocationPath
ParamLocationHeader
ParamLocationCookie
)
// This function is used by older generated code, and must remain compatible
// with that code. It is not to be used in new templates. Please see the
// function below, which can specialize its output based on the location of
// the parameter.
func StyleParam(style string, explode bool, paramName string, value interface{}) (string, error) {
return StyleParamWithLocation(style, explode, paramName, ParamLocationUndefined, value)
}
// Given an input value, such as a primitive type, array or object, turn it
// into a parameter based on style/explode definition, performing whatever
// escaping is necessary based on parameter location
func StyleParamWithLocation(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
t := reflect.TypeOf(value)
v := reflect.ValueOf(value)
@ -46,17 +70,17 @@ func StyleParam(style string, explode bool, paramName string, value interface{})
for i := 0; i < n; i++ {
sliceVal[i] = v.Index(i).Interface()
}
return styleSlice(style, explode, paramName, sliceVal)
return styleSlice(style, explode, paramName, paramLocation, sliceVal)
case reflect.Struct:
return styleStruct(style, explode, paramName, value)
return styleStruct(style, explode, paramName, paramLocation, value)
case reflect.Map:
return styleMap(style, explode, paramName, value)
return styleMap(style, explode, paramName, paramLocation, value)
default:
return stylePrimitive(style, explode, paramName, value)
return stylePrimitive(style, explode, paramName, paramLocation, value)
}
}
func styleSlice(style string, explode bool, paramName string, values []interface{}) (string, error) {
func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, values []interface{}) (string, error) {
if style == "deepObject" {
if !explode {
return "", errors.New("deepObjects must be exploded")
@ -111,9 +135,12 @@ func styleSlice(style string, explode bool, paramName string, values []interface
// We're going to assume here that the array is one of simple types.
var err error
var part string
parts := make([]string, len(values))
for i, v := range values {
parts[i], err = primitiveToString(v)
part, err = primitiveToString(v)
part = escapeParameterString(part, paramLocation)
parts[i] = part
if err != nil {
return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
}
@ -132,24 +159,35 @@ func sortedKeys(strMap map[string]string) []string {
return keys
}
// This is a special case. The struct may be a date or time, in
// which case, marshal it in correct format.
func marshalDateTimeValue(value interface{}) (string, bool) {
v := reflect.Indirect(reflect.ValueOf(value))
t := v.Type()
// This is a special case. The struct may be a time, in which case, marshal
// it in RFC3339 format.
func marshalTimeValue(value interface{}) (string, bool) {
if timeVal, ok := value.(*time.Time); ok {
if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
tt := v.Convert(reflect.TypeOf(time.Time{}))
timeVal := tt.Interface().(time.Time)
return timeVal.Format(time.RFC3339Nano), true
}
if timeVal, ok := value.(time.Time); ok {
return timeVal.Format(time.RFC3339Nano), true
if t.ConvertibleTo(reflect.TypeOf(types.Date{})) {
d := v.Convert(reflect.TypeOf(types.Date{}))
dateVal := d.Interface().(types.Date)
return dateVal.Format(types.DateFormat), true
}
return "", false
}
func styleStruct(style string, explode bool, paramName string, value interface{}) (string, error) {
if timeVal, ok := marshalTimeValue(value); ok {
return stylePrimitive(style, explode, paramName, timeVal)
func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
if timeVal, ok := marshalDateTimeValue(value); ok {
styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, timeVal)
if err != nil {
return "", errors.Wrap(err, "failed to style time")
}
return styledVal, nil
}
if style == "deepObject" {
@ -191,10 +229,10 @@ func styleStruct(style string, explode bool, paramName string, value interface{}
fieldDict[fieldName] = str
}
return processFieldDict(style, explode, paramName, fieldDict)
return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
}
func styleMap(style string, explode bool, paramName string, value interface{}) (string, error) {
func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
if style == "deepObject" {
if !explode {
return "", errors.New("deepObjects must be exploded")
@ -215,11 +253,10 @@ func styleMap(style string, explode bool, paramName string, value interface{}) (
}
fieldDict[fieldName] = str
}
return processFieldDict(style, explode, paramName, fieldDict)
return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
}
func processFieldDict(style string, explode bool, paramName string, fieldDict map[string]string) (string, error) {
func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, fieldDict map[string]string) (string, error) {
var parts []string
// This works for everything except deepObject. We'll handle that one
@ -227,12 +264,12 @@ func processFieldDict(style string, explode bool, paramName string, fieldDict ma
if style != "deepObject" {
if explode {
for _, k := range sortedKeys(fieldDict) {
v := fieldDict[k]
v := escapeParameterString(fieldDict[k], paramLocation)
parts = append(parts, k+"="+v)
}
} else {
for _, k := range sortedKeys(fieldDict) {
v := fieldDict[k]
v := escapeParameterString(fieldDict[k], paramLocation)
parts = append(parts, k)
parts = append(parts, v)
}
@ -286,7 +323,7 @@ func processFieldDict(style string, explode bool, paramName string, fieldDict ma
return prefix + strings.Join(parts, separator), nil
}
func stylePrimitive(style string, explode bool, paramName string, value interface{}) (string, error) {
func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
strVal, err := primitiveToString(value)
if err != nil {
return "", err
@ -304,7 +341,7 @@ func stylePrimitive(style string, explode bool, paramName string, value interfac
default:
return "", fmt.Errorf("unsupported style '%s'", style)
}
return prefix + strVal, nil
return prefix + escapeParameterString(strVal, paramLocation), nil
}
// Converts a primitive value to a string. We need to do this based on the
@ -320,8 +357,10 @@ func primitiveToString(value interface{}) (string, error) {
switch kind {
case reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
output = strconv.FormatInt(v.Int(), 10)
case reflect.Float32, reflect.Float64:
case reflect.Float64:
output = strconv.FormatFloat(v.Float(), 'f', -1, 64)
case reflect.Float32:
output = strconv.FormatFloat(v.Float(), 'f', -1, 32)
case reflect.Bool:
if v.Bool() {
output = "true"
@ -335,3 +374,17 @@ func primitiveToString(value interface{}) (string, error) {
}
return output, nil
}
// This function escapes a parameter value bas on the location of that parameter.
// Query params and path params need different kinds of escaping, while header
// and cookie params seem not to need escaping.
func escapeParameterString(value string, paramLocation ParamLocation) string {
switch paramLocation {
case ParamLocationQuery:
return url.QueryEscape(value)
case ParamLocationPath:
return url.PathEscape(value)
default:
return value
}
}

View file

@ -28,10 +28,29 @@ func ParseCommandlineMap(src string) (map[string]string, error) {
return result, nil
}
// ParseCommandLineList parses comma separated string lists which are passed
// in on the command line. Spaces are trimmed off both sides of result
// strings.
func ParseCommandLineList(input string) []string {
input = strings.TrimSpace(input)
if len(input) == 0 {
return nil
}
splitInput := strings.Split(input, ",")
args := make([]string, 0, len(splitInput))
for _, s := range splitInput {
s = strings.TrimSpace(s)
if len(s) > 0 {
args = append(args, s)
}
}
return args
}
// This function splits a string along the specifed separator, but it
// ignores anything between double quotes for splitting. We do simple
// inside/outside quote counting. Quotes are not stripped from output.
func splitString(s string, sep rune) ([]string) {
func splitString(s string, sep rune) []string {
const escapeChar rune = '"'
var parts []string
@ -39,7 +58,7 @@ func splitString(s string, sep rune) ([]string) {
inQuotes := false
for _, c := range s {
if c == escapeChar{
if c == escapeChar {
if inQuotes {
inQuotes = false
} else {

View file

@ -6,15 +6,15 @@ import (
"github.com/getkin/kin-openapi/openapi3"
)
func LoadSwagger(filePath string) (swagger *openapi3.Swagger, err error) {
func LoadSwagger(filePath string) (swagger *openapi3.T, err error) {
loader := openapi3.NewSwaggerLoader()
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
u, err := url.Parse(filePath)
if err == nil && u.Scheme != "" && u.Host != "" {
return loader.LoadSwaggerFromURI(u)
return loader.LoadFromURI(u)
} else {
return loader.LoadSwaggerFromFile(filePath)
return loader.LoadFromFile(filePath)
}
}