diff --git a/admin.go b/admin.go index 1f11fb44..ed2a5880 100644 --- a/admin.go +++ b/admin.go @@ -154,7 +154,7 @@ func toSnakeCase(str string) string { // JSONMarshal Generates JSON format from an object func JSONMarshal(v interface{}, safeEncoding bool) ([]byte, error) { // b, err := json.Marshal(v) - b, err := json.MarshalIndent(v, "", " ") + b, err := jsonMarshal(v) if safeEncoding { b = bytes.Replace(b, []byte("\\u003c"), []byte("<"), -1) @@ -164,19 +164,73 @@ func JSONMarshal(v interface{}, safeEncoding bool) ([]byte, error) { return b, err } +func nullZeroValueStructs(record map[string]interface{}) map[string]interface{} { + for k := range record { + switch v := record[k].(type) { + case map[string]interface{}: + if id, ok := v["ID"].(float64); ok && id == 0 { + record[k] = nil + } else if id, ok := v["id"].(float64); ok && id == 0 { + record[k] = nil + } else { + record[k] = nullZeroValueStructs(v) + } + } + } + return record +} + +func removeZeroValueStructs(buf []byte) []byte { + response := map[string]interface{}{} + json.Unmarshal(buf, &response) + if val, ok := response["result"].(map[string]interface{}); ok { + val = nullZeroValueStructs(val) + buf, _ = json.Marshal(val) + return buf + } + if _, ok := response["result"].([]interface{}); !ok { + return buf + } + val := response["result"].([]interface{}) + var record map[string]interface{} + for i := range val { + record = val[i].(map[string]interface{}) + record = nullZeroValueStructs(record) + val[i] = record + } + response["result"] = val + buf, _ = json.Marshal(response) + return buf +} + +func jsonMarshal(v interface{}) ([]byte, error) { + var buf []byte + var err error + if CompressJSON { + buf, err = json.Marshal(v) + if err == nil && RemoveZeroValueJSON { + buf = removeZeroValueStructs(buf) + } + } else { + buf, err = json.MarshalIndent(v, "", " ") + } + + return buf, err +} + // ReturnJSON returns json to the client func ReturnJSON(w http.ResponseWriter, r *http.Request, v interface{}) { // Set content type in header w.Header().Set("Content-Type", "application/json") // Marshal content - b, err := json.MarshalIndent(v, "", " ") + b, err := jsonMarshal(v) if err != nil { response := map[string]interface{}{ "status": "error", "error_msg": fmt.Sprintf("unable to encode JSON. %s", err), } - b, _ = json.MarshalIndent(response, "", " ") + b, _ = jsonMarshal(response) w.Write(b) return } diff --git a/admin_test.go b/admin_test.go index 84bbe5dd..041741dc 100644 --- a/admin_test.go +++ b/admin_test.go @@ -204,12 +204,12 @@ func (t *UAdminTests) TestReturnJSON() { out string }{ {map[string]interface{}{"ID": 1, "Name": "Test"}, `{ - "ID": 1, - "Name": "Test" + "ID": 1, + "Name": "Test" }`}, {math.NaN(), `{ - "error_msg": "unable to encode JSON. json: unsupported value: NaN", - "status": "error" + "error_msg": "unable to encode JSON. json: unsupported value: NaN", + "status": "error" }`}, } diff --git a/auth.go b/auth.go index 7258ac2e..ac9e5120 100644 --- a/auth.go +++ b/auth.go @@ -4,12 +4,17 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" + "io" "math/big" "net" + "os" "path" + "sync" "crypto/hmac" "crypto/rand" + "crypto/rsa" "crypto/sha256" "math" "net/http" @@ -17,6 +22,7 @@ import ( "strings" "time" + "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" ) @@ -36,6 +42,8 @@ var JWT = "" // used to identify the as JWT audience. var JWTIssuer = "" +var JWTAlgo = "HS256" //"RS256" + // AcceptedJWTIssuers is a list of accepted JWT issuers. By default the // local JWTIssuer is accepted. To accept other issuers, add them to // this list @@ -47,6 +55,9 @@ var bcryptDiff = 12 // cachedSessions is variable for keeping active sessions var cachedSessions map[string]Session +// Need to have a lock to protect it from race conditions during concurrent writes. +var cachedSessionsMutex sync.RWMutex + // invalidAttempts keeps track of invalid password attempts // per IP address var invalidAttempts = map[string]int{} @@ -157,15 +168,22 @@ func createJWT(r *http.Request, s *Session) string { if !isValidSession(r, s) { return "" } + alg := JWTAlgo + aud := JWTIssuer + SSO := false + if r.Context().Value(CKey("aud")) != nil { + aud = r.Context().Value(CKey("aud")).(string) + SSO = true + } header := map[string]interface{}{ - "alg": "HS256", + "alg": alg, "typ": "JWT", } payload := map[string]interface{}{ "sub": s.User.Username, "iat": s.LastLogin.Unix(), "iss": JWTIssuer, - "aud": JWTIssuer, + "aud": aud, } if s.ExpiresOn != nil { payload["exp"] = s.ExpiresOn.Unix() @@ -176,16 +194,98 @@ func createJWT(r *http.Request, s *Session) string { payload = CustomJWT(r, s, payload) } - jHeader, _ := json.Marshal(header) - jPayload, _ := json.Marshal(payload) - b64Header := base64.RawURLEncoding.EncodeToString(jHeader) - b64Payload := base64.RawURLEncoding.EncodeToString(jPayload) + // TODO: Add custom handler to customize JWT + // This custom function show have parameters for: + // JWT Object + // SSO boolean + // Algorithm + // User + // *Session + + if alg == "HS256" { + jHeader, _ := json.Marshal(header) + jPayload, _ := json.Marshal(payload) + b64Header := base64.RawURLEncoding.EncodeToString(jHeader) + b64Payload := base64.RawURLEncoding.EncodeToString(jPayload) + + hash := hmac.New(sha256.New, []byte(JWT+s.Key)) + hash.Write([]byte(b64Header + "." + b64Payload)) + signature := hash.Sum(nil) + b64Signature := base64.RawURLEncoding.EncodeToString(signature) + return b64Header + "." + b64Payload + "." + b64Signature + } else if alg == "RS256" { + buf, err := os.ReadFile(".jwt-rsa-private.pem") + if err != nil { + return "" + } + key, err := jwt.ParseRSAPrivateKeyFromPEM(buf) + if err != nil { + return "" + } + + // Customize JWT Data + header["kid"] = "1" + + // Extra customization for SSO + if SSO { + payload["name"] = s.User.String() + payload["given_name"] = s.User.FirstName + payload["family_name"] = s.User.LastName + payload["email"] = s.User.Email + if s.User.Photo != "" { + payload["picture"] = JWTIssuer + strings.TrimSuffix(RootURL, "/") + s.User.Photo + "?token=" + strings.TrimPrefix(hashPass(s.User.Photo), "$2a$12$") + } + + groups := []map[string]interface{}{} + + if s.User.UserGroupID != 0 { + Preload(&s.User, "UserGroup") + groups = append(groups, map[string]interface{}{ + "displayName": s.User.UserGroup.GroupName, + "id": s.User.UserGroupID, + }) + } + if s.User.Admin { + groups = append(groups, map[string]interface{}{ + "displayName": "$admin", + "id": 0, + }) + } + payload["groups"] = groups + + entitlements := []map[string]interface{}{} + for k := range models { + perm := s.User.GetAccess(k) + entitlements = append(entitlements, map[string]interface{}{ + "modelName": k, + "read": perm.Read, + "add": perm.Add, + "edit": perm.Edit, + "delete": perm.Delete, + "approval": perm.Approval, + }) + } + + payload["entitlements"] = entitlements + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(payload)) + + for k, v := range header { + token.Header[k] = v + } + + tokenRaw, err := token.SignedString(key) + + if err != nil { + return "" + } + return tokenRaw + } else { + Trail(ERROR, "Unknown algorithm for JWT (%s)", alg) + return "" + } - hash := hmac.New(sha256.New, []byte(JWT+s.Key)) - hash.Write([]byte(b64Header + "." + b64Payload)) - signature := hash.Sum(nil) - b64Signature := base64.RawURLEncoding.EncodeToString(signature) - return b64Header + "." + b64Payload + "." + b64Signature } func isValidSession(r *http.Request, s *Session) bool { @@ -360,6 +460,8 @@ func Logout(r *http.Request) { // Delete the cookie from memory if we sessions are cached if CacheSessions { + cachedSessionsMutex.Lock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.Unlock() // Ensure the mutex is unlocked when the function exits delete(cachedSessions, s.Key) } @@ -565,6 +667,8 @@ func getNetSize(r *http.Request, net string) int { func getSessionByKey(key string) *Session { s := Session{} if CacheSessions { + cachedSessionsMutex.RLock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.RUnlock() // Ensure the mutex is unlocked when the function exits s = cachedSessions[key] } else { Get(&s, "`key` = ?", key) @@ -575,140 +679,353 @@ func getSessionByKey(key string) *Session { return &s } -func getSession(r *http.Request) string { - key, err := r.Cookie("session") - if err == nil && key != nil { - return key.Value +func getJWT(r *http.Request) string { + // JWT + if r.Header.Get("Authorization") == "" { + return "" } - if r.Method == "GET" && r.FormValue("session") != "" { - return r.FormValue("session") + if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") { + return "" } - if r.Method != "GET" { - err := r.ParseMultipartForm(2 << 10) - if err != nil { - r.ParseForm() - } - if r.FormValue("session") != "" { - return r.FormValue("session") - } + + jwtToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + jwtParts := strings.Split(jwtToken, ".") + + if len(jwtParts) != 3 { + return "" } - // JWT - if r.Header.Get("Authorization") != "" { - if strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") { - jwt := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - jwtParts := strings.Split(jwt, ".") - if len(jwtParts) != 3 { - return "" - } + jHeader, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[0]) + if err != nil { + return "" + } + jPayload, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[1]) + if err != nil { + return "" + } - jHeader, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[0]) - if err != nil { - return "" + header := map[string]interface{}{} + err = json.Unmarshal(jHeader, &header) + if err != nil { + return "" + } + + // Get data from payload + payload := map[string]interface{}{} + err = json.Unmarshal(jPayload, &payload) + if err != nil { + return "" + } + + // Verify issuer + SSOLogin := false + if iss, ok := payload["iss"].(string); ok { + if iss != JWTIssuer { + accepted := false + for _, fiss := range AcceptedJWTIssuers { + if fiss == iss { + accepted = true + break + } } - jPayload, err := base64.RawURLEncoding.WithPadding(base64.NoPadding).DecodeString(jwtParts[1]) - if err != nil { + if !accepted { return "" } + SSOLogin = true + } + } else { + return "" + } - header := map[string]interface{}{} - err = json.Unmarshal(jHeader, &header) - if err != nil { - return "" + // verify audience + if aud, ok := payload["aud"].(string); ok { + if aud != JWTIssuer { + return "" + } + } else if aud, ok := payload["aud"].([]string); ok { + accepted := false + for _, audItem := range aud { + if audItem == JWTIssuer { + accepted = true + break } + } + if !accepted { + return "" + } + } else { + return "" + } - // Get data from payload - payload := map[string]interface{}{} - err = json.Unmarshal(jPayload, &payload) - if err != nil { - return "" - } + // if there is no subject, return empty session + if _, ok := payload["sub"].(string); !ok { + return "" + } - // Verify issuer - if iss, ok := payload["iss"].(string); ok { - if iss != JWTIssuer { - accepted := false - for _, fiss := range AcceptedJWTIssuers { - if fiss == iss { - accepted = true - break - } - } - if !accepted { - return "" + sub := payload["sub"].(string) + user := User{} + Get(&user, "username = ?", sub) + + if user.ID == 0 && SSOLogin { + now := time.Now() + user := User{ + Username: sub, + FirstName: payload["given_name"].(string), + LastName: payload["family_name"].(string), + Active: true, + Admin: func() bool { + for _, group := range payload["groups"].([]interface{}) { + if group.(map[string]interface{})["id"].(float64) == 0 { + return true } } - } else { - return "" - } + return false + }(), + LastLogin: &now, + RemoteAccess: true, //TODO: add remote access in JWT + Password: GenerateBase64(64), + } - // verify audience - if aud, ok := payload["aud"].(string); ok { - if aud != JWTIssuer { - return "" - } - } else if aud, ok := payload["aud"].([]string); ok { - accepted := false - for _, audItem := range aud { - if audItem == JWTIssuer { - accepted = true - break - } - } - if !accepted { - return "" - } - } else { - return "" - } + // TODO: Add custom function to customize the user before saving + // this function will receive the following parameters: + // payload, *user - // if there is no subject, return empty session - if _, ok := payload["sub"].(string); !ok { - return "" - } + user.Save() - sub := payload["sub"].(string) - user := User{} - Get(&user, "username = ?", sub) + // process entitlements + // TODO: find a way to refresh entitlements every login + } else if user.ID == 0 { + return "" + } - if user.ID == 0 { - return "" - } + session := user.GetActiveSession() + if session == nil && SSOLogin { + session = &Session{ + UserID: user.ID, + Active: true, + LoginTime: time.Now(), + IP: GetRemoteIP(r), + } + session.GenerateKey() - session := user.GetActiveSession() - if session == nil { - return "" - } + // TODO: Add custom function to customize the user session + // this function will receive the following parameters: + // payload, user, *session - // TODO: verify exp + session.Save() + } else if session == nil { + return "" + } - // Verify the signature - alg := "HS256" - if v, ok := header["alg"].(string); ok { - alg = v - } - if _, ok := header["typ"]; ok { - if v, ok := header["typ"].(string); !ok || v != "JWT" { - return "" - } - } - switch alg { - case "HS256": - // TODO: allow third party JWT signature authentication - hash := hmac.New(sha256.New, []byte(JWT+session.Key)) - hash.Write([]byte(jwtParts[0] + "." + jwtParts[1])) - token := hash.Sum(nil) - b64Token := base64.RawURLEncoding.EncodeToString(token) - if b64Token != jwtParts[2] { - return "" - } - default: - // For now, only support HMAC-SHA256 - return "" - } - return session.Key + // TODO: verify exp + + // Verify the signature + alg := "HS256" + if v, ok := header["alg"].(string); ok { + alg = v + } + if _, ok := header["typ"]; ok { + if v, ok := header["typ"].(string); !ok || v != "JWT" { + return "" } } + // verify signature + switch alg { + case "HS256": + // TODO: allow third party JWT signature authentication + hash := hmac.New(sha256.New, []byte(JWT+session.Key)) + hash.Write([]byte(jwtParts[0] + "." + jwtParts[1])) + token := hash.Sum(nil) + b64Token := base64.RawURLEncoding.EncodeToString(token) + if b64Token != jwtParts[2] { + return "" + } + case "RS256": + if !verifyRSA(jwtToken, SSOLogin) { + return "" + } + default: + // For now, only support HMAC-SHA256 + return "" + } + + return session.Key + +} + +var jwtIssuerCerts = map[[2]string][]byte{} + +func getJWTRSAPublicKeySSO(jwtToken *jwt.Token) *rsa.PublicKey { + iss, err := jwtToken.Claims.GetIssuer() + if err != nil { + return nil + } + + kid, _ := jwtToken.Header["kid"].(string) + if kid == "" { + return nil + } + + if val, ok := jwtIssuerCerts[[2]string{iss, kid}]; ok { + cert, _ := jwt.ParseRSAPublicKeyFromPEM(val) + return cert + } + + res, err := http.Get(iss + "/.well-known/openid-configuration") + if err != nil { + return nil + } + + if res.StatusCode != 200 { + return nil + } + + buf, err := io.ReadAll(res.Body) + if err != nil { + return nil + } + + obj := map[string]interface{}{} + err = json.Unmarshal(buf, &obj) + if err != nil { + return nil + } + + crtURL := "" + if val, ok := obj["jwks_uri"].(string); !ok || val == "" { + return nil + } else { + crtURL = val + } + + res, err = http.Get(crtURL) + if err != nil { + return nil + } + + if res.StatusCode != 200 { + return nil + } + + buf, err = io.ReadAll(res.Body) + if err != nil { + return nil + } + + certObj := map[string][]map[string]string{} + err = json.Unmarshal(buf, &certObj) + if err != nil { + return nil + } + + if val, ok := certObj["keys"]; !ok || len(val) == 0 { + return nil + } + + var cert map[string]string + for i := range certObj["keys"] { + if certObj["keys"][i]["kid"] == kid { + cert = certObj["keys"][i] + break + } + } + + if cert == nil { + return nil + } + + N := new(big.Int) + buf, _ = base64.RawURLEncoding.DecodeString(cert["n"]) + N = N.SetBytes(buf) + + E := new(big.Int) + buf, _ = base64.RawURLEncoding.DecodeString(cert["e"]) + E = E.SetBytes(buf) + publicCert := rsa.PublicKey{ + N: N, + E: int(E.Int64()), + } + + return &publicCert +} + +func getJWTRSAPublicKeyLocal(jwtToken *jwt.Token) *rsa.PublicKey { + pubKeyPEM, err := os.ReadFile(".jwt-rsa-public.pem") + if err != nil { + return nil + } + + pubKey, err := jwt.ParseRSAPublicKeyFromPEM(pubKeyPEM) + if err != nil { + return nil + } + + return pubKey +} + +func verifyRSA(token string, SSOLogin bool) bool { + tok, err := jwt.Parse(token, func(jwtToken *jwt.Token) (interface{}, error) { + if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected method: %s", jwtToken.Header["alg"]) + } + + var pubKey *rsa.PublicKey + + if SSOLogin { + pubKey = getJWTRSAPublicKeySSO(jwtToken) + } else { + pubKey = getJWTRSAPublicKeyLocal(jwtToken) + } + + if pubKey == nil { + return nil, fmt.Errorf("Unable to load local public key") + } + + return pubKey, nil + }) + if err != nil { + return false + } + + _, ok := tok.Claims.(jwt.MapClaims) + if !ok || !tok.Valid { + return false + } + + return true +} + +func getSession(r *http.Request) string { + // First, try JWT + if val := getJWT(r); val != "" { + return val + } + + if r.URL.Query().Get("access-token") != "" { + r.Header.Add("Authorization", "Bearer "+r.URL.Query().Get("access-token")) + if val := getJWT(r); val != "" { + return val + } + } + + // Then try session + key, err := r.Cookie("session") + if err == nil && key != nil { + return key.Value + } + if r.Method == "GET" && r.FormValue("session") != "" { + return r.FormValue("session") + } + if r.Method != "GET" { + err := r.ParseMultipartForm(2 << 10) + if err != nil { + r.ParseForm() + } + if r.FormValue("session") != "" { + return r.FormValue("session") + } + } + return "" } diff --git a/check_csrf.go b/check_csrf.go index 64f9c57c..5036244f 100644 --- a/check_csrf.go +++ b/check_csrf.go @@ -40,13 +40,16 @@ If you you call this API: It will return an error message and the system will create a CRITICAL level log with details about the possible attack. To make the request -work, `x-csrf-token` paramtere should be added. +work, `x-csrf-token` parameter should be added. http://0.0.0.0:8080/myapi/?x-csrf-token=MY_SESSION_KEY Where you replace `MY_SESSION_KEY` with the session key. */ func CheckCSRF(r *http.Request) bool { + if getJWT(r) != "" { + return false + } token := getCSRFToken(r) if token != "" && token == getSession(r) { return false diff --git a/d_api_auth.go b/d_api_auth.go index 6481486c..02c2af2d 100644 --- a/d_api_auth.go +++ b/d_api_auth.go @@ -31,6 +31,10 @@ func dAPIAuthHandler(w http.ResponseWriter, r *http.Request, s *Session) { dAPIResetPasswordHandler(w, r, s) case "changepassword": dAPIChangePasswordHandler(w, r, s) + case "openidlogin": + dAPIOpenIDLoginHandler(w, r, s) + case "certs": + dAPIOpenIDCertHandler(w, r) default: w.WriteHeader(http.StatusNotFound) ReturnJSON(w, r, map[string]interface{}{ diff --git a/d_api_openid_cert_handler.go b/d_api_openid_cert_handler.go new file mode 100644 index 00000000..9b82e788 --- /dev/null +++ b/d_api_openid_cert_handler.go @@ -0,0 +1,45 @@ +package uadmin + +import ( + "encoding/base64" + "math/big" + "net/http" + "os" + + "github.com/golang-jwt/jwt/v5" +) + +func dAPIOpenIDCertHandler(w http.ResponseWriter, r *http.Request) { + buf, err := os.ReadFile(".jwt-rsa-public.pem") + if err != nil { + w.WriteHeader(404) + ReturnJSON(w, r, map[string]interface{}{ + "status": "error", + "err_msg": "Unable to load public certificate", + }) + return + } + cert, err := jwt.ParseRSAPublicKeyFromPEM(buf) + if err != nil { + w.WriteHeader(404) + ReturnJSON(w, r, map[string]interface{}{ + "status": "error", + "err_msg": "Unable to parse public certificate", + }) + return + } + obj := map[string][]map[string]string{ + "keys": { + { + "kid": "1", + "use": "sig", + "kty": "RSA", + "alg": "RS256", + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(cert.E)).Bytes()), + "n": base64.RawURLEncoding.EncodeToString(cert.N.Bytes()), + }, + }, + } + + ReturnJSON(w, r, obj) +} diff --git a/d_api_openid_login.go b/d_api_openid_login.go new file mode 100644 index 00000000..fa2e5f1d --- /dev/null +++ b/d_api_openid_login.go @@ -0,0 +1,73 @@ +package uadmin + +import ( + "context" + "net/http" + "strings" +) + +func dAPIOpenIDLoginHandler(w http.ResponseWriter, r *http.Request, s *Session) { + _ = s + redirectURI := r.FormValue("redirect_uri") + + Trail(DEBUG, "HERE") + + if r.Method == "GET" { + if session := IsAuthenticated(r); session != nil { + Preload(session, "User") + c := map[string]interface{}{ + "SiteName": SiteName, + "Language": getLanguage(r), + "RootURL": RootURL, + "Logo": Logo, + "user": s.User, + "OpenIDWebsiteURL": redirectURI, + } + RenderHTML(w, r, "./templates/uadmin/"+Theme+"/openid_concent.html", c) + return + } + + http.Redirect(w, r, RootURL+"login/?next="+RootURL+"api/d/auth/openidlogin?"+r.URL.Query().Encode(), 303) + return + } + + if s == nil { + w.WriteHeader(http.StatusUnauthorized) + ReturnJSON(w, r, map[string]interface{}{ + "status": "error", + "err_msg": "Invalid credentials", + }) + return + } + + // Preload the user to get the group name + Preload(&s.User) + + ctx := context.WithValue(r.Context(), CKey("aud"), getAUD(redirectURI)) + r = r.WithContext(ctx) + jwt := createJWT(r, s) + + http.Redirect(w, r, redirectURI+"?access-token="+jwt, 303) + +} + +func getAUD(URL string) string { + aud := "" + + if strings.HasPrefix(URL, "https://") { + aud = "https://" + URL = strings.TrimPrefix(URL, "https://") + } + + if strings.HasPrefix(URL, "http://") { + aud = "http://" + URL = strings.TrimPrefix(URL, "http://") + } + + if strings.Contains(URL, "/") { + URL = URL[:strings.Index(URL, "/")] + aud += URL + } + + return aud +} diff --git a/d_api_read.go b/d_api_read.go index cb357a64..8ab946e5 100644 --- a/d_api_read.go +++ b/d_api_read.go @@ -245,6 +245,8 @@ func dAPIReadHandler(w http.ResponseWriter, r *http.Request, s *Session) { if int(GetID(m)) != 0 { i = m.Interface() rowsCount = 1 + } else { + w.WriteHeader(404) } if params["$preload"] == "1" || params["$preload"] == "true" { diff --git a/d_api_signup.go b/d_api_signup.go index 1a03570a..77b29a0b 100644 --- a/d_api_signup.go +++ b/d_api_signup.go @@ -71,6 +71,7 @@ func dAPISignupHandler(w http.ResponseWriter, r *http.Request, s *Session) { "status": "error", "err_msg": "username taken", }) + return } // if the user is active, then login in diff --git a/generate_translation.go b/generate_translation.go index cd08abd2..33c4d6c8 100644 --- a/generate_translation.go +++ b/generate_translation.go @@ -168,7 +168,7 @@ func syncModelTranslation(m ModelSchema) map[string]int { fileName := "./static/i18n/" + pkgName + "/" + m.ModelName + ".en.json" - // Check if the fist doesn't exist and create it + // Check if the first doesn't exist and create it if _, err = os.Stat(fileName); os.IsNotExist(err) { buf, _ = json.MarshalIndent(structLang, "", " ") err = ioutil.WriteFile(fileName, buf, 0644) diff --git a/global.go b/global.go index e2a9b377..e923b301 100644 --- a/global.go +++ b/global.go @@ -81,7 +81,7 @@ const cEMAIL = "email" const cM2M = "m2m" // Version number as per Semantic Versioning 2.0.0 (semver.org) -const Version = "0.10.0" +const Version = "0.10.1" // VersionCodeName is the cool name we give to versions with significant changes. // This name should always be a bug's name starting from A-Z them revolving back. @@ -481,6 +481,15 @@ var PreLoginHandler func(r *http.Request, username string, password string) // PreLoginHandler is a function that runs after all dAPI delete requests var PostUploadHandler func(filePath string, modelName string, f *F) string +// CompressJSON is a variable that allows the user to reduce the size of JSON responses +var CompressJSON = false + +// CompressJSON is a variable that allows the user to reduce the size of JSON responses +var RemoveZeroValueJSON = false + +// SSOURL enables SSO using OpenID Connect +var SSOURL = "" + // Private Global Variables // Regex var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") diff --git a/go.mod b/go.mod index 4c2ac87f..f9aa324a 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( require ( github.com/boombuler/barcode v1.0.1 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/golang-jwt/jwt/v5 v5.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.3.0 // indirect diff --git a/go.sum b/go.sum index 9a1dc4dc..5dea4bf6 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= diff --git a/login_handler.go b/login_handler.go index 19053791..f87dadfc 100644 --- a/login_handler.go +++ b/login_handler.go @@ -19,6 +19,7 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { Password string Logo string FavIcon string + SSOURL string } c := Context{} @@ -27,6 +28,15 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { c.Language = getLanguage(r) c.Logo = Logo c.FavIcon = FavIcon + c.SSOURL = SSOURL + + if session := IsAuthenticated(r); session != nil { + session = session.User.GetActiveSession() + SetSessionCookie(w, r, session) + if r.URL.Query().Get("next") != "" { + http.Redirect(w, r, r.URL.Query().Get("next"), 303) + } + } if r.Method == cPOST { if r.FormValue("save") == "Send Request" { diff --git a/main_handler.go b/main_handler.go index e17bef0e..e9c4095a 100644 --- a/main_handler.go +++ b/main_handler.go @@ -80,6 +80,10 @@ func mainHandler(w http.ResponseWriter, r *http.Request) { settingsHandler(w, r, session) return } + if URLParts[0] == "login" { + loginHandler(w, r) + return + } listHandler(w, r, session) return } else if len(URLParts) == 2 { diff --git a/media_handler.go b/media_handler.go index b2078e9a..785cfa5a 100644 --- a/media_handler.go +++ b/media_handler.go @@ -9,7 +9,9 @@ import ( func mediaHandler(w http.ResponseWriter, r *http.Request) { session := IsAuthenticated(r) - if session == nil && !PublicMedia { + token := r.URL.Query().Get("token") + if session == nil && !PublicMedia && token == "" { + w.WriteHeader(401) loginHandler(w, r) return } @@ -25,6 +27,14 @@ func mediaHandler(w http.ResponseWriter, r *http.Request) { fName := path.Clean(r.URL.Path) + if session == nil && !PublicMedia && token != "" { + // this request for a limited request for one resource + if verifyPassword("$2a$12$"+token, fName) != nil { + w.WriteHeader(401) + return + } + } + f, err := os.Open("." + fName) if err != nil { w.WriteHeader(404) diff --git a/model.go b/model.go index 7f886063..acae1e8b 100644 --- a/model.go +++ b/model.go @@ -8,5 +8,5 @@ import ( // in any other struct to make it a model for uadmin type Model struct { ID uint `gorm:"primary_key"` - DeletedAt gorm.DeletedAt `sql:"index"` + DeletedAt gorm.DeletedAt `sql:"index" json:"-"` } diff --git a/openid_config_handler.go b/openid_config_handler.go new file mode 100644 index 00000000..f22d2e4a --- /dev/null +++ b/openid_config_handler.go @@ -0,0 +1,50 @@ +package uadmin + +import "net/http" + +func JWTConfigHandler(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "issuer": JWTIssuer, + "authorization_endpoint": JWTIssuer + "/api/d/auth/openidlogin", + "token_endpoint": "", + "userinfo_endpoint": JWTIssuer + "/api/d/auth/userinfo", + "jwks_uri": JWTIssuer + "/api/d/auth/certs", + "scopes_supported": []string{ + "openid", + "email", + "profile", + }, + "response_types_supported": []string{ + "code", + "token", + "id_token", + "code token", + "code id_token", + "token id_token", + "code token id_token", + "none", + }, + "subject_types_supported": []string{ + "public", + }, + "id_token_signing_alg_values_supported": []string{ + "RS256", + }, + "claims_supported": []string{ + "aud", + "email", + "email_verified", + "exp", + "family_name", + "given_name", + "iat", + "iss", + "locale", + "name", + "picture", + "sub", + }, + } + + ReturnJSON(w, r, data) +} diff --git a/register.go b/register.go index fa7586c1..a90d7bd4 100644 --- a/register.go +++ b/register.go @@ -382,5 +382,9 @@ func registerHandlers() { http.HandleFunc(RootURL+"api/", Handler(apiHandler)) } + if !DisableDAPIAuth { + http.HandleFunc(RootURL+".well-known/openid-configuration/", Handler(JWTConfigHandler)) + } + handlersRegistered = true } diff --git a/representation.go b/representation.go index 8373bafe..6bf25512 100644 --- a/representation.go +++ b/representation.go @@ -19,7 +19,10 @@ func GetID(m reflect.Value) uint { func GetString(a interface{}) string { str, ok := a.(fmt.Stringer) if ok { - return str.String() + if a != nil { + return str.String() + } + return "" } t := reflect.TypeOf(a) v := reflect.ValueOf(a) diff --git a/session.go b/session.go index 774ca3ef..b0fd9d2d 100644 --- a/session.go +++ b/session.go @@ -41,6 +41,8 @@ func (s *Session) Save() { Save(s) s.User = u if CacheSessions { + cachedSessionsMutex.Lock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.Unlock() // Ensure the mutex is unlocked when the function exits if s.Active { Preload(s) cachedSessions[s.Key] = *s @@ -81,6 +83,9 @@ func loadSessions() { if !CacheSessions { return } + cachedSessionsMutex.Lock() // Lock the mutex in order to protect from concurrent writes + defer cachedSessionsMutex.Unlock() // Ensure the mutex is unlocked when the function exits + sList := []Session{} Filter(&sList, "`active` = ? AND (expires_on IS NULL OR expires_on > ?)", true, time.Now()) cachedSessions = map[string]Session{} diff --git a/setting.go b/setting.go index 0b86a1af..bf5fc771 100644 --- a/setting.go +++ b/setting.go @@ -287,6 +287,10 @@ func (s *Setting) ApplyValue() { Logo = v.(string) case "uAdmin.FavIcon": FavIcon = v.(string) + case "uAdmin.CompressJSON": + CompressJSON = v.(bool) + case "uAdmin.RemoveZeroValueJSON": + RemoveZeroValueJSON = v.(bool) } } @@ -821,6 +825,32 @@ func syncSystemSettings() { DataType: t.File(), Help: "the fav icon that shows on uAdmin UI", }, + { + Name: "Compress JSON", + Value: func(v bool) string { + n := 0 + if v { + n = 1 + } + return fmt.Sprint(n) + }(CompressJSON), + DefaultValue: "0", + DataType: t.Boolean(), + Help: "Compress JSON allows the system to reduce the size of json responses", + }, + { + Name: "Remove Zero Value JSON", + Value: func(v bool) string { + n := 0 + if v { + n = 1 + } + return fmt.Sprint(n) + }(RemoveZeroValueJSON), + DefaultValue: "0", + DataType: t.Boolean(), + Help: "Compress JSON allows the system to reduce the size of json responses", + }, } // Prepare uAdmin Settings diff --git a/templates/uadmin/default/login.html b/templates/uadmin/default/login.html index dca60b1b..081a4d82 100644 --- a/templates/uadmin/default/login.html +++ b/templates/uadmin/default/login.html @@ -95,6 +95,7 @@