209 lines
4.9 KiB
Go
209 lines
4.9 KiB
Go
package backends
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
jwtGo "github.com/golang-jwt/jwt"
|
|
"github.com/iegomez/mosquitto-go-auth/hashing"
|
|
"github.com/pkg/errors"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type JWT struct {
|
|
mode string
|
|
name string
|
|
checker jwtChecker
|
|
}
|
|
|
|
type tokenOptions struct {
|
|
parseToken bool
|
|
skipUserExpiration bool
|
|
skipACLExpiration bool
|
|
secret string
|
|
userFieldKey string
|
|
}
|
|
|
|
type jwtChecker interface {
|
|
GetUser(username string) (bool, error)
|
|
GetSuperuser(username string) (bool, error)
|
|
CheckAcl(username, topic, clientid string, acc int32) (bool, error)
|
|
Halt()
|
|
}
|
|
|
|
const (
|
|
remoteMode = "remote"
|
|
localMode = "local"
|
|
jsMode = "js"
|
|
filesMode = "files"
|
|
claimsSubjectKey = "sub"
|
|
claimsUsernameKey = "username"
|
|
claimsIssKey = "iss"
|
|
)
|
|
|
|
func NewJWT(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer, version string) (*JWT, error) {
|
|
log.SetLevel(logLevel)
|
|
|
|
jwt := &JWT{}
|
|
|
|
var err error
|
|
var checker jwtChecker
|
|
|
|
var options tokenOptions
|
|
|
|
if parseToken, ok := authOpts["jwt_parse_token"]; ok && parseToken == "true" {
|
|
options.parseToken = true
|
|
}
|
|
|
|
if skipUserExpiration, ok := authOpts["jwt_skip_user_expiration"]; ok && skipUserExpiration == "true" {
|
|
options.skipUserExpiration = true
|
|
}
|
|
|
|
if skipACLExpiration, ok := authOpts["jwt_skip_acl_expiration"]; ok && skipACLExpiration == "true" {
|
|
options.skipACLExpiration = true
|
|
}
|
|
|
|
if secret, ok := authOpts["jwt_secret"]; ok {
|
|
options.secret = secret
|
|
}
|
|
|
|
if userField, ok := authOpts["jwt_userfield"]; ok && userField == "Username" {
|
|
options.userFieldKey = claimsUsernameKey
|
|
} else {
|
|
options.userFieldKey = claimsSubjectKey
|
|
}
|
|
|
|
switch authOpts["jwt_mode"] {
|
|
case jsMode:
|
|
jwt.mode = jsMode
|
|
checker, err = NewJsJWTChecker(authOpts, options)
|
|
case localMode:
|
|
jwt.mode = localMode
|
|
checker, err = NewLocalJWTChecker(authOpts, logLevel, hasher, options)
|
|
case remoteMode:
|
|
jwt.mode = remoteMode
|
|
checker, err = NewRemoteJWTChecker(authOpts, options, version)
|
|
case filesMode:
|
|
jwt.mode = filesMode
|
|
checker, err = NewFilesJWTChecker(authOpts, logLevel, hasher, options)
|
|
default:
|
|
err = errors.New("unknown JWT mode")
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
jwt.name = fmt.Sprintf("JWT %s", authOpts["jwt_mode"])
|
|
|
|
jwt.checker = checker
|
|
|
|
return jwt, nil
|
|
}
|
|
|
|
// GetUser authenticates a given user.
|
|
func (o *JWT) GetUser(token, password, clientid string) (bool, error) {
|
|
return o.checker.GetUser(token)
|
|
}
|
|
|
|
// GetSuperuser checks if the given user is a superuser.
|
|
func (o *JWT) GetSuperuser(token string) (bool, error) {
|
|
return o.checker.GetSuperuser(token)
|
|
}
|
|
|
|
// CheckAcl checks user authorization.
|
|
func (o *JWT) CheckAcl(token, topic, clientid string, acc int32) (bool, error) {
|
|
return o.checker.CheckAcl(token, topic, clientid, acc)
|
|
}
|
|
|
|
// GetName returns the backend's name
|
|
func (o *JWT) GetName() string {
|
|
return o.name
|
|
}
|
|
|
|
// Halt closes any db connection.
|
|
func (o *JWT) Halt() {
|
|
o.checker.Halt()
|
|
}
|
|
|
|
func getJWTClaims(secret string, tokenStr string, skipExpiration bool) (*jwtGo.MapClaims, error) {
|
|
|
|
jwtToken, err := jwtGo.ParseWithClaims(tokenStr, &jwtGo.MapClaims{}, func(token *jwtGo.Token) (interface{}, error) {
|
|
return []byte(secret), nil
|
|
})
|
|
|
|
expirationError := false
|
|
if err != nil {
|
|
if !skipExpiration {
|
|
log.Debugf("jwt parse error: %s", err)
|
|
return nil, err
|
|
}
|
|
|
|
if v, ok := err.(*jwtGo.ValidationError); ok && v.Errors == jwtGo.ValidationErrorExpired {
|
|
expirationError = true
|
|
}
|
|
}
|
|
|
|
if !jwtToken.Valid && !expirationError {
|
|
return nil, errors.New("jwt invalid token")
|
|
}
|
|
|
|
claims, ok := jwtToken.Claims.(*jwtGo.MapClaims)
|
|
if !ok {
|
|
log.Debugf("jwt error: expected *MapClaims, got %T", jwtToken.Claims)
|
|
return nil, errors.New("got strange claims")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
func getUsernameForToken(options tokenOptions, tokenStr string, skipExpiration bool) (string, error) {
|
|
claims, err := getJWTClaims(options.secret, tokenStr, skipExpiration)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
username, found := (*claims)[options.userFieldKey]
|
|
if !found {
|
|
return "", nil
|
|
}
|
|
|
|
usernameString, ok := username.(string)
|
|
if !ok {
|
|
log.Debugf("jwt error: username expected to be string, got %T", username)
|
|
return "", errors.New("got strange username")
|
|
}
|
|
|
|
return usernameString, nil
|
|
}
|
|
|
|
func getClaimsForToken(options tokenOptions, tokenStr string, skipExpiration bool) (map[string]interface{}, error) {
|
|
claims, err := getJWTClaims(options.secret, tokenStr, skipExpiration)
|
|
if err != nil {
|
|
return make(map[string]interface{}), err
|
|
}
|
|
|
|
return map[string]interface{}(*claims), nil
|
|
}
|
|
|
|
func getIssForToken(options tokenOptions, tokenStr string, skipExpiration bool) (string, error) {
|
|
claims, err := getJWTClaims(options.secret, tokenStr, skipExpiration)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
iss, found := (*claims)[claimsIssKey]
|
|
if !found {
|
|
return "", nil
|
|
}
|
|
|
|
issString, ok := iss.(string)
|
|
if !ok {
|
|
log.Debugf("jwt error: iss expected to be string, got %T", iss)
|
|
return "", errors.New("got strange iss")
|
|
}
|
|
|
|
return issString, nil
|
|
}
|