mosquitto-go-auth/backends/jwt_local.go

186 lines
4.2 KiB
Go

package backends
import (
"database/sql"
"strings"
"github.com/iegomez/mosquitto-go-auth/hashing"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
type localJWTChecker struct {
db string
postgres Postgres
mysql Mysql
userQuery string
hasher hashing.HashComparer
options tokenOptions
}
const (
mysqlDB = "mysql"
postgresDB = "postgres"
)
// NewLocalJWTChecker initializes a checker with a local DB.
func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer, options tokenOptions) (jwtChecker, error) {
checker := &localJWTChecker{
hasher: hasher,
db: postgresDB,
options: options,
}
missingOpts := ""
localOk := true
if options.secret == "" {
return nil, errors.New("JWT backend error: missing jwt secret")
}
if db, ok := authOpts["jwt_db"]; ok {
checker.db = db
}
if userQuery, ok := authOpts["jwt_userquery"]; ok {
checker.userQuery = userQuery
} else {
localOk = false
missingOpts += " jwt_userquery"
}
if !localOk {
return nil, errors.Errorf("JWT backend error: missing local options: %s", missingOpts)
}
// Extract DB specific opts (e.g., host, port, etc.) to construct the underlying DB backend.
dbAuthOpts := extractOpts(authOpts, checker.db)
if checker.db == mysqlDB {
mysql, err := NewMysql(dbAuthOpts, logLevel, hasher)
if err != nil {
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local jwt: %s", err)
}
checker.mysql = mysql
return checker, nil
}
postgres, err := NewPostgres(dbAuthOpts, logLevel, hasher)
if err != nil {
return nil, errors.Errorf("JWT backend error: couldn't create postgres connector for local jwt: %s", err)
}
checker.postgres = postgres
return checker, nil
}
func (o *localJWTChecker) GetUser(token string) (bool, error) {
username, err := getUsernameForToken(o.options, token, o.options.skipUserExpiration)
if err != nil {
log.Printf("jwt local get user error: %s", err)
return false, err
}
return o.getLocalUser(username)
}
func (o *localJWTChecker) GetSuperuser(token string) (bool, error) {
username, err := getUsernameForToken(o.options, token, o.options.skipUserExpiration)
if err != nil {
log.Printf("jwt local get superuser error: %s", err)
return false, err
}
if o.db == mysqlDB {
return o.mysql.GetSuperuser(username)
}
return o.postgres.GetSuperuser(username)
}
func (o *localJWTChecker) CheckAcl(token, topic, clientid string, acc int32) (bool, error) {
username, err := getUsernameForToken(o.options, token, o.options.skipACLExpiration)
if err != nil {
log.Printf("jwt local check acl error: %s", err)
return false, err
}
if o.db == mysqlDB {
return o.mysql.CheckAcl(username, topic, clientid, acc)
}
return o.postgres.CheckAcl(username, topic, clientid, acc)
}
func (o *localJWTChecker) Halt() {
if o.postgres != (Postgres{}) && o.postgres.DB != nil {
err := o.postgres.DB.Close()
if err != nil {
log.Errorf("JWT cleanup error: %s", err)
}
} else if o.mysql != (Mysql{}) && o.mysql.DB != nil {
err := o.mysql.DB.Close()
if err != nil {
log.Errorf("JWT cleanup error: %s", err)
}
}
}
func (o *localJWTChecker) getLocalUser(username string) (bool, error) {
if o.userQuery == "" {
return false, nil
}
var count sql.NullInt64
var err error
if o.db == mysqlDB {
err = o.mysql.DB.Get(&count, o.userQuery, username)
} else {
err = o.postgres.DB.Get(&count, o.userQuery, username)
}
if err != nil {
log.Debugf("local JWT get user error: %s", err)
return false, err
}
if !count.Valid {
log.Debugf("local JWT get user error: user %s not found", username)
return false, nil
}
if count.Int64 > 0 {
return true, nil
}
return false, nil
}
func extractOpts(authOpts map[string]string, db string) map[string]string {
dbAuthOpts := make(map[string]string)
dbPrefix := "pg"
if db == mysqlDB {
dbPrefix = mysqlDB
}
prefix := "jwt_" + dbPrefix
for k, v := range authOpts {
if strings.HasPrefix(k, prefix) {
dbAuthOpts[strings.TrimPrefix(k, "jwt_")] = v
}
}
// Set a dummy query for user check since it won't be checked with the DB backend's method.
dbAuthOpts[dbPrefix+"_userquery"] = "dummy"
return dbAuthOpts
}