diff --git a/admin.go b/admin.go index 1f11fb44..ed2a5880 100644 --- a/admin.go +++ b/admin.go @@ -154,7 +154,7 @@ func toSnakeCase(str string) string { // JSONMarshal Generates JSON format from an object func JSONMarshal(v interface{}, safeEncoding bool) ([]byte, error) { // b, err := json.Marshal(v) - b, err := json.MarshalIndent(v, "", " ") + b, err := jsonMarshal(v) if safeEncoding { b = bytes.Replace(b, []byte("\\u003c"), []byte("<"), -1) @@ -164,19 +164,73 @@ func JSONMarshal(v interface{}, safeEncoding bool) ([]byte, error) { return b, err } +func nullZeroValueStructs(record map[string]interface{}) map[string]interface{} { + for k := range record { + switch v := record[k].(type) { + case map[string]interface{}: + if id, ok := v["ID"].(float64); ok && id == 0 { + record[k] = nil + } else if id, ok := v["id"].(float64); ok && id == 0 { + record[k] = nil + } else { + record[k] = nullZeroValueStructs(v) + } + } + } + return record +} + +func removeZeroValueStructs(buf []byte) []byte { + response := map[string]interface{}{} + json.Unmarshal(buf, &response) + if val, ok := response["result"].(map[string]interface{}); ok { + val = nullZeroValueStructs(val) + buf, _ = json.Marshal(val) + return buf + } + if _, ok := response["result"].([]interface{}); !ok { + return buf + } + val := response["result"].([]interface{}) + var record map[string]interface{} + for i := range val { + record = val[i].(map[string]interface{}) + record = nullZeroValueStructs(record) + val[i] = record + } + response["result"] = val + buf, _ = json.Marshal(response) + return buf +} + +func jsonMarshal(v interface{}) ([]byte, error) { + var buf []byte + var err error + if CompressJSON { + buf, err = json.Marshal(v) + if err == nil && RemoveZeroValueJSON { + buf = removeZeroValueStructs(buf) + } + } else { + buf, err = json.MarshalIndent(v, "", " ") + } + + return buf, err +} + // ReturnJSON returns json to the client func ReturnJSON(w http.ResponseWriter, r *http.Request, v interface{}) { // Set content type in header w.Header().Set("Content-Type", "application/json") // Marshal content - b, err := json.MarshalIndent(v, "", " ") + b, err := jsonMarshal(v) if err != nil { response := map[string]interface{}{ "status": "error", "error_msg": fmt.Sprintf("unable to encode JSON. %s", err), } - b, _ = json.MarshalIndent(response, "", " ") + b, _ = jsonMarshal(response) w.Write(b) return } diff --git a/admin_test.go b/admin_test.go index 84bbe5dd..041741dc 100644 --- a/admin_test.go +++ b/admin_test.go @@ -204,12 +204,12 @@ func (t *UAdminTests) TestReturnJSON() { out string }{ {map[string]interface{}{"ID": 1, "Name": "Test"}, `{ - "ID": 1, - "Name": "Test" + "ID": 1, + "Name": "Test" }`}, {math.NaN(), `{ - "error_msg": "unable to encode JSON. json: unsupported value: NaN", - "status": "error" + "error_msg": "unable to encode JSON. json: unsupported value: NaN", + "status": "error" }`}, } diff --git a/auth.go b/auth.go index 9c093efb..ac9e5120 100644 --- a/auth.go +++ b/auth.go @@ -4,12 +4,17 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" + "io" "math/big" "net" + "os" "path" + "sync" "crypto/hmac" "crypto/rand" + "crypto/rsa" "crypto/sha256" "math" "net/http" @@ -17,6 +22,7 @@ import ( "strings" "time" + "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" ) @@ -36,6 +42,8 @@ var JWT = "" // used to identify the as JWT audience. var JWTIssuer = "" +var JWTAlgo = "HS256" //"RS256" + // AcceptedJWTIssuers is a list of accepted JWT issuers. By default the // local JWTIssuer is accepted. To accept other issuers, add them to // this list @@ -47,6 +55,9 @@ var bcryptDiff = 12 // cachedSessions is variable for keeping active sessions var cachedSessions map[string]Session +// Need to have a lock to protect it from race conditions during concurrent writes. +var cachedSessionsMutex sync.RWMutex + // invalidAttempts keeps track of invalid password attempts // per IP address var invalidAttempts = map[string]int{} @@ -84,6 +95,9 @@ func GenerateBase32(length int) string { // hashPass Generates a hash from a password and salt func hashPass(pass string) string { password := []byte(pass + Salt) + if len(password) > 72 { + password = password[:72] + } hash, err := bcrypt.GenerateFromPassword(password, bcryptDiff) if err != nil { Trail(ERROR, "uadmin.auth.hashPass.GenerateFromPassword: %s", err) @@ -154,15 +168,22 @@ func createJWT(r *http.Request, s *Session) string { if !isValidSession(r, s) { return "" } + alg := JWTAlgo + aud := JWTIssuer + SSO := false + if r.Context().Value(CKey("aud")) != nil { + aud = r.Context().Value(CKey("aud")).(string) + SSO = true + } header := map[string]interface{}{ - "alg": "HS256", + "alg": alg, "typ": "JWT", } payload := map[string]interface{}{ "sub": s.User.Username, "iat": s.LastLogin.Unix(), "iss": JWTIssuer, - "aud": JWTIssuer, + "aud": aud, } if s.ExpiresOn != nil { payload["exp"] = s.ExpiresOn.Unix() @@ -173,16 +194,98 @@ func createJWT(r *http.Request, s *Session) string { payload = CustomJWT(r, s, payload) } - jHeader, _ := json.Marshal(header) - jPayload, _ := json.Marshal(payload) - b64Header := base64.RawURLEncoding.EncodeToString(jHeader) - b64Payload := base64.RawURLEncoding.EncodeToString(jPayload) + // TODO: Add custom handler to customize JWT + // This custom function show have parameters for: + // JWT Object + // SSO boolean + // Algorithm + // User + // *Session + + if alg == "HS256" { + jHeader, _ := json.Marshal(header) + jPayload, _ := json.Marshal(payload) + b64Header := base64.RawURLEncoding.EncodeToString(jHeader) + b64Payload := base64.RawURLEncoding.EncodeToString(jPayload) + + hash := hmac.New(sha256.New, []byte(JWT+s.Key)) + hash.Write([]byte(b64Header + "." + b64Payload)) + signature := hash.Sum(nil) + b64Signature := base64.RawURLEncoding.EncodeToString(signature) + return b64Header + "." + b64Payload + "." + b64Signature + } else if alg == "RS256" { + buf, err := os.ReadFile(".jwt-rsa-private.pem") + if err != nil { + return "" + } + key, err := jwt.ParseRSAPrivateKeyFromPEM(buf) + if err != nil { + return "" + } + + // Customize JWT Data + header["kid"] = "1" + + // Extra customization for SSO + if SSO { + payload["name"] = s.User.String() + payload["given_name"] = s.User.FirstName + payload["family_name"] = s.User.LastName + payload["email"] = s.User.Email + if s.User.Photo != "" { + payload["picture"] = JWTIssuer + strings.TrimSuffix(RootURL, "/") + s.User.Photo + "?token=" + strings.TrimPrefix(hashPass(s.User.Photo), "$2a$12$") + } + + groups := []map[string]interface{}{} + + if s.User.UserGroupID != 0 { + Preload(&s.User, "UserGroup") + groups = append(groups, map[string]interface{}{ + "displayName": s.User.UserGroup.GroupName, + "id": s.User.UserGroupID, + }) + } + if s.User.Admin { + groups = append(groups, map[string]interface{}{ + "displayName": "$admin", + "id": 0, + }) + } + payload["groups"] = groups + + entitlements := []map[string]interface{}{} + for k := range models { + perm := s.User.GetAccess(k) + entitlements = append(entitlements, map[string]interface{}{ + "modelName": k, + "read": perm.Read, + "add": perm.Add, + "edit": perm.Edit, + "delete": perm.Delete, + "approval": perm.Approval, + }) + } + + payload["entitlements"] = entitlements + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(payload)) + + for k, v := range header { + token.Header[k] = v + } + + tokenRaw, err := token.SignedString(key) + + if err != nil { + return "" + } + return tokenRaw + } else { + Trail(ERROR, "Unknown algorithm for JWT (%s)", alg) + return "" + } - hash := hmac.New(sha256.New, []byte(JWT+s.Key)) - hash.Write([]byte(b64Header + "." + b64Payload)) - signature := hash.Sum(nil) - b64Signature := base64.RawURLEncoding.EncodeToString(signature) - return b64Header + "." + b64Payload + "." + b64Signature } func isValidSession(r *http.Request, s *Session) bool { @@ -236,6 +339,9 @@ func getSessionFromRequest(r *http.Request) *Session { // Login return *User and a bool for Is OTP Required func Login(r *http.Request, username string, password string) (*Session, bool) { + if PreLoginHandler != nil { + PreLoginHandler(r, username, password) + } // Get the user from DB user := User{} Get(&user, "username = ?", username) @@ -354,6 +460,8 @@ func Logout(r *http.Request) { // Delete the cookie from memory if we sessions are cached if CacheSessions { + cachedSessionsMutex.Lock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.Unlock() // Ensure the mutex is unlocked when the function exits delete(cachedSessions, s.Key) } @@ -559,6 +667,8 @@ func getNetSize(r *http.Request, net string) int { func getSessionByKey(key string) *Session { s := Session{} if CacheSessions { + cachedSessionsMutex.RLock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.RUnlock() // Ensure the mutex is unlocked when the function exits s = cachedSessions[key] } else { Get(&s, "`key` = ?", key) @@ -569,140 +679,353 @@ func getSessionByKey(key string) *Session { return &s } -func getSession(r *http.Request) string { - key, err := r.Cookie("session") - if err == nil && key != nil { - return key.Value +func getJWT(r *http.Request) string { + // JWT + if r.Header.Get("Authorization") == "" { + return "" } - if r.Method == "GET" && r.FormValue("session") != "" { - return r.FormValue("session") + if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") { + return "" } - if r.Method == "POST" { - err := r.ParseMultipartForm(2 << 10) - if err != nil { - r.ParseForm() - } - if r.FormValue("session") != "" { - return r.FormValue("session") - } + + jwtToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + jwtParts := strings.Split(jwtToken, ".") + + if len(jwtParts) != 3 { + return "" } - // JWT - if r.Header.Get("Authorization") != "" { - if strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") { - jwt := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - jwtParts := strings.Split(jwt, ".") - if len(jwtParts) != 3 { - return "" - } + jHeader, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[0]) + if err != nil { + return "" + } + jPayload, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[1]) + if err != nil { + return "" + } - jHeader, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[0]) - if err != nil { - return "" + header := map[string]interface{}{} + err = json.Unmarshal(jHeader, &header) + if err != nil { + return "" + } + + // Get data from payload + payload := map[string]interface{}{} + err = json.Unmarshal(jPayload, &payload) + if err != nil { + return "" + } + + // Verify issuer + SSOLogin := false + if iss, ok := payload["iss"].(string); ok { + if iss != JWTIssuer { + accepted := false + for _, fiss := range AcceptedJWTIssuers { + if fiss == iss { + accepted = true + break + } } - jPayload, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[1]) - if err != nil { + if !accepted { return "" } + SSOLogin = true + } + } else { + return "" + } - header := map[string]interface{}{} - err = json.Unmarshal(jHeader, &header) - if err != nil { - return "" + // verify audience + if aud, ok := payload["aud"].(string); ok { + if aud != JWTIssuer { + return "" + } + } else if aud, ok := payload["aud"].([]string); ok { + accepted := false + for _, audItem := range aud { + if audItem == JWTIssuer { + accepted = true + break } + } + if !accepted { + return "" + } + } else { + return "" + } - // Get data from payload - payload := map[string]interface{}{} - err = json.Unmarshal(jPayload, &payload) - if err != nil { - return "" - } + // if there is no subject, return empty session + if _, ok := payload["sub"].(string); !ok { + return "" + } - // Verify issuer - if iss, ok := payload["iss"].(string); ok { - if iss != JWTIssuer { - accepted := false - for _, fiss := range AcceptedJWTIssuers { - if fiss == iss { - accepted = true - break - } - } - if !accepted { - return "" + sub := payload["sub"].(string) + user := User{} + Get(&user, "username = ?", sub) + + if user.ID == 0 && SSOLogin { + now := time.Now() + user := User{ + Username: sub, + FirstName: payload["given_name"].(string), + LastName: payload["family_name"].(string), + Active: true, + Admin: func() bool { + for _, group := range payload["groups"].([]interface{}) { + if group.(map[string]interface{})["id"].(float64) == 0 { + return true } } - } else { - return "" - } + return false + }(), + LastLogin: &now, + RemoteAccess: true, //TODO: add remote access in JWT + Password: GenerateBase64(64), + } - // verify audience - if aud, ok := payload["aud"].(string); ok { - if aud != JWTIssuer { - return "" - } - } else if aud, ok := payload["aud"].([]string); ok { - accepted := false - for _, audItem := range aud { - if audItem == JWTIssuer { - accepted = true - break - } - } - if !accepted { - return "" - } - } else { - return "" - } + // TODO: Add custom function to customize the user before saving + // this function will receive the following parameters: + // payload, *user - // if there is no subject, return empty session - if _, ok := payload["sub"].(string); !ok { - return "" - } + user.Save() - sub := payload["sub"].(string) - user := User{} - Get(&user, "username = ?", sub) + // process entitlements + // TODO: find a way to refresh entitlements every login + } else if user.ID == 0 { + return "" + } - if user.ID == 0 { - return "" - } + session := user.GetActiveSession() + if session == nil && SSOLogin { + session = &Session{ + UserID: user.ID, + Active: true, + LoginTime: time.Now(), + IP: GetRemoteIP(r), + } + session.GenerateKey() - session := user.GetActiveSession() - if session == nil { - return "" - } + // TODO: Add custom function to customize the user session + // this function will receive the following parameters: + // payload, user, *session - // TODO: verify exp + session.Save() + } else if session == nil { + return "" + } - // Verify the signature - alg := "HS256" - if v, ok := header["alg"].(string); ok { - alg = v - } - if _, ok := header["typ"]; ok { - if v, ok := header["typ"].(string); !ok || v != "JWT" { - return "" - } - } - switch alg { - case "HS256": - // TODO: allow third party JWT signature authentication - hash := hmac.New(sha256.New, []byte(JWT+session.Key)) - hash.Write([]byte(jwtParts[0] + "." + jwtParts[1])) - token := hash.Sum(nil) - b64Token := base64.RawURLEncoding.EncodeToString(token) - if b64Token != jwtParts[2] { - return "" - } - default: - // For now, only support HMAC-SHA256 - return "" - } - return session.Key + // TODO: verify exp + + // Verify the signature + alg := "HS256" + if v, ok := header["alg"].(string); ok { + alg = v + } + if _, ok := header["typ"]; ok { + if v, ok := header["typ"].(string); !ok || v != "JWT" { + return "" } } + // verify signature + switch alg { + case "HS256": + // TODO: allow third party JWT signature authentication + hash := hmac.New(sha256.New, []byte(JWT+session.Key)) + hash.Write([]byte(jwtParts[0] + "." + jwtParts[1])) + token := hash.Sum(nil) + b64Token := base64.RawURLEncoding.EncodeToString(token) + if b64Token != jwtParts[2] { + return "" + } + case "RS256": + if !verifyRSA(jwtToken, SSOLogin) { + return "" + } + default: + // For now, only support HMAC-SHA256 + return "" + } + + return session.Key + +} + +var jwtIssuerCerts = map[[2]string][]byte{} + +func getJWTRSAPublicKeySSO(jwtToken *jwt.Token) *rsa.PublicKey { + iss, err := jwtToken.Claims.GetIssuer() + if err != nil { + return nil + } + + kid, _ := jwtToken.Header["kid"].(string) + if kid == "" { + return nil + } + + if val, ok := jwtIssuerCerts[[2]string{iss, kid}]; ok { + cert, _ := jwt.ParseRSAPublicKeyFromPEM(val) + return cert + } + + res, err := http.Get(iss + "/.well-known/openid-configuration") + if err != nil { + return nil + } + + if res.StatusCode != 200 { + return nil + } + + buf, err := io.ReadAll(res.Body) + if err != nil { + return nil + } + + obj := map[string]interface{}{} + err = json.Unmarshal(buf, &obj) + if err != nil { + return nil + } + + crtURL := "" + if val, ok := obj["jwks_uri"].(string); !ok || val == "" { + return nil + } else { + crtURL = val + } + + res, err = http.Get(crtURL) + if err != nil { + return nil + } + + if res.StatusCode != 200 { + return nil + } + + buf, err = io.ReadAll(res.Body) + if err != nil { + return nil + } + + certObj := map[string][]map[string]string{} + err = json.Unmarshal(buf, &certObj) + if err != nil { + return nil + } + + if val, ok := certObj["keys"]; !ok || len(val) == 0 { + return nil + } + + var cert map[string]string + for i := range certObj["keys"] { + if certObj["keys"][i]["kid"] == kid { + cert = certObj["keys"][i] + break + } + } + + if cert == nil { + return nil + } + + N := new(big.Int) + buf, _ = base64.RawURLEncoding.DecodeString(cert["n"]) + N = N.SetBytes(buf) + + E := new(big.Int) + buf, _ = base64.RawURLEncoding.DecodeString(cert["e"]) + E = E.SetBytes(buf) + publicCert := rsa.PublicKey{ + N: N, + E: int(E.Int64()), + } + + return &publicCert +} + +func getJWTRSAPublicKeyLocal(jwtToken *jwt.Token) *rsa.PublicKey { + pubKeyPEM, err := os.ReadFile(".jwt-rsa-public.pem") + if err != nil { + return nil + } + + pubKey, err := jwt.ParseRSAPublicKeyFromPEM(pubKeyPEM) + if err != nil { + return nil + } + + return pubKey +} + +func verifyRSA(token string, SSOLogin bool) bool { + tok, err := jwt.Parse(token, func(jwtToken *jwt.Token) (interface{}, error) { + if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected method: %s", jwtToken.Header["alg"]) + } + + var pubKey *rsa.PublicKey + + if SSOLogin { + pubKey = getJWTRSAPublicKeySSO(jwtToken) + } else { + pubKey = getJWTRSAPublicKeyLocal(jwtToken) + } + + if pubKey == nil { + return nil, fmt.Errorf("Unable to load local public key") + } + + return pubKey, nil + }) + if err != nil { + return false + } + + _, ok := tok.Claims.(jwt.MapClaims) + if !ok || !tok.Valid { + return false + } + + return true +} + +func getSession(r *http.Request) string { + // First, try JWT + if val := getJWT(r); val != "" { + return val + } + + if r.URL.Query().Get("access-token") != "" { + r.Header.Add("Authorization", "Bearer "+r.URL.Query().Get("access-token")) + if val := getJWT(r); val != "" { + return val + } + } + + // Then try session + key, err := r.Cookie("session") + if err == nil && key != nil { + return key.Value + } + if r.Method == "GET" && r.FormValue("session") != "" { + return r.FormValue("session") + } + if r.Method != "GET" { + err := r.ParseMultipartForm(2 << 10) + if err != nil { + r.ParseForm() + } + if r.FormValue("session") != "" { + return r.FormValue("session") + } + } + return "" } @@ -772,6 +1095,12 @@ func GetSchema(r *http.Request) string { func verifyPassword(hash string, plain string) error { password := []byte(plain + Salt) hashedPassword := []byte(hash) + if len(hashedPassword) > 72 { + hashedPassword = hashedPassword[:72] + } + if len(password) > 72 { + password = password[:72] + } return bcrypt.CompareHashAndPassword(hashedPassword, password) } diff --git a/check_csrf.go b/check_csrf.go index 64f9c57c..5036244f 100644 --- a/check_csrf.go +++ b/check_csrf.go @@ -40,13 +40,16 @@ If you you call this API: It will return an error message and the system will create a CRITICAL level log with details about the possible attack. To make the request -work, `x-csrf-token` paramtere should be added. +work, `x-csrf-token` parameter should be added. http://0.0.0.0:8080/myapi/?x-csrf-token=MY_SESSION_KEY Where you replace `MY_SESSION_KEY` with the session key. */ func CheckCSRF(r *http.Request) bool { + if getJWT(r) != "" { + return false + } token := getCSRFToken(r) if token != "" && token == getSession(r) { return false diff --git a/cmd/uadmin/copy.go b/cmd/uadmin/copy.go index a86087b7..dabf2c10 100644 --- a/cmd/uadmin/copy.go +++ b/cmd/uadmin/copy.go @@ -1,7 +1,7 @@ /* The MIT License (MIT) -Copyright (c) 2018 otiai10 +# Copyright (c) 2018 otiai10 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -57,7 +57,7 @@ func copy(src, dest string, info os.FileInfo) error { // and file permission. func fcopy(src, dest string) error { - if err := os.MkdirAll(filepath.Dir(dest), os.ModePerm); err != nil { + if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { return err } @@ -86,7 +86,7 @@ func fcopy(src, dest string) error { // and pass everything to "copy" recursively. func dcopy(srcdir, destdir string) error { - if err := os.MkdirAll(destdir, os.FileMode(0744)); err != nil { + if err := os.MkdirAll(destdir, 0755); err != nil { return err } diff --git a/cmd/uadmin/main.go b/cmd/uadmin/main.go index 7be1c3f7..9b75b726 100644 --- a/cmd/uadmin/main.go +++ b/cmd/uadmin/main.go @@ -88,7 +88,7 @@ func main() { for _, v := range folderList { dst = filepath.Join(ex, v) if _, err = os.Stat(dst); os.IsNotExist(err) { - err = os.MkdirAll(dst, os.FileMode(0744)) + err = os.MkdirAll(dst, 0755) if err != nil { uadmin.Trail(uadmin.WARNING, "Unable to create \"%s\" folder: %s", v, err) } else { diff --git a/crop_image_handler_test.go b/crop_image_handler_test.go index eea1541f..5ccfe0a8 100644 --- a/crop_image_handler_test.go +++ b/crop_image_handler_test.go @@ -31,7 +31,7 @@ func (t *UAdminTests) TestCropImageHandler() { } } - os.MkdirAll("./media/user", 0744) + os.MkdirAll("./media/user", 0755) // Save to iamge.png f1, _ := os.OpenFile("./media/user/image_raw.png", os.O_WRONLY|os.O_CREATE, 0600) diff --git a/d_api.go b/d_api.go index be8ba959..537e5d04 100644 --- a/d_api.go +++ b/d_api.go @@ -238,6 +238,9 @@ func dAPIHandler(w http.ResponseWriter, r *http.Request, s *Session) { // Route the request to the correct handler based on the command if command == "read" { // check if there is a prequery + if APIPreQueryReadHandler != nil && !APIPreQueryReadHandler(w, r) { + return + } if preQuery, ok := model.(APIPreQueryReader); ok && !preQuery.APIPreQueryRead(w, r) { } else { dAPIReadHandler(w, r, s) @@ -245,6 +248,10 @@ func dAPIHandler(w http.ResponseWriter, r *http.Request, s *Session) { return } if command == "add" { + // check if there is a prequery + if APIPreQueryAddHandler != nil && !APIPreQueryAddHandler(w, r) { + return + } if preQuery, ok := model.(APIPreQueryAdder); ok && !preQuery.APIPreQueryAdd(w, r) { } else { dAPIAddHandler(w, r, s) @@ -253,6 +260,9 @@ func dAPIHandler(w http.ResponseWriter, r *http.Request, s *Session) { } if command == "edit" { // check if there is a prequery + if APIPreQueryEditHandler != nil && !APIPreQueryEditHandler(w, r) { + return + } if preQuery, ok := model.(APIPreQueryEditor); ok && !preQuery.APIPreQueryEdit(w, r) { } else { dAPIEditHandler(w, r, s) @@ -261,6 +271,9 @@ func dAPIHandler(w http.ResponseWriter, r *http.Request, s *Session) { } if command == "delete" { // check if there is a prequery + if APIPreQueryDeleteHandler != nil && !APIPreQueryDeleteHandler(w, r) { + return + } if preQuery, ok := model.(APIPreQueryDeleter); ok && !preQuery.APIPreQueryDelete(w, r) { } else { dAPIDeleteHandler(w, r, s) diff --git a/d_api_add.go b/d_api_add.go index b5b682bc..e8ef728a 100644 --- a/d_api_add.go +++ b/d_api_add.go @@ -20,6 +20,7 @@ func dAPIAddHandler(w http.ResponseWriter, r *http.Request, s *Session) { // Check CSRF if CheckCSRF(r) { + w.WriteHeader(http.StatusForbidden) ReturnJSON(w, r, map[string]interface{}{ "status": "error", "err_msg": "Failed CSRF protection.", diff --git a/d_api_auth.go b/d_api_auth.go index 6481486c..02c2af2d 100644 --- a/d_api_auth.go +++ b/d_api_auth.go @@ -31,6 +31,10 @@ func dAPIAuthHandler(w http.ResponseWriter, r *http.Request, s *Session) { dAPIResetPasswordHandler(w, r, s) case "changepassword": dAPIChangePasswordHandler(w, r, s) + case "openidlogin": + dAPIOpenIDLoginHandler(w, r, s) + case "certs": + dAPIOpenIDCertHandler(w, r) default: w.WriteHeader(http.StatusNotFound) ReturnJSON(w, r, map[string]interface{}{ diff --git a/d_api_delete.go b/d_api_delete.go index a69676c4..da766790 100644 --- a/d_api_delete.go +++ b/d_api_delete.go @@ -18,6 +18,7 @@ func dAPIDeleteHandler(w http.ResponseWriter, r *http.Request, s *Session) { // Check CSRF if CheckCSRF(r) { + w.WriteHeader(http.StatusForbidden) ReturnJSON(w, r, map[string]interface{}{ "status": "error", "err_msg": "Failed CSRF protection.", diff --git a/d_api_edit.go b/d_api_edit.go index 05db97ff..1cfe012e 100644 --- a/d_api_edit.go +++ b/d_api_edit.go @@ -16,6 +16,7 @@ func dAPIEditHandler(w http.ResponseWriter, r *http.Request, s *Session) { // Check CSRF if CheckCSRF(r) { + w.WriteHeader(http.StatusForbidden) ReturnJSON(w, r, map[string]interface{}{ "status": "error", "err_msg": "Failed CSRF protection.", diff --git a/d_api_helper.go b/d_api_helper.go index caf6005e..9bcb7b59 100644 --- a/d_api_helper.go +++ b/d_api_helper.go @@ -650,6 +650,9 @@ func returnDAPIJSON(w http.ResponseWriter, r *http.Request, a map[string]interfa if model != nil { if command == "read" { + if APIPostQueryReadHandler != nil && !APIPostQueryReadHandler(w, r, a) { + return nil + } if postQuery, ok := model.(APIPostQueryReader); ok { if !postQuery.APIPostQueryRead(w, r, a) { return nil @@ -657,6 +660,9 @@ func returnDAPIJSON(w http.ResponseWriter, r *http.Request, a map[string]interfa } } if command == "add" { + if APIPostQueryAddHandler != nil && !APIPostQueryAddHandler(w, r, a) { + return nil + } if postQuery, ok := model.(APIPostQueryAdder); ok { if !postQuery.APIPostQueryAdd(w, r, a) { return nil @@ -664,6 +670,9 @@ func returnDAPIJSON(w http.ResponseWriter, r *http.Request, a map[string]interfa } } if command == "edit" { + if APIPostQueryEditHandler != nil && !APIPostQueryEditHandler(w, r, a) { + return nil + } if postQuery, ok := model.(APIPostQueryEditor); ok { if !postQuery.APIPostQueryEdit(w, r, a) { return nil @@ -671,6 +680,9 @@ func returnDAPIJSON(w http.ResponseWriter, r *http.Request, a map[string]interfa } } if command == "delete" { + if APIPostQueryDeleteHandler != nil && !APIPostQueryDeleteHandler(w, r, a) { + return nil + } if postQuery, ok := model.(APIPostQueryDeleter); ok { if !postQuery.APIPostQueryDelete(w, r, a) { return nil @@ -684,7 +696,7 @@ func returnDAPIJSON(w http.ResponseWriter, r *http.Request, a map[string]interfa } } } - // if command == "schema" { + // if command == "method" { /* TODO: Add post query for methods if postQuery, ok := model.(APIPostQueryMethoder); ok { diff --git a/d_api_logout.go b/d_api_logout.go index 30eb7506..1dd276c0 100644 --- a/d_api_logout.go +++ b/d_api_logout.go @@ -13,7 +13,7 @@ func dAPILogoutHandler(w http.ResponseWriter, r *http.Request, s *Session) { } if CheckCSRF(r) { - w.WriteHeader(http.StatusUnauthorized) + w.WriteHeader(http.StatusForbidden) ReturnJSON(w, r, map[string]interface{}{ "status": "error", "err_msg": "Missing CSRF token", diff --git a/d_api_method.go b/d_api_method.go index 7f60a544..48d9b18e 100644 --- a/d_api_method.go +++ b/d_api_method.go @@ -24,6 +24,7 @@ func dAPIMethodHandler(w http.ResponseWriter, r *http.Request, s *Session) { } if CheckCSRF(r) { + w.WriteHeader(http.StatusForbidden) ReturnJSON(w, r, map[string]interface{}{ "status": "error", "err_msg": "Failed CSRF protection.", diff --git a/d_api_openid_cert_handler.go b/d_api_openid_cert_handler.go new file mode 100644 index 00000000..9b82e788 --- /dev/null +++ b/d_api_openid_cert_handler.go @@ -0,0 +1,45 @@ +package uadmin + +import ( + "encoding/base64" + "math/big" + "net/http" + "os" + + "github.com/golang-jwt/jwt/v5" +) + +func dAPIOpenIDCertHandler(w http.ResponseWriter, r *http.Request) { + buf, err := os.ReadFile(".jwt-rsa-public.pem") + if err != nil { + w.WriteHeader(404) + ReturnJSON(w, r, map[string]interface{}{ + "status": "error", + "err_msg": "Unable to load public certificate", + }) + return + } + cert, err := jwt.ParseRSAPublicKeyFromPEM(buf) + if err != nil { + w.WriteHeader(404) + ReturnJSON(w, r, map[string]interface{}{ + "status": "error", + "err_msg": "Unable to parse public certificate", + }) + return + } + obj := map[string][]map[string]string{ + "keys": { + { + "kid": "1", + "use": "sig", + "kty": "RSA", + "alg": "RS256", + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(cert.E)).Bytes()), + "n": base64.RawURLEncoding.EncodeToString(cert.N.Bytes()), + }, + }, + } + + ReturnJSON(w, r, obj) +} diff --git a/d_api_openid_login.go b/d_api_openid_login.go new file mode 100644 index 00000000..fa2e5f1d --- /dev/null +++ b/d_api_openid_login.go @@ -0,0 +1,73 @@ +package uadmin + +import ( + "context" + "net/http" + "strings" +) + +func dAPIOpenIDLoginHandler(w http.ResponseWriter, r *http.Request, s *Session) { + _ = s + redirectURI := r.FormValue("redirect_uri") + + Trail(DEBUG, "HERE") + + if r.Method == "GET" { + if session := IsAuthenticated(r); session != nil { + Preload(session, "User") + c := map[string]interface{}{ + "SiteName": SiteName, + "Language": getLanguage(r), + "RootURL": RootURL, + "Logo": Logo, + "user": s.User, + "OpenIDWebsiteURL": redirectURI, + } + RenderHTML(w, r, "./templates/uadmin/"+Theme+"/openid_concent.html", c) + return + } + + http.Redirect(w, r, RootURL+"login/?next="+RootURL+"api/d/auth/openidlogin?"+r.URL.Query().Encode(), 303) + return + } + + if s == nil { + w.WriteHeader(http.StatusUnauthorized) + ReturnJSON(w, r, map[string]interface{}{ + "status": "error", + "err_msg": "Invalid credentials", + }) + return + } + + // Preload the user to get the group name + Preload(&s.User) + + ctx := context.WithValue(r.Context(), CKey("aud"), getAUD(redirectURI)) + r = r.WithContext(ctx) + jwt := createJWT(r, s) + + http.Redirect(w, r, redirectURI+"?access-token="+jwt, 303) + +} + +func getAUD(URL string) string { + aud := "" + + if strings.HasPrefix(URL, "https://") { + aud = "https://" + URL = strings.TrimPrefix(URL, "https://") + } + + if strings.HasPrefix(URL, "http://") { + aud = "http://" + URL = strings.TrimPrefix(URL, "http://") + } + + if strings.Contains(URL, "/") { + URL = URL[:strings.Index(URL, "/")] + aud += URL + } + + return aud +} diff --git a/d_api_read.go b/d_api_read.go index 3392c075..8ab946e5 100644 --- a/d_api_read.go +++ b/d_api_read.go @@ -69,7 +69,7 @@ func dAPIReadHandler(w http.ResponseWriter, r *http.Request, s *Session) { if f != "" { SQL = strings.Replace(SQL, "{FIELDS}", f, -1) } else { - SQL = strings.Replace(SQL, "{FIELDS}", "*", -1) + SQL = strings.Replace(SQL, "{FIELDS}", tableName+".*", -1) } join := getQueryJoin(r, params, tableName) @@ -94,6 +94,12 @@ func dAPIReadHandler(w http.ResponseWriter, r *http.Request, s *Session) { args = append(args, lmArgs...) } } + if r.Context().Value(CKey("WHERE")) != nil { + if q != "" { + q += " AND " + } + q += r.Context().Value(CKey("WHERE")).(string) + } if q != "" { SQL += " WHERE " + q } @@ -228,13 +234,19 @@ func dAPIReadHandler(w http.ResponseWriter, r *http.Request, s *Session) { } else if len(urlParts) == 1 { // Read One m, _ := NewModel(modelName, true) - Get(m.Interface(), "id = ?", urlParts[0]) + q := "id = ?" + if r.Context().Value(CKey("WHERE")) != nil { + q += " AND " + r.Context().Value(CKey("WHERE")).(string) + } + Get(m.Interface(), q, urlParts[0]) rowsCount = 0 var i interface{} if int(GetID(m)) != 0 { i = m.Interface() rowsCount = 1 + } else { + w.WriteHeader(404) } if params["$preload"] == "1" || params["$preload"] == "true" { diff --git a/d_api_reset_password.go b/d_api_reset_password.go index 567e7f02..6d1c6dd9 100644 --- a/d_api_reset_password.go +++ b/d_api_reset_password.go @@ -1,6 +1,7 @@ package uadmin import ( + "fmt" "net/http" "time" ) @@ -52,9 +53,15 @@ func dAPIResetPasswordHandler(w http.ResponseWriter, r *http.Request, s *Session // check if the user exists and active if user.ID == 0 || (user.ExpiresOn != nil && user.ExpiresOn.After(time.Now())) { w.WriteHeader(404) + identifier := "email" + identifierVal := email + if uid != "" { + identifier = "uid" + identifierVal = uid + } ReturnJSON(w, r, map[string]interface{}{ "status": "error", - "err_msg": "email or uid do not match any active user", + "err_msg": fmt.Sprintf("%s: '%s' do not match any active user", identifier, identifierVal), }) // log the request go func() { diff --git a/d_api_signup.go b/d_api_signup.go index 1a03570a..77b29a0b 100644 --- a/d_api_signup.go +++ b/d_api_signup.go @@ -71,6 +71,7 @@ func dAPISignupHandler(w http.ResponseWriter, r *http.Request, s *Session) { "status": "error", "err_msg": "username taken", }) + return } // if the user is active, then login in diff --git a/db.go b/db.go index 8241c466..e0dae6d7 100644 --- a/db.go +++ b/db.go @@ -73,6 +73,10 @@ type DBSettings struct { Timezone string `json:"timezone"` } +type AutoMigrater interface { + AutoMigrate() bool +} + // initializeDB opens the connection the DB func initializeDB(a ...interface{}) { // Open the connection the the DB @@ -80,6 +84,11 @@ func initializeDB(a ...interface{}) { // Migrate schema for i, model := range a { + if autoMigrate, ok := model.(AutoMigrater); ok { + if !autoMigrate.AutoMigrate() { + continue + } + } Trail(INFO, "Initializing DB: [%s%d/%d%s]", colors.FGGreenB, i+1, len(a), colors.FGNormal) err := db.AutoMigrate(model) if err != nil { @@ -238,7 +247,7 @@ func GetDB() *gorm.DB { }) // Check if the error is DB doesn't exist and create it - if err != nil && err.Error() == "Error 1049: Unknown database '"+Database.Name+"'" { + if err != nil && strings.Contains(err.Error(), "Unknown database '"+Database.Name+"'") { err = createDB() if err == nil { @@ -440,6 +449,9 @@ func All(a interface{}) (err error) { // Save saves the object in the database func Save(a interface{}) (err error) { encryptRecord(a) + if Database.Type == "mysql" { + a = fixDates(a) + } TimeMetric("uadmin/db/duration", 1000, func() { err = db.Save(a).Error for fmt.Sprint(err) == "database is locked" { @@ -459,6 +471,30 @@ func Save(a interface{}) (err error) { return nil } +func fixDates(a interface{}) interface{} { + value := reflect.ValueOf(a).Elem() + now := time.Now() + timeType := reflect.TypeOf(now) + timePointerType := reflect.TypeOf(&now) + timeValue := reflect.ValueOf(now) + timePointerValue := reflect.ValueOf(&now) + for i := 0; i < value.NumField(); i++ { + if value.Field(i).Type() == timeType { + if value.Field(i).Interface().(time.Time).IsZero() { + value.Field(i).Set(timeValue) + } + } else if value.Field(i).Type() == timePointerType { + if !value.Field(i).IsNil() { + if value.Field(i).Interface().(*time.Time).IsZero() { + value.Field(i).Set(timePointerValue) + } + } + } + + } + return value.Addr().Interface() +} + func customSave(m interface{}) (err error) { a := m t := reflect.TypeOf(a) diff --git a/generate_translation.go b/generate_translation.go index 7aa860a8..33c4d6c8 100644 --- a/generate_translation.go +++ b/generate_translation.go @@ -65,7 +65,7 @@ func syncCustomTranslation(path string) map[string]int { group := pathParts[0] name := pathParts[1] - os.MkdirAll("./static/i18n/"+group+"/", 0744) + os.MkdirAll("./static/i18n/"+group+"/", 0755) fileName := "./static/i18n/" + group + "/" + name + ".en.json" langMap := map[string]string{} if _, err = os.Stat(fileName); os.IsNotExist(err) { @@ -159,7 +159,7 @@ func syncModelTranslation(m ModelSchema) map[string]int { pkgName = strings.Split(pkgName, ".")[0] // Get the model's original language file - err = os.MkdirAll("./static/i18n/"+pkgName+"/", 0744) + err = os.MkdirAll("./static/i18n/"+pkgName+"/", 0755) if err != nil { Trail(ERROR, "generateTranslation error creating folder (./static/i18n/"+pkgName+"/). %v", err) @@ -168,7 +168,7 @@ func syncModelTranslation(m ModelSchema) map[string]int { fileName := "./static/i18n/" + pkgName + "/" + m.ModelName + ".en.json" - // Check if the fist doesn't exist and create it + // Check if the first doesn't exist and create it if _, err = os.Stat(fileName); os.IsNotExist(err) { buf, _ = json.MarshalIndent(structLang, "", " ") err = ioutil.WriteFile(fileName, buf, 0644) diff --git a/global.go b/global.go index c1209cd5..e923b301 100644 --- a/global.go +++ b/global.go @@ -81,7 +81,7 @@ const cEMAIL = "email" const cM2M = "m2m" // Version number as per Semantic Versioning 2.0.0 (semver.org) -const Version = "0.9.2" +const Version = "0.10.1" // VersionCodeName is the cool name we give to versions with significant changes. // This name should always be a bug's name starting from A-Z them revolving back. @@ -90,7 +90,8 @@ const Version = "0.9.2" // 0.7.0 Catterpiller // 0.8.0 Dragonfly // 0.9.0 Gnat -const VersionCodeName = "Gnat" +// 0.10.0 Gnat +const VersionCodeName = "Housefly" // Public Global Variables @@ -450,6 +451,45 @@ var FullMediaURL = false // MaskPasswordInAPI will replace any password fields with an asterisk mask var MaskPasswordInAPI = true +// APIPreQueryReadHandler is a function that runs before all dAPI read requests +var APIPreQueryReadHandler func(http.ResponseWriter, *http.Request) bool + +// APIPostQueryReadHandler is a function that runs after all dAPI read requests +var APIPostQueryReadHandler func(http.ResponseWriter, *http.Request, map[string]interface{}) bool + +// APIPreQueryAddHandler is a function that runs before all dAPI add requests +var APIPreQueryAddHandler func(http.ResponseWriter, *http.Request) bool + +// APIPostQueryAddHandler is a function that runs after all dAPI add requests +var APIPostQueryAddHandler func(http.ResponseWriter, *http.Request, map[string]interface{}) bool + +// APIPreQueryEditHandler is a function that runs before all dAPI edit requests +var APIPreQueryEditHandler func(http.ResponseWriter, *http.Request) bool + +// APIPostQueryEditHandler is a function that runs after all dAPI edit requests +var APIPostQueryEditHandler func(http.ResponseWriter, *http.Request, map[string]interface{}) bool + +// APIPreQueryDeleteHandler is a function that runs before all dAPI delete requests +var APIPreQueryDeleteHandler func(http.ResponseWriter, *http.Request) bool + +// APIPostQueryDeleteHandler is a function that runs after all dAPI delete requests +var APIPostQueryDeleteHandler func(http.ResponseWriter, *http.Request, map[string]interface{}) bool + +// PreLoginHandler is a function that runs after all dAPI delete requests +var PreLoginHandler func(r *http.Request, username string, password string) + +// PreLoginHandler is a function that runs after all dAPI delete requests +var PostUploadHandler func(filePath string, modelName string, f *F) string + +// CompressJSON is a variable that allows the user to reduce the size of JSON responses +var CompressJSON = false + +// CompressJSON is a variable that allows the user to reduce the size of JSON responses +var RemoveZeroValueJSON = false + +// SSOURL enables SSO using OpenID Connect +var SSOURL = "" + // Private Global Variables // Regex var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") diff --git a/go.mod b/go.mod index 4c2ac87f..f9aa324a 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( require ( github.com/boombuler/barcode v1.0.1 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/golang-jwt/jwt/v5 v5.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.3.0 // indirect diff --git a/go.sum b/go.sum index 9a1dc4dc..5dea4bf6 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= diff --git a/load_initial_data.go b/load_initial_data.go index d8deb0a0..f81e7dc5 100644 --- a/load_initial_data.go +++ b/load_initial_data.go @@ -103,6 +103,8 @@ func loadInitialData() error { // Save records for i := 0; i < modelArray.Elem().Len(); i++ { + Get(modelArray.Elem().Index(i).Addr().Interface(), "id = ?", GetID(modelArray.Elem().Index(i))) + json.Unmarshal(buf, modelArray.Interface()) err = Save(modelArray.Elem().Index(i).Addr().Interface()) if err != nil { return fmt.Errorf("loadInitialData.Save: Error in %s[%d]. %s", table, i, err) diff --git a/login_handler.go b/login_handler.go index 19053791..f87dadfc 100644 --- a/login_handler.go +++ b/login_handler.go @@ -19,6 +19,7 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { Password string Logo string FavIcon string + SSOURL string } c := Context{} @@ -27,6 +28,15 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { c.Language = getLanguage(r) c.Logo = Logo c.FavIcon = FavIcon + c.SSOURL = SSOURL + + if session := IsAuthenticated(r); session != nil { + session = session.User.GetActiveSession() + SetSessionCookie(w, r, session) + if r.URL.Query().Get("next") != "" { + http.Redirect(w, r, r.URL.Query().Get("next"), 303) + } + } if r.Method == cPOST { if r.FormValue("save") == "Send Request" { diff --git a/main_handler.go b/main_handler.go index e17bef0e..e9c4095a 100644 --- a/main_handler.go +++ b/main_handler.go @@ -80,6 +80,10 @@ func mainHandler(w http.ResponseWriter, r *http.Request) { settingsHandler(w, r, session) return } + if URLParts[0] == "login" { + loginHandler(w, r) + return + } listHandler(w, r, session) return } else if len(URLParts) == 2 { diff --git a/media_handler.go b/media_handler.go index 719e1c38..785cfa5a 100644 --- a/media_handler.go +++ b/media_handler.go @@ -1,7 +1,6 @@ package uadmin import ( - "io" "net/http" "os" "path" @@ -10,23 +9,55 @@ import ( func mediaHandler(w http.ResponseWriter, r *http.Request) { session := IsAuthenticated(r) - if session == nil && !PublicMedia { + token := r.URL.Query().Get("token") + if session == nil && !PublicMedia && token == "" { + w.WriteHeader(401) loginHandler(w, r) return } - r.URL.Path = strings.TrimPrefix(r.URL.Path, "/media/") - file, err := os.Open("./media/" + path.Clean(r.URL.Path)) + // r.URL.Path = strings.TrimPrefix(r.URL.Path, "/media/") + // file, err := os.Open("./media/" + path.Clean(r.URL.Path)) + // if err != nil { + // pageErrorHandler(w, r, session) + // return + // } + // io.Copy(w, file) + // file.Close() + + fName := path.Clean(r.URL.Path) + + if session == nil && !PublicMedia && token != "" { + // this request for a limited request for one resource + if verifyPassword("$2a$12$"+token, fName) != nil { + w.WriteHeader(401) + return + } + } + + f, err := os.Open("." + fName) if err != nil { - pageErrorHandler(w, r, session) + w.WriteHeader(404) return } - io.Copy(w, file) - file.Close() + defer f.Close() + stat, err := os.Stat("." + fName) + if err != nil || stat.IsDir() { + w.WriteHeader(404) + return + } + modTime := stat.ModTime() + if RetainMediaVersions { + w.Header().Add("Cache-Control", "private, max-age=604800") + } else { + w.Header().Add("Cache-Control", "private, max-age=3600") + } + + http.ServeContent(w, r, "."+fName, modTime, f) // Delete the file if exported to excel - if strings.HasPrefix(r.URL.Path, "export/") { - filePart := strings.TrimPrefix(r.URL.Path, "export/") + if strings.HasPrefix(fName, "/media/export/") { + filePart := strings.TrimPrefix(fName, "/media/export/") filePart = path.Clean(filePart) if filePart != "" && !strings.HasSuffix(filePart, "index.html") { os.Remove("./media/export/" + filePart) diff --git a/model.go b/model.go index 7f886063..acae1e8b 100644 --- a/model.go +++ b/model.go @@ -8,5 +8,5 @@ import ( // in any other struct to make it a model for uadmin type Model struct { ID uint `gorm:"primary_key"` - DeletedAt gorm.DeletedAt `sql:"index"` + DeletedAt gorm.DeletedAt `sql:"index" json:"-"` } diff --git a/openid_config_handler.go b/openid_config_handler.go new file mode 100644 index 00000000..f22d2e4a --- /dev/null +++ b/openid_config_handler.go @@ -0,0 +1,50 @@ +package uadmin + +import "net/http" + +func JWTConfigHandler(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "issuer": JWTIssuer, + "authorization_endpoint": JWTIssuer + "/api/d/auth/openidlogin", + "token_endpoint": "", + "userinfo_endpoint": JWTIssuer + "/api/d/auth/userinfo", + "jwks_uri": JWTIssuer + "/api/d/auth/certs", + "scopes_supported": []string{ + "openid", + "email", + "profile", + }, + "response_types_supported": []string{ + "code", + "token", + "id_token", + "code token", + "code id_token", + "token id_token", + "code token id_token", + "none", + }, + "subject_types_supported": []string{ + "public", + }, + "id_token_signing_alg_values_supported": []string{ + "RS256", + }, + "claims_supported": []string{ + "aud", + "email", + "email_verified", + "exp", + "family_name", + "given_name", + "iat", + "iss", + "locale", + "name", + "picture", + "sub", + }, + } + + ReturnJSON(w, r, data) +} diff --git a/otp.go b/otp.go index e170f8a6..13946b72 100644 --- a/otp.go +++ b/otp.go @@ -88,7 +88,7 @@ func generateOTPSeed(digits int, algorithm string, skew uint, period uint, user key, _ := totp.Generate(opts) img, _ := key.Image(250, 250) - os.MkdirAll("./media/otp/", 0744) + os.MkdirAll("./media/otp/", 0755) fName := "./media/otp/" + key.Secret() + ".png" for _, err := os.Stat(fName); os.IsExist(err); { diff --git a/params_to_instance.go b/params_to_instance.go new file mode 100644 index 00000000..dfeaee8c --- /dev/null +++ b/params_to_instance.go @@ -0,0 +1,60 @@ +package uadmin + +// func setParams(params map[string]string, m reflect.Value, schema ModelSchema) (reflect.Value, error) { +// paramMap := map[string]interface{}{} +// for k, v := range params { +// key := k +// if key == "" { +// continue +// } +// if key[0] == '_' { +// key = key[1:] +// } +// f := schema.FieldByColumnName(key) +// if f != nil { +// key = f.Name +// } +// paramMap[key] = v + +// // fix value for numbers +// if f.Type == cNUMBER { +// if strings.HasPrefix(f.TypeName, "float") { +// paramMap[key], _ = strconv.ParseFloat(v, 64) +// } else if strings.HasPrefix(f.TypeName, "uint") { +// paramMap[key], _ = strconv.ParseUint(v, 10, 64) +// } else if strings.HasPrefix(f.TypeName, "int") { +// paramMap[key], _ = strconv.ParseInt(v, 10, 64) +// } +// } else if f.Type == cBOOL { +// if paramMap[key] == "true" || paramMap[key] == "1" { +// paramMap[key] = true +// } else { +// paramMap[key] = false +// } +// } else if f.Type == cLIST { +// paramMap[key], _ = strconv.ParseInt(v, 10, 64) +// } else if f.Type == cDATE { + +// } +// } + +// buf, _ := json.Marshal(params) +// var err error +// if m.Kind() == reflect.Pointer { +// err = json.Unmarshal(buf, m.Interface()) +// } else if m.Kind() == reflect.Struct { +// err = json.Unmarshal(buf, m.Addr().Interface()) +// } + +// return m, err +// } + +// func parseDate(v string) interface{} { +// if v == "" || v == "null" { +// return nil +// } +// dt, err := time.Parse("2006-05-04T15:02:01Z", v) +// if err != nil { +// return dt +// } +// } diff --git a/process_upload.go b/process_upload.go index 73484757..53bb1a99 100644 --- a/process_upload.go +++ b/process_upload.go @@ -65,7 +65,7 @@ func processUpload(r *http.Request, f *F, modelName string, session *Session, s uploadTo = f.UploadTo } if _, err = os.Stat("." + uploadTo); os.IsNotExist(err) { - err = os.MkdirAll("."+uploadTo, os.ModePerm) + err = os.MkdirAll("."+uploadTo, 0755) if err != nil { Trail(ERROR, "processForm.MkdirAll. %s", err) return "" @@ -103,7 +103,7 @@ func processUpload(r *http.Request, f *F, modelName string, session *Session, s // Sanitize the file name fName = pathName + path.Clean(fName) - err = os.MkdirAll(pathName, os.ModePerm) + err = os.MkdirAll(pathName, 0755) if err != nil { Trail(ERROR, "processForm.MkdirAll. unable to create folder for uploaded file. %s", err) return "" @@ -255,5 +255,9 @@ func processUpload(r *http.Request, f *F, modelName string, session *Session, s os.RemoveAll(strings.Join(oldFileParts[0:len(oldFileParts)-1], "/")) } + if PostUploadHandler != nil { + val = PostUploadHandler(val, modelName, f) + } + return val } diff --git a/register.go b/register.go index 9bdb3f03..a90d7bd4 100644 --- a/register.go +++ b/register.go @@ -80,6 +80,17 @@ func Register(m ...interface{}) { // Setup languages initializeLanguage() + // check if trail dashboard menu item is added + if Count([]DashboardMenu{}, "menu_name = ?", "Trail") == 0 { + dashboard := DashboardMenu{ + MenuName: "Trail", + URL: "trail", + Hidden: false, + Cat: "System", + } + Save(&dashboard) + } + // Store models in Model global variable // and initialize the dashboard dashboardMenus := []DashboardMenu{} @@ -142,17 +153,6 @@ func Register(m ...interface{}) { } } - // check if trail dashboard menu item is added - if Count([]DashboardMenu{}, "menu_name = ?", "Trail") == 0 { - dashboard := DashboardMenu{ - MenuName: "Trail", - URL: "trail", - Hidden: false, - Cat: "System", - } - Save(&dashboard) - } - // Check if encrypt key is there or generate it if _, err := os.Stat(".key"); os.IsNotExist(err) && os.Getenv("UADMIN_KEY") == "" { EncryptKey = generateByteArray(32) @@ -382,5 +382,9 @@ func registerHandlers() { http.HandleFunc(RootURL+"api/", Handler(apiHandler)) } + if !DisableDAPIAuth { + http.HandleFunc(RootURL+".well-known/openid-configuration/", Handler(JWTConfigHandler)) + } + handlersRegistered = true } diff --git a/representation.go b/representation.go index 8373bafe..6bf25512 100644 --- a/representation.go +++ b/representation.go @@ -19,7 +19,10 @@ func GetID(m reflect.Value) uint { func GetString(a interface{}) string { str, ok := a.(fmt.Stringer) if ok { - return str.String() + if a != nil { + return str.String() + } + return "" } t := reflect.TypeOf(a) v := reflect.ValueOf(a) diff --git a/server_test.go b/server_test.go index e1de492d..e137c1bd 100644 --- a/server_test.go +++ b/server_test.go @@ -174,7 +174,7 @@ func TestRunner(t *testing.T) { t.Run(dbSetup.Name+"=GroupPermissions", func(t *testing.T) { uTest.TestGroupPermission() }) - t.Run(dbSetup.Name+"=HomeHamdler", func(t *testing.T) { + t.Run(dbSetup.Name+"=HomeHandler", func(t *testing.T) { uTest.TestHomeHandler() }) t.Run(dbSetup.Name+"=Language", func(t *testing.T) { diff --git a/session.go b/session.go index 774ca3ef..b0fd9d2d 100644 --- a/session.go +++ b/session.go @@ -41,6 +41,8 @@ func (s *Session) Save() { Save(s) s.User = u if CacheSessions { + cachedSessionsMutex.Lock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.Unlock() // Ensure the mutex is unlocked when the function exits if s.Active { Preload(s) cachedSessions[s.Key] = *s @@ -81,6 +83,9 @@ func loadSessions() { if !CacheSessions { return } + cachedSessionsMutex.Lock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.Unlock() // Ensure the mutex is unlocked when the function exits + sList := []Session{} Filter(&sList, "`active` = ? AND (expires_on IS NULL OR expires_on > ?)", true, time.Now()) cachedSessions = map[string]Session{} diff --git a/setting.go b/setting.go index 0b86a1af..bf5fc771 100644 --- a/setting.go +++ b/setting.go @@ -287,6 +287,10 @@ func (s *Setting) ApplyValue() { Logo = v.(string) case "uAdmin.FavIcon": FavIcon = v.(string) + case "uAdmin.CompressJSON": + CompressJSON = v.(bool) + case "uAdmin.RemoveZeroValueJSON": + RemoveZeroValueJSON = v.(bool) } } @@ -821,6 +825,32 @@ func syncSystemSettings() { DataType: t.File(), Help: "the fav icon that shows on uAdmin UI", }, + { + Name: "Compress JSON", + Value: func(v bool) string { + n := 0 + if v { + n = 1 + } + return fmt.Sprint(n) + }(CompressJSON), + DefaultValue: "0", + DataType: t.Boolean(), + Help: "Compress JSON allows the system to reduce the size of json responses", + }, + { + Name: "Remove Zero Value JSON", + Value: func(v bool) string { + n := 0 + if v { + n = 1 + } + return fmt.Sprint(n) + }(RemoveZeroValueJSON), + DefaultValue: "0", + DataType: t.Boolean(), + Help: "Compress JSON allows the system to reduce the size of json responses", + }, } // Prepare uAdmin Settings diff --git a/setting_handler.go b/setting_handler.go index 20b0a108..e6721849 100644 --- a/setting_handler.go +++ b/setting_handler.go @@ -63,7 +63,7 @@ func settingsHandler(w http.ResponseWriter, r *http.Request, session *Session) { schema, _ := getSchema(s) schema.FieldByName(sParts[1]) - f := F{Name: s.Code, Type: tMap[s.DataType], UploadTo: "/static/settings/"} + f := F{Name: s.Code, Type: tMap[s.DataType], UploadTo: "/media/settings/"} val := processUpload(r, &f, "setting", session, &schema) if val == "" { diff --git a/static_handler.go b/static_handler.go index 02260a9a..1b1ac833 100644 --- a/static_handler.go +++ b/static_handler.go @@ -71,7 +71,7 @@ func StaticHandler(w http.ResponseWriter, r *http.Request) { return } modTime = stat.ModTime() - w.Header().Add("Cache-Control", "private, max-age=3600") + w.Header().Add("Cache-Control", "private, max-age=604800") } else { modTime = time.Now() } diff --git a/templates/uadmin/default/login.html b/templates/uadmin/default/login.html index dca60b1b..081a4d82 100644 --- a/templates/uadmin/default/login.html +++ b/templates/uadmin/default/login.html @@ -95,6 +95,7 @@

{{Tf "uadmin/system" .La Forgot Password + {{if .SSOURL}}SSO Login{{end}}
{{if .ErrExists}}
@@ -150,6 +151,16 @@

+ + +
+ +
+ +
+
+
+ +
+
+
+ +
+
+
+ + + + +

+ Click Continue +

+

+ to login to {{.OpenIDWebsiteURL}} as {{.user.Username}} +

+
+
+
+ +
+ + +
+
+ +
+
+
+
+
+
+ + + + + + + + + + + + \ No newline at end of file diff --git a/upload_image_handler.go b/upload_image_handler.go index 2f103c55..9d05c486 100644 --- a/upload_image_handler.go +++ b/upload_image_handler.go @@ -20,7 +20,7 @@ func UploadImageHandler(w http.ResponseWriter, r *http.Request, session *Session } folderPath = "./media/htmlimages/" + GenerateBase64(24) + "/" } - os.MkdirAll(folderPath, 0744) + os.MkdirAll(folderPath, 0755) fileName := strings.Replace(f.Filename, "/", " ", -1)