internal/auth: add TenantChannelMiddleware

Extracts the tenant from the JWT and sets it in the request context.
This commit is contained in:
Sanne Raymaekers 2023-06-17 17:39:27 +02:00
parent bec17b6d47
commit 0f946e1c9e
4 changed files with 50 additions and 11 deletions

View file

@ -0,0 +1,29 @@
package auth
import (
"errors"
"fmt"
"github.com/labstack/echo/v4"
)
const TenantCtxKey string = "tenant"
func TenantChannelMiddleware(tenantProviderFields []string, onFail error) func(next echo.HandlerFunc) echo.HandlerFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {
tenant, err := GetFromClaims(ctx.Request().Context(), tenantProviderFields)
// Allowlisted paths won't have a token
if err != nil && !errors.Is(err, NoJWTError) {
return onFail
}
// prefix the tenant to prevent collisions if support for specifying channels in a request is ever added
if tenant != "" {
ctx.Set(TenantCtxKey, fmt.Sprintf("org-%s", tenant))
}
return next(ctx)
}
}
}

View file

@ -68,6 +68,7 @@ const (
ErrorGettingOSBuildJobStatus ServiceErrorCode = 1017
ErrorGettingAWSEC2JobStatus ServiceErrorCode = 1018
ErrorGettingJobType ServiceErrorCode = 1019
ErrorTenantNotInContext ServiceErrorCode = 1020
// Errors contained within this file
ErrorUnspecified ServiceErrorCode = 10000
@ -143,6 +144,7 @@ func getServiceErrors() serviceErrors {
serviceError{ErrorGettingOSBuildJobStatus, http.StatusInternalServerError, "Unable to get osbuild job status"},
serviceError{ErrorGettingAWSEC2JobStatus, http.StatusInternalServerError, "Unable to get ec2 job status"},
serviceError{ErrorGettingJobType, http.StatusInternalServerError, "Unable to get job type of existing job"},
serviceError{ErrorTenantNotInContext, http.StatusInternalServerError, "Unable to retrieve tenant from request context"},
serviceError{ErrorUnspecified, http.StatusInternalServerError, "Unspecified internal error "},
serviceError{ErrorNotHTTPError, http.StatusInternalServerError, "Error is not an instance of HTTPError"},

View file

@ -6,22 +6,23 @@ import (
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/osbuild/osbuild-composer/internal/auth"
)
const TenantCtxKey string = "tenant"
// getTenantChannel returns the tenant channel for the provided request context
func (s *Server) getTenantChannel(ctx echo.Context) (string, error) {
// channel is empty if JWT is not enabled
var channel string
if s.config.JWTEnabled {
tenant, err := auth.GetFromClaims(ctx.Request().Context(), s.config.TenantProviderFields)
if err != nil {
return "", err
tenant, ok := ctx.Get(auth.TenantCtxKey).(string)
if !ok {
return "", HTTPError(ErrorTenantNotInContext)
}
// prefix the tenant to prevent collisions if support for specifying channels in a request is ever added
channel = "org-" + tenant
return tenant, nil
}
return channel, nil
// channel is empty if JWT is not enabled
return "", nil
}
type ComposeHandlerFunc func(ctx echo.Context, id string) error
@ -36,7 +37,7 @@ func (s *Server) EnsureJobChannel(next ComposeHandlerFunc) ComposeHandlerFunc {
ctxChannel, err := s.getTenantChannel(c)
if err != nil {
return HTTPErrorWithInternal(ErrorTenantNotFound, err)
return err
}
jobChannel, err := s.workers.JobChannel(jobId)

View file

@ -19,6 +19,7 @@ import (
"github.com/osbuild/osbuild-composer/pkg/jobqueue"
"github.com/osbuild/osbuild-composer/internal/auth"
"github.com/osbuild/osbuild-composer/internal/blueprint"
"github.com/osbuild/osbuild-composer/internal/common"
"github.com/osbuild/osbuild-composer/internal/container"
@ -90,8 +91,14 @@ func (s *Server) Handler(path string) http.Handler {
server: s,
}
statusMW := prometheus.StatusMiddleware(prometheus.ComposerSubsystem)
RegisterHandlers(e.Group(path, prometheus.MetricsMiddleware, s.ValidateRequest, statusMW), &handler)
mws := []echo.MiddlewareFunc{
prometheus.StatusMiddleware(prometheus.ComposerSubsystem),
}
if s.config.JWTEnabled {
mws = append(mws, auth.TenantChannelMiddleware(s.config.TenantProviderFields, HTTPError(ErrorTenantNotFound)))
}
mws = append(mws, prometheus.MetricsMiddleware, s.ValidateRequest)
RegisterHandlers(e.Group(path, mws...), &handler)
return e
}