package middlewares import ( "bssapp-backend/internal/authz" "bssapp-backend/permissions" "bytes" "database/sql" "encoding/json" "io" "log" "net/http" "strconv" "strings" "sync" "time" "bssapp-backend/auth" "github.com/gorilla/mux" ) /* AuthzGuardV2 - module+action role permission check (mk_sys_role_permissions) - optional scope checks (department / piyasa) via intersection: user_allowed ∩ role_allowed - cache with TTL Expected: - AuthMiddleware runs before this and sets JWT claims in context. - claims should contain RoleID and UserID. */ // ===================================================== // 🔧 CONFIG / CONSTANTS // ===================================================== const ( defaultPermTTL = 60 * time.Second defaultScopeTTL = 30 * time.Second maxBodyRead = 1 << 20 // 1MB ) // ===================================================== // 🧠 CACHE // ===================================================== type cacheItem struct { val any expires time.Time } type ttlCache struct { mu sync.RWMutex ttl time.Duration m map[string]cacheItem } // ===================================================== // 🌍 GLOBAL SCOPE CACHE (for invalidation) // ===================================================== var globalScopeCache *ttlCache func newTTLCache(ttl time.Duration) *ttlCache { return &ttlCache{ ttl: ttl, m: make(map[string]cacheItem), } } func (c *ttlCache) get(key string) (any, bool) { now := time.Now() c.mu.RLock() item, ok := c.m[key] c.mu.RUnlock() if !ok { return nil, false } if now.After(item.expires) { // lazy delete c.mu.Lock() delete(c.m, key) c.mu.Unlock() return nil, false } return item.val, true } func (c *ttlCache) set(key string, val any) { c.mu.Lock() c.m[key] = cacheItem{val: val, expires: time.Now().Add(c.ttl)} c.mu.Unlock() } // ===================================================== // ✅ MAIN MIDDLEWARE // ===================================================== type AuthzV2Options struct { // If true, scope checks are attempted when scope can be extracted. EnableScopeChecks bool // If true, when scope is required but cannot be extracted, deny. // If false, when scope cannot be extracted, scope check is skipped. StrictScope bool // Override TTLs (optional) PermTTL time.Duration ScopeTTL time.Duration // Custom extractors (optional). If nil, built-in extractors are used. ExtractDepartmentCodes func(r *http.Request) []string ExtractPiyasaCodes func(r *http.Request) []string // Decide whether this request should be treated as scope-sensitive. // If nil, built-in heuristic is used. IsScopeSensitive func(module string, r *http.Request) bool } func AuthzGuardV2(pg *sql.DB, module string, action string) func(http.Handler) http.Handler { return AuthzGuardV2WithOptions(pg, module, action, AuthzV2Options{ EnableScopeChecks: true, StrictScope: false, }) } func AuthzGuardV2WithOptions(pg *sql.DB, module string, action string, opt AuthzV2Options) func(http.Handler) http.Handler { permTTL := opt.PermTTL if permTTL <= 0 { permTTL = defaultPermTTL } scopeTTL := opt.ScopeTTL if scopeTTL <= 0 { scopeTTL = defaultScopeTTL } permCache := newTTLCache(permTTL) if globalScopeCache == nil { globalScopeCache = newTTLCache(scopeTTL) } scopeCache := globalScopeCache // default extractors extractDept := opt.ExtractDepartmentCodes if extractDept == nil { extractDept = defaultExtractDepartmentCodes } extractPiy := opt.ExtractPiyasaCodes if extractPiy == nil { extractPiy = defaultExtractPiyasaCodes } isScopeSensitive := opt.IsScopeSensitive if isScopeSensitive == nil { isScopeSensitive = defaultIsScopeSensitive } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // OPTIONS passthrough if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } claims, ok := auth.GetClaimsFromContext(r.Context()) if !ok || claims == nil { http.Error(w, "unauthorized", 401) return } userID := claims.ID roleCode := claims.RoleCode // ADMIN BYPASS if claims.IsAdmin() { next.ServeHTTP(w, r) return } // resolve role_id from role_code roleID, err := cachedRoleID(pg, permCache, roleCode) if err != nil { log.Println("❌ role resolve error:", err) http.Error(w, "forbidden", http.StatusForbidden) return } // -------------------------------------------------- // 🔐 PERMISSION RESOLUTION (USER > ROLE > DENY) // -------------------------------------------------- permRepo := permissions.NewPermissionRepository(pg) allowed := false resolved := false // karar verildi mi? // -------------------------------------------------- // 1️⃣ USER OVERRIDE (ÖNCELİK) // -------------------------------------------------- overrides, err := permRepo.GetUserOverrides(userID) if err != nil { log.Println("❌ override load error:", err) http.Error(w, "forbidden", http.StatusForbidden) return } for _, o := range overrides { if o.Module == module && o.Action == action { log.Printf( "🔁 USER OVERRIDE → %s:%s = %v", module, action, o.Allowed, ) allowed = o.Allowed resolved = true break } } // -------------------------------------------------- // 2️⃣ ROLE + DEPARTMENT (NEW SYSTEM) // -------------------------------------------------- if !resolved { deptCodes := claims.DepartmentCodes roleDeptAllowed, err := permRepo.ResolvePermissionChain( userID, roleID, deptCodes, module, action, ) if err != nil { log.Println("❌ perm resolve error:", err) http.Error(w, "forbidden", http.StatusForbidden) return } if roleDeptAllowed { log.Printf("🆕 ROLE+DEPT → %s:%s = true", module, action) allowed = true resolved = true } else { log.Printf("🆕 ROLE+DEPT → %s:%s = false (try legacy)", module, action) } } // -------------------------------------------------- // 3️⃣ ROLE ONLY (LEGACY FALLBACK) // -------------------------------------------------- if !resolved { legacyAllowed, err := cachedRolePermission( pg, permCache, roleID, module, action, ) if err != nil { log.Println("❌ legacy perm error:", err) http.Error(w, "forbidden", http.StatusForbidden) return } log.Printf("🕰️ LEGACY ROLE → %s:%s = %v", module, action, legacyAllowed) allowed = legacyAllowed resolved = true } // -------------------------------------------------- // 3️⃣ FINAL DECISION // -------------------------------------------------- if !allowed { log.Printf( "⛔ ACCESS DENIED user=%d %s:%s path=%s", claims.ID, module, action, r.URL.Path, ) http.Error(w, "forbidden", http.StatusForbidden) return } log.Printf( "✅ ACCESS OK user=%d %s:%s %s", claims.ID, module, action, r.URL.Path, ) // -------------------------------------------------- // 4️⃣ OPTIONAL SCOPE CHECKS (FINAL - SECURE) // -------------------------------------------------- if opt.EnableScopeChecks && isScopeSensitive(module, r) { // 🔹 Request’ten gelenler reqDepts := normalizeCodes(extractDept(r)) reqPiy := normalizeCodes(extractPiy(r)) ctx := r.Context() // 🔹 USER PIYASA (DB’den) userPiy, err := authz.GetUserPiyasaCodes(pg, int(userID)) if err != nil { log.Println("❌ piyasa load error:", err) http.Error(w, "forbidden", 403) return } userPiy = normalizeCodes(userPiy) // ------------------------------------------------ // ✅ PIYASA INTERSECTION // ------------------------------------------------ var effectivePiy []string switch { case len(reqPiy) > 0 && len(userPiy) > 0: effectivePiy = intersect(reqPiy, userPiy) case len(reqPiy) > 0 && len(userPiy) == 0: // user piyasa tanımlı değilse request'e güvenme → boş kalsın (StrictScope varsa deny) effectivePiy = nil case len(reqPiy) == 0 && len(userPiy) > 0: effectivePiy = userPiy } if len(reqPiy) > 0 && len(effectivePiy) == 0 { // request piyasa istiyor ama user scope karşılamıyor http.Error(w, "forbidden", http.StatusForbidden) return } // ------------------------------------------------ // ✅ CONTEXT’E YAZ // ------------------------------------------------ if len(reqDepts) > 0 { ctx = authz.WithDeptCodes(ctx, reqDepts) } if len(effectivePiy) > 0 { ctx = authz.WithPiyasaCodes(ctx, effectivePiy) } r = r.WithContext(ctx) // ------------------------------------------------ // ❗ STRICT MODE // ------------------------------------------------ if len(reqDepts) == 0 && len(effectivePiy) == 0 { if opt.StrictScope { http.Error(w, "forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) return } // ------------------------------------------------ // 🔍 DEPARTMENT CHECK // ------------------------------------------------ if len(reqDepts) > 0 { okDept, err := cachedDeptIntersectionAny( pg, scopeCache, userID, roleID, reqDepts, ) if err != nil { log.Println("❌ dept scope error:", err) http.Error(w, "forbidden", http.StatusForbidden) return } if !okDept { http.Error(w, "forbidden", http.StatusForbidden) return } } // ------------------------------------------------ // 🔍 PIYASA CHECK // ------------------------------------------------ if len(effectivePiy) > 0 { okPiy, err := cachedPiyasaIntersectionAny( pg, scopeCache, userID, roleID, effectivePiy, ) if err != nil { log.Println("❌ piyasa scope error:", err) http.Error(w, "forbidden", http.StatusForbidden) return } if !okPiy { http.Error(w, "forbidden", http.StatusForbidden) return } } } // -------------------------------------------------- // ✅ ALLOW // -------------------------------------------------- next.ServeHTTP(w, r) }) } } // ===================================================== // 🔐 PERMISSION CHECK (mk_sys_role_permissions) // ===================================================== func cachedRolePermission(pg *sql.DB, c *ttlCache, roleID int64, module, action string) (bool, error) { key := "perm|" + itoa(roleID) + "|" + module + "|" + action if v, ok := c.get(key); ok { return v.(bool), nil } var allowed bool err := pg.QueryRow(` SELECT allowed FROM mk_sys_role_permissions WHERE role_id = $1 AND module_code = $2 AND action = $3 `, roleID, module, action).Scan(&allowed) if err == sql.ErrNoRows { c.set(key, false) return false, nil } if err != nil { return false, err } c.set(key, allowed) return allowed, nil } // ===================================================== // 🧩 SCOPE INTERSECTION // user scope ∩ role scope // ===================================================== func cachedDeptIntersectionAny(pg *sql.DB, c *ttlCache, userID, roleID int64, deptCodes []string) (bool, error) { // cache by exact request list (sorted would be ideal; normalizeCodes already stabilizes somewhat) key := "deptAny|" + itoa(userID) + "|" + itoa(roleID) + "|" + strings.Join(deptCodes, ",") if v, ok := c.get(key); ok { return v.(bool), nil } // ANY match: if request wants multiple codes, allow if at least one is in intersection // Intersection query: // user departments: dfusr_dprt -> mk_dprt(code) // role allowed: dfrole_dprt -> mk_dprt(id) OR directly by id okAny := false // We do it in a single query with ANY($3) var dummy int err := pg.QueryRow(` SELECT 1 FROM dfusr_dprt ud JOIN mk_dprt d ON d.id = ud.dprt_id JOIN dfrole_dprt rd ON rd.dprt_id = ud.dprt_id WHERE ud.dfusr_id = $1 AND rd.dfrole_id = $2 AND ud.is_active = true AND rd.is_allowed = true AND d.code = ANY($3) LIMIT 1 `, userID, roleID, pqArray(deptCodes)).Scan(&dummy) if err == sql.ErrNoRows { c.set(key, false) return false, nil } if err != nil { return false, err } okAny = true c.set(key, okAny) return okAny, nil } func cachedPiyasaIntersectionAny(pg *sql.DB, c *ttlCache, userID, roleID int64, piyasaCodes []string) (bool, error) { key := "piyAny|" + itoa(userID) + "|" + itoa(roleID) + "|" + strings.Join(piyasaCodes, ",") if v, ok := c.get(key); ok { return v.(bool), nil } var dummy int err := pg.QueryRow(` SELECT 1 FROM dfusr_piyasa up WHERE up.dfusr_id = $1 AND up.is_allowed = true AND up.piyasa_code = ANY($2) LIMIT 1 `, userID, pqArray(piyasaCodes)).Scan(&dummy) if err == sql.ErrNoRows { c.set(key, false) return false, nil } if err != nil { return false, err } c.set(key, true) return true, nil } // ===================================================== // 🧲 DEFAULT SCOPE DETECTION // ===================================================== // defaultIsScopeSensitive decides whether this module likely needs dept/piyasa checks. // You can tighten/extend later. func defaultIsScopeSensitive(module string, r *http.Request) bool { switch module { case "order", "customer", "report", "finance": return true default: return false } } // ===================================================== // 🔎 DEFAULT EXTRACTORS // ===================================================== // We try to extract scope from: // - query params: dprt, dprt_code, department, department_code, piyasa, piyasa_code, market // - headers: X-Department, X-Piyasa // - json body fields: department / department_code / dprt_code, piyasa / piyasa_code / market_code // (body read is safe: we re-inject the body) func defaultExtractDepartmentCodes(r *http.Request) []string { var out []string // query params for _, k := range []string{"dprt", "dprt_code", "department", "department_code"} { out = append(out, splitCSV(r.URL.Query().Get(k))...) } // headers out = append(out, splitCSV(r.Header.Get("X-Department"))...) // JSON body (if any) out = append(out, extractFromJSONBody(r, []string{ "department", "department_code", "dprt", "dprt_code", })...) return out } func defaultExtractPiyasaCodes(r *http.Request) []string { var out []string for _, k := range []string{"piyasa", "piyasa_code", "market", "market_code"} { out = append(out, splitCSV(r.URL.Query().Get(k))...) } out = append(out, splitCSV(r.Header.Get("X-Piyasa"))...) out = append(out, extractFromJSONBody(r, []string{ "piyasa", "piyasa_code", "market", "market_code", "customer_attribute", })...) return out } func extractFromJSONBody(r *http.Request, keys []string) []string { // Only for methods that might have body switch r.Method { case http.MethodPost, http.MethodPut, http.MethodPatch: default: return nil } // read body (and restore) raw, err := readBodyAndRestore(r, maxBodyRead) if err != nil || len(raw) == 0 { return nil } // try parse object var obj map[string]any if err := json.Unmarshal(raw, &obj); err != nil { return nil } var out []string for _, k := range keys { if v, ok := obj[k]; ok { switch t := v.(type) { case string: out = append(out, splitCSV(t)...) case []any: for _, it := range t { if s, ok := it.(string); ok { out = append(out, splitCSV(s)...) } } } } } return out } func readBodyAndRestore(r *http.Request, limit int64) ([]byte, error) { if r.Body == nil { return nil, nil } // Read with limit raw, err := io.ReadAll(io.LimitReader(r.Body, limit)) if err != nil { return nil, err } // restore r.Body = io.NopCloser(bytes.NewBuffer(raw)) return raw, nil } // ===================================================== // 🧼 HELPERS // ===================================================== func splitCSV(s string) []string { s = strings.TrimSpace(s) if s == "" { return nil } parts := strings.Split(s, ",") out := make([]string, 0, len(parts)) for _, p := range parts { p = strings.TrimSpace(p) if p != "" { out = append(out, p) } } return out } func normalizeCodes(in []string) []string { if len(in) == 0 { return nil } seen := map[string]struct{}{} out := make([]string, 0, len(in)) for _, s := range in { s = strings.ToUpper(strings.TrimSpace(s)) if s == "" { continue } if _, ok := seen[s]; ok { continue } seen[s] = struct{}{} out = append(out, s) } return out } // pqArray: minimal adapter to pass []string as Postgres array. // If you already use lib/pq, replace this with pq.Array. // Here we use a simple JSON array -> Postgres can cast text[] from ARRAY[]? Not directly. // So: YOU SHOULD USE lib/pq in your project. If it's already there, change pqArray() to pq.Array(slice). // // For now, we implement as a driver.Value using "{A,B}" format (Postgres text[] literal). type pgTextArray string func (a pgTextArray) Value() (any, error) { return string(a), nil } func pqArray(ss []string) any { // produce "{A,B,C}" as text[] literal // escape quotes minimally if len(ss) == 0 { return pgTextArray("{}") } var b strings.Builder b.WriteString("{") for i, s := range ss { if i > 0 { b.WriteString(",") } s = strings.ReplaceAll(s, `"`, `\"`) b.WriteString(`"`) b.WriteString(s) b.WriteString(`"`) } b.WriteString("}") return pgTextArray(b.String()) } func itoa(n int64) string { return strconv.FormatInt(n, 10) } // isolate strconv usage without importing it globally in this snippet func strconvFormatInt(n int64) string { // local minimal // NOTE: in real code just import strconv and use strconv.FormatInt(n, 10) if n == 0 { return "0" } neg := n < 0 if neg { n = -n } var buf [32]byte i := len(buf) for n > 0 { i-- buf[i] = byte('0' + (n % 10)) n /= 10 } if neg { i-- buf[i] = '-' } return string(buf[i:]) } // optional: allow passing scope explicitly from handlers via context (advanced use) type scopeKey string // ===================================================== // 🔍 ROLE RESOLVER (code -> id) WITH CACHE // ===================================================== func cachedRoleID(pg *sql.DB, c *ttlCache, roleCode string) (int64, error) { key := "role|" + strings.ToLower(roleCode) if v, ok := c.get(key); ok { return v.(int64), nil } var id int64 err := pg.QueryRow(` SELECT id FROM dfrole WHERE LOWER(code) = LOWER($1) `, roleCode).Scan(&id) if err != nil { return 0, err } c.set(key, id) return id, nil } // ===================================================== // 🧹 CACHE INVALIDATION (ADMIN) // ===================================================== func ClearAuthzScopeCacheForUser(userID int64) { // NOTE: this clears ALL scope cache. // Simple & safe. Optimize later if needed. if globalScopeCache != nil { globalScopeCache.mu.Lock() defer globalScopeCache.mu.Unlock() for k := range globalScopeCache.m { if strings.Contains(k, "|"+itoa(userID)+"|") { delete(globalScopeCache.m, k) } } } } // intersect: A ∩ B func intersect(a, b []string) []string { set := make(map[string]struct{}, len(a)) for _, v := range a { set[v] = struct{}{} } var out []string for _, v := range b { if _, ok := set[v]; ok { out = append(out, v) } } return out } func AuthzGuardByRoute(pg *sql.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // ===================================================== // 0️⃣ OPTIONS → PASS (CORS preflight) // ===================================================== if r.Method == http.MethodOptions { next.ServeHTTP(w, r) return } // ===================================================== // 1️⃣ AUTH // ===================================================== claims, ok := auth.GetClaimsFromContext(r.Context()) if !ok || claims == nil { log.Printf( "AUTHZ_BY_ROUTE 401 reason=claims_missing method=%s path=%s", r.Method, r.URL.Path, ) http.Error(w, "unauthorized: token missing or invalid", http.StatusUnauthorized) return } // ===================================================== // 2️⃣ REAL ROUTE TEMPLATE ( /api/users/{id} ) // ===================================================== route := mux.CurrentRoute(r) if route == nil { log.Printf("❌ AUTHZ: route not resolved: %s %s", r.Method, r.URL.Path, ) http.Error(w, "route not resolved", http.StatusForbidden) return } pathTemplate, err := route.GetPathTemplate() if err != nil { log.Printf("❌ AUTHZ: path template error: %v", err) http.Error(w, "route template error", http.StatusForbidden) return } // Password change must be reachable for every authenticated user. // This avoids permission deadlocks during forced first-password flow. if pathTemplate == "/api/password/change" { next.ServeHTTP(w, r) return } // Self permission endpoints are required right after login // to hydrate UI permission state for the authenticated user. switch pathTemplate { case "/api/permissions/routes", "/api/permissions/effective": next.ServeHTTP(w, r) return } // ===================================================== // 3️⃣ ROUTE LOOKUP (path + method) // ===================================================== var module, action string err = pg.QueryRow(` SELECT module_code, action FROM mk_sys_routes WHERE path = $1 AND method = $2 `, pathTemplate, r.Method, ).Scan(&module, &action) if err != nil { log.Printf( "❌ AUTHZ: route not registered: %s %s", r.Method, pathTemplate, ) if pathTemplate == "/api/password/change" { http.Error(w, "password change route permission not found", http.StatusForbidden) return } http.Error(w, "route permission not found", http.StatusForbidden) return } // ===================================================== // 4️⃣ PERMISSION RESOLVE // ===================================================== repo := permissions.NewPermissionRepository(pg) allowed, err := repo.ResolvePermissionChain( int64(claims.ID), int64(claims.RoleID), claims.DepartmentCodes, module, action, ) if err != nil { log.Printf( "❌ AUTHZ: resolve error user=%d %s:%s err=%v", claims.ID, module, action, err, ) if pathTemplate == "/api/password/change" { http.Error(w, "password change permission check failed", http.StatusForbidden) return } http.Error(w, "forbidden", http.StatusForbidden) return } if !allowed { log.Printf( "⛔ AUTHZ: denied user=%d %s:%s", claims.ID, module, action, ) if pathTemplate == "/api/password/change" { http.Error(w, "password change forbidden: permission denied", http.StatusForbidden) return } http.Error(w, "forbidden", http.StatusForbidden) return } // ===================================================== // 5️⃣ PASS // ===================================================== next.ServeHTTP(w, r) }) } }