CreatePartition(
+ HttpContext context,
+ RateLimitPolicyDefinition policy
+ )
{
- switch (policy.PolicyType)
+ var key = ResolvePartitionKey(context, policy.Target);
+
+ return policy.PolicyType switch
{
- case RateLimitPolicyType.FixedWindow:
- {
- var limiterOptions = new FixedWindowRateLimiterOptions
+ RateLimitPolicyType.FixedWindow => RateLimitPartition.GetFixedWindowLimiter(
+ key,
+ _ => new FixedWindowRateLimiterOptions
{
PermitLimit = policy.PermitLimit,
Window = policy.Window,
QueueLimit = policy.QueueLimit,
QueueProcessingOrder = QueueProcessingOrder.OldestFirst,
- };
- options.AddPolicy(
- policy.Name,
- context =>
- RateLimitPartition.GetFixedWindowLimiter(
- ResolvePartitionKey(context, policy.Target),
- _ => limiterOptions
- )
- );
- break;
- }
- case RateLimitPolicyType.SlidingWindow:
- {
- var limiterOptions = new SlidingWindowRateLimiterOptions
+ }
+ ),
+ RateLimitPolicyType.SlidingWindow => RateLimitPartition.GetSlidingWindowLimiter(
+ key,
+ _ => new SlidingWindowRateLimiterOptions
{
PermitLimit = policy.PermitLimit,
Window = policy.Window,
SegmentsPerWindow = policy.SegmentsPerWindow,
QueueLimit = policy.QueueLimit,
QueueProcessingOrder = QueueProcessingOrder.OldestFirst,
- };
- options.AddPolicy(
- policy.Name,
- context =>
- RateLimitPartition.GetSlidingWindowLimiter(
- ResolvePartitionKey(context, policy.Target),
- _ => limiterOptions
- )
- );
- break;
- }
- case RateLimitPolicyType.TokenBucket:
- {
- var limiterOptions = new TokenBucketRateLimiterOptions
+ }
+ ),
+ RateLimitPolicyType.TokenBucket => RateLimitPartition.GetTokenBucketLimiter(
+ key,
+ _ => new TokenBucketRateLimiterOptions
{
TokenLimit = policy.TokenLimit,
TokensPerPeriod = policy.TokensPerPeriod,
ReplenishmentPeriod = policy.ReplenishmentPeriod,
QueueLimit = policy.QueueLimit,
QueueProcessingOrder = QueueProcessingOrder.OldestFirst,
- };
- options.AddPolicy(
- policy.Name,
- context =>
- RateLimitPartition.GetTokenBucketLimiter(
- ResolvePartitionKey(context, policy.Target),
- _ => limiterOptions
- )
- );
- break;
- }
- default:
- options.AddPolicy(
- policy.Name,
- context =>
- RateLimitPartition.GetNoLimiter(ResolvePartitionKey(context, policy.Target))
- );
- break;
- }
+ }
+ ),
+ _ => RateLimitPartition.GetNoLimiter(key),
+ };
}
private static string ResolvePartitionKey(HttpContext context, RateLimitTarget target)
{
return target switch
{
- RateLimitTarget.Ip => context.Connection.RemoteIpAddress?.ToString() ?? "unknown",
+ RateLimitTarget.Ip => context.Connection.RemoteIpAddress?.ToString()
+ ?? UnknownIpPartitionKey,
RateLimitTarget.User => context.User.FindFirstValue(ClaimTypes.NameIdentifier)
- ?? "anonymous",
+ ?? AnonymousUserPartitionKey,
RateLimitTarget.IpAndUser =>
- $"{context.Connection.RemoteIpAddress}:{context.User.FindFirstValue(ClaimTypes.NameIdentifier) ?? "anonymous"}",
- RateLimitTarget.Global => "__global__",
- _ => context.Connection.RemoteIpAddress?.ToString() ?? "unknown",
+ $"{context.Connection.RemoteIpAddress}:{context.User.FindFirstValue(ClaimTypes.NameIdentifier) ?? AnonymousUserPartitionKey}",
+ RateLimitTarget.Global => GlobalPartitionKey,
+ _ => context.Connection.RemoteIpAddress?.ToString() ?? UnknownIpPartitionKey,
};
}
+
+ internal static class RateLimitRejectionHandler
+ {
+ private const string JsonProblemBody =
+ """{"type":"https://httpstatuses.io/429","title":"Too Many Requests","status":429,"detail":"Rate limit exceeded. Please retry after the period indicated in the Retry-After header."}""";
+
+ private const string HtmlBody = """
+
+ 429 Too Many Requests
+
+
+ Too many requests
+ You have hit the rate limit for this endpoint. Please wait and try again.
+
+ """;
+
+ public static async ValueTask HandleAsync(
+ OnRejectedContext context,
+ CancellationToken cancellationToken
+ )
+ {
+ var response = context.HttpContext.Response;
+ response.Headers["Retry-After"] = context.Lease.TryGetMetadata(
+ MetadataName.RetryAfter,
+ out var retryAfter
+ )
+ ? ((int)retryAfter.TotalSeconds).ToString(CultureInfo.InvariantCulture)
+ : "60";
+
+ if (PrefersHtml(context.HttpContext.Request))
+ {
+ response.ContentType = "text/html; charset=utf-8";
+ await response.WriteAsync(HtmlBody, cancellationToken);
+ }
+ else
+ {
+ response.ContentType = "application/problem+json";
+ await response.WriteAsync(JsonProblemBody, cancellationToken);
+ }
+ }
+
+ private static bool PrefersHtml(HttpRequest request)
+ {
+ // Inertia AJAX requests expect JSON even though they originate
+ // from a browser, so this check must come before the Accept sniff.
+ if (request.IsInertia())
+ {
+ return false;
+ }
+
+ var accept = request.Headers.Accept.ToString();
+ return accept.Contains("text/html", StringComparison.OrdinalIgnoreCase);
+ }
+ }
}
diff --git a/modules/Admin/src/SimpleModule.Admin/Locales/keys.ts b/modules/Admin/src/SimpleModule.Admin/Locales/keys.ts
index 19bec2c1..994d1f8d 100644
--- a/modules/Admin/src/SimpleModule.Admin/Locales/keys.ts
+++ b/modules/Admin/src/SimpleModule.Admin/Locales/keys.ts
@@ -111,10 +111,10 @@ export const AdminKeys = {
ConfirmDisable2faTitle: 'UsersEdit.ConfirmDisable2faTitle',
ConfirmReverifyAction: 'UsersEdit.ConfirmReverifyAction',
ConfirmReverifyDescription: 'UsersEdit.ConfirmReverifyDescription',
- ConfirmReverifyTitle: 'UsersEdit.ConfirmReverifyTitle',
ConfirmReverifyPhoneAction: 'UsersEdit.ConfirmReverifyPhoneAction',
ConfirmReverifyPhoneDescription: 'UsersEdit.ConfirmReverifyPhoneDescription',
ConfirmReverifyPhoneTitle: 'UsersEdit.ConfirmReverifyPhoneTitle',
+ ConfirmReverifyTitle: 'UsersEdit.ConfirmReverifyTitle',
ConfirmRevokeAllAction: 'UsersEdit.ConfirmRevokeAllAction',
ConfirmRevokeAllDescription: 'UsersEdit.ConfirmRevokeAllDescription',
ConfirmRevokeAllTitle: 'UsersEdit.ConfirmRevokeAllTitle',
@@ -129,11 +129,6 @@ export const AdminKeys = {
EmailVerificationStatus: 'UsersEdit.EmailVerificationStatus',
EmailVerificationTitle: 'UsersEdit.EmailVerificationTitle',
EmailVerified: 'UsersEdit.EmailVerified',
- PhoneVerificationTitle: 'UsersEdit.PhoneVerificationTitle',
- PhoneVerificationStatus: 'UsersEdit.PhoneVerificationStatus',
- PhoneVerified: 'UsersEdit.PhoneVerified',
- PhoneNotVerified: 'UsersEdit.PhoneNotVerified',
- PhoneNotSet: 'UsersEdit.PhoneNotSet',
ErrorPasswordMismatch: 'UsersEdit.ErrorPasswordMismatch',
FailedLoginAttempts: 'UsersEdit.FailedLoginAttempts',
FieldConfirmPassword: 'UsersEdit.FieldConfirmPassword',
@@ -149,6 +144,11 @@ export const AdminKeys = {
LoginInfoTitle: 'UsersEdit.LoginInfoTitle',
NoActiveSessions: 'UsersEdit.NoActiveSessions',
NoRolesDefined: 'UsersEdit.NoRolesDefined',
+ PhoneNotSet: 'UsersEdit.PhoneNotSet',
+ PhoneNotVerified: 'UsersEdit.PhoneNotVerified',
+ PhoneVerificationStatus: 'UsersEdit.PhoneVerificationStatus',
+ PhoneVerificationTitle: 'UsersEdit.PhoneVerificationTitle',
+ PhoneVerified: 'UsersEdit.PhoneVerified',
ReactivateButton: 'UsersEdit.ReactivateButton',
ResetPasswordButton: 'UsersEdit.ResetPasswordButton',
ResetPasswordTitle: 'UsersEdit.ResetPasswordTitle',
diff --git a/modules/OpenIddict/src/SimpleModule.OpenIddict/Endpoints/Connect/TokenEndpoint.cs b/modules/OpenIddict/src/SimpleModule.OpenIddict/Endpoints/Connect/TokenEndpoint.cs
index 2aa8748e..41454b9a 100644
--- a/modules/OpenIddict/src/SimpleModule.OpenIddict/Endpoints/Connect/TokenEndpoint.cs
+++ b/modules/OpenIddict/src/SimpleModule.OpenIddict/Endpoints/Connect/TokenEndpoint.cs
@@ -8,6 +8,7 @@
using OpenIddict.Abstractions;
using OpenIddict.Server.AspNetCore;
using SimpleModule.Core;
+using SimpleModule.Core.RateLimiting;
using SimpleModule.OpenIddict.Contracts;
using SimpleModule.Permissions.Contracts;
using SimpleModule.Users.Contracts;
@@ -24,7 +25,8 @@ public void Map(IEndpointRouteBuilder app)
{
app.MapPost(ConnectRouteConstants.ConnectToken, (Delegate)HandleAsync)
.ExcludeFromDescription()
- .AllowAnonymous();
+ .AllowAnonymous()
+ .RateLimit(RateLimitPolicies.AuthStrict);
}
private static async Task HandleAsync(HttpContext context)
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/CreateRequestValidator.cs b/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/CreateRequestValidator.cs
index 6ebe68fd..639619a1 100644
--- a/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/CreateRequestValidator.cs
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/CreateRequestValidator.cs
@@ -11,5 +11,11 @@ public CreateRequestValidator()
RuleFor(x => x.PermitLimit)
.GreaterThan(0)
.WithMessage("Permit limit must be greater than zero.");
+ RuleFor(x => x.WindowSeconds).GreaterThan(0);
+ RuleFor(x => x.SegmentsPerWindow).GreaterThan(0);
+ RuleFor(x => x.ReplenishmentPeriodSeconds).GreaterThan(0);
+ RuleFor(x => x.TokenLimit).GreaterThan(0);
+ RuleFor(x => x.TokensPerPeriod).GreaterThan(0);
+ RuleFor(x => x.QueueLimit).GreaterThanOrEqualTo(0);
}
}
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateEndpoint.cs b/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateEndpoint.cs
index 91bba12e..7ac3f428 100644
--- a/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateEndpoint.cs
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateEndpoint.cs
@@ -1,8 +1,10 @@
+using FluentValidation;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Routing;
using SimpleModule.Core;
using SimpleModule.Core.Authorization;
using SimpleModule.Core.Endpoints;
+using SimpleModule.Core.Validation;
using SimpleModule.RateLimiting.Contracts;
namespace SimpleModule.RateLimiting.Endpoints.Policies;
@@ -15,10 +17,25 @@ public class UpdateEndpoint : IEndpoint
public void Map(IEndpointRouteBuilder app) =>
app.MapPut(
Route,
- (int id, UpdateRateLimitRuleRequest request, IRateLimitingContracts contracts) =>
- CrudEndpoints.Update(() =>
+ async (
+ int id,
+ UpdateRateLimitRuleRequest request,
+ IValidator validator,
+ IRateLimitingContracts contracts
+ ) =>
+ {
+ var validation = await validator.ValidateAsync(request);
+ if (!validation.IsValid)
+ {
+ throw new Core.Exceptions.ValidationException(
+ validation.ToValidationErrors()
+ );
+ }
+
+ return await CrudEndpoints.Update(() =>
contracts.UpdateRuleAsync(RateLimitRuleId.From(id), request)
- )
+ );
+ }
)
.RequirePermission(RateLimitingPermissions.Update);
}
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateRequestValidator.cs b/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateRequestValidator.cs
new file mode 100644
index 00000000..8fc5318a
--- /dev/null
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/Endpoints/Policies/UpdateRequestValidator.cs
@@ -0,0 +1,18 @@
+using FluentValidation;
+using SimpleModule.RateLimiting.Contracts;
+
+namespace SimpleModule.RateLimiting.Endpoints.Policies;
+
+public sealed class UpdateRequestValidator : AbstractValidator
+{
+ public UpdateRequestValidator()
+ {
+ RuleFor(x => x.PermitLimit).GreaterThan(0);
+ RuleFor(x => x.WindowSeconds).GreaterThan(0);
+ RuleFor(x => x.SegmentsPerWindow).GreaterThan(0);
+ RuleFor(x => x.ReplenishmentPeriodSeconds).GreaterThan(0);
+ RuleFor(x => x.TokenLimit).GreaterThan(0);
+ RuleFor(x => x.TokensPerPeriod).GreaterThan(0);
+ RuleFor(x => x.QueueLimit).GreaterThanOrEqualTo(0);
+ }
+}
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitRuleCache.cs b/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitRuleCache.cs
new file mode 100644
index 00000000..b42ae83b
--- /dev/null
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitRuleCache.cs
@@ -0,0 +1,144 @@
+using Microsoft.AspNetCore.Http;
+using Microsoft.EntityFrameworkCore;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using SimpleModule.Core.RateLimiting;
+using SimpleModule.RateLimiting.Contracts;
+
+namespace SimpleModule.RateLimiting;
+
+///
+/// Loads enabled s from the database into an
+/// immutable snapshot consulted by the global rate limiter. Rebuilt on
+/// startup and after every admin write via .
+///
+internal sealed partial class RateLimitRuleCache(
+ IServiceScopeFactory scopeFactory,
+ ILogger logger
+) : IRateLimitRuleSource, IHostedService
+{
+ // Volatile so non-x86 readers see the swapped reference promptly.
+ private volatile CompiledRule[] _rules = [];
+
+ public Task StartAsync(CancellationToken cancellationToken) => RefreshAsync(cancellationToken);
+
+ public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+
+ public async Task RefreshAsync(CancellationToken cancellationToken = default)
+ {
+ await using var scope = scopeFactory.CreateAsyncScope();
+ var db = scope.ServiceProvider.GetRequiredService();
+
+ var rules = await db
+ .Rules.AsNoTracking()
+ .Where(r => r.IsEnabled)
+ .ToListAsync(cancellationToken);
+
+ var compiled = rules
+ .Select(Compile)
+ .OrderByDescending(r => Specificity(r.Prefix, r.Kind))
+ .ToArray();
+
+ _rules = compiled;
+ LogRefreshed(logger, compiled.Length);
+ }
+
+ public RateLimitPolicyDefinition? FindForPath(PathString path)
+ {
+ var snapshot = _rules;
+ if (snapshot.Length == 0)
+ {
+ return null;
+ }
+
+ var pathStr = path.HasValue ? path.Value! : "/";
+ foreach (var rule in snapshot)
+ {
+ if (rule.Matches(pathStr))
+ {
+ return rule.Policy;
+ }
+ }
+
+ return null;
+ }
+
+ private static CompiledRule Compile(RateLimitRule rule)
+ {
+ var pattern = string.IsNullOrWhiteSpace(rule.EndpointPattern) ? "*" : rule.EndpointPattern!;
+
+ var policy = new RateLimitPolicyDefinition
+ {
+ Name = rule.PolicyName,
+ PolicyType = rule.PolicyType,
+ Target = rule.Target,
+ PermitLimit = rule.PermitLimit,
+ Window = TimeSpan.FromSeconds(rule.WindowSeconds),
+ SegmentsPerWindow = rule.SegmentsPerWindow,
+ TokenLimit = rule.TokenLimit,
+ TokensPerPeriod = rule.TokensPerPeriod,
+ ReplenishmentPeriod = TimeSpan.FromSeconds(rule.ReplenishmentPeriodSeconds),
+ QueueLimit = rule.QueueLimit,
+ };
+
+ if (pattern == "*")
+ {
+ return new CompiledRule(MatchKind.CatchAll, "", "", policy);
+ }
+
+ var wildcard = pattern.IndexOf('*', StringComparison.Ordinal);
+ return wildcard < 0
+ ? new CompiledRule(MatchKind.Exact, pattern, "", policy)
+ : new CompiledRule(
+ MatchKind.Wildcard,
+ pattern[..wildcard],
+ pattern[(wildcard + 1)..],
+ policy
+ );
+ }
+
+ ///
+ /// Higher score = more specific. Catch-all scores zero so it always loses
+ /// to any concrete pattern; longer prefixes outrank shorter ones.
+ ///
+ private static int Specificity(string prefix, MatchKind kind) =>
+ kind == MatchKind.CatchAll ? 0 : prefix.Length;
+
+ private enum MatchKind
+ {
+ CatchAll,
+ Exact,
+ Wildcard,
+ }
+
+ private readonly record struct CompiledRule(
+ MatchKind Kind,
+ string Prefix,
+ string Suffix,
+ RateLimitPolicyDefinition Policy
+ )
+ {
+ public bool Matches(string path) =>
+ Kind switch
+ {
+ MatchKind.CatchAll => true,
+ MatchKind.Exact => string.Equals(path, Prefix, StringComparison.OrdinalIgnoreCase),
+ MatchKind.Wildcard => path.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)
+ && (
+ Suffix.Length == 0
+ || (
+ path.Length >= Prefix.Length + Suffix.Length
+ && path.EndsWith(Suffix, StringComparison.OrdinalIgnoreCase)
+ )
+ ),
+ _ => false,
+ };
+ }
+
+ [LoggerMessage(
+ Level = LogLevel.Information,
+ Message = "Rate-limit rule cache refreshed: {Count} enabled rules loaded"
+ )]
+ private static partial void LogRefreshed(ILogger logger, int count);
+}
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingModule.cs b/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingModule.cs
index b9627df9..6c2d63f9 100644
--- a/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingModule.cs
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingModule.cs
@@ -22,6 +22,11 @@ public void ConfigureServices(IServiceCollection services, IConfiguration config
RateLimitingConstants.ModuleName
);
services.AddValidatorsFromAssemblyContaining();
+ services.AddSingleton();
+ services.AddSingleton(sp =>
+ sp.GetRequiredService()
+ );
+ services.AddHostedService(sp => sp.GetRequiredService());
}
public void ConfigureRateLimits(IRateLimitBuilder builder)
@@ -30,7 +35,7 @@ public void ConfigureRateLimits(IRateLimitBuilder builder)
.Add(
new RateLimitPolicyDefinition
{
- Name = "fixed-default",
+ Name = RateLimitPolicies.FixedDefault,
PolicyType = RateLimitPolicyType.FixedWindow,
Target = RateLimitTarget.Ip,
PermitLimit = 60,
@@ -40,7 +45,7 @@ public void ConfigureRateLimits(IRateLimitBuilder builder)
.Add(
new RateLimitPolicyDefinition
{
- Name = "sliding-strict",
+ Name = RateLimitPolicies.SlidingStrict,
PolicyType = RateLimitPolicyType.SlidingWindow,
Target = RateLimitTarget.IpAndUser,
PermitLimit = 30,
@@ -51,7 +56,7 @@ public void ConfigureRateLimits(IRateLimitBuilder builder)
.Add(
new RateLimitPolicyDefinition
{
- Name = "token-bucket",
+ Name = RateLimitPolicies.TokenBucket,
PolicyType = RateLimitPolicyType.TokenBucket,
Target = RateLimitTarget.Ip,
TokenLimit = 100,
@@ -62,7 +67,7 @@ public void ConfigureRateLimits(IRateLimitBuilder builder)
.Add(
new RateLimitPolicyDefinition
{
- Name = "auth-strict",
+ Name = RateLimitPolicies.AuthStrict,
PolicyType = RateLimitPolicyType.FixedWindow,
Target = RateLimitTarget.Ip,
PermitLimit = 10,
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingService.cs b/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingService.cs
index 64e51bef..1199dd35 100644
--- a/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingService.cs
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/RateLimitingService.cs
@@ -1,11 +1,13 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
+using SimpleModule.Core.RateLimiting;
using SimpleModule.RateLimiting.Contracts;
namespace SimpleModule.RateLimiting;
public partial class RateLimitingService(
RateLimitingDbContext db,
+ IRateLimitRuleSource ruleSource,
ILogger logger
) : IRateLimitingContracts
{
@@ -43,6 +45,7 @@ public async Task CreateRuleAsync(CreateRateLimitRuleRequest requ
db.Rules.Add(rule);
await db.SaveChangesAsync();
+ await ruleSource.RefreshAsync();
LogRuleCreated(logger, rule.Id, rule.PolicyName);
@@ -73,6 +76,7 @@ UpdateRateLimitRuleRequest request
rule.IsEnabled = request.IsEnabled;
await db.SaveChangesAsync();
+ await ruleSource.RefreshAsync();
LogRuleUpdated(logger, rule.Id, rule.PolicyName);
@@ -89,6 +93,7 @@ public async Task DeleteRuleAsync(RateLimitRuleId id)
db.Rules.Remove(rule);
await db.SaveChangesAsync();
+ await ruleSource.RefreshAsync();
LogRuleDeleted(logger, id);
}
diff --git a/modules/RateLimiting/src/SimpleModule.RateLimiting/SimpleModule.RateLimiting.csproj b/modules/RateLimiting/src/SimpleModule.RateLimiting/SimpleModule.RateLimiting.csproj
index a1081fba..81361c14 100644
--- a/modules/RateLimiting/src/SimpleModule.RateLimiting/SimpleModule.RateLimiting.csproj
+++ b/modules/RateLimiting/src/SimpleModule.RateLimiting/SimpleModule.RateLimiting.csproj
@@ -8,6 +8,7 @@
+
diff --git a/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitRuleCacheTests.cs b/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitRuleCacheTests.cs
new file mode 100644
index 00000000..9993e5c3
--- /dev/null
+++ b/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitRuleCacheTests.cs
@@ -0,0 +1,154 @@
+using FluentAssertions;
+using Microsoft.AspNetCore.Http;
+using Microsoft.EntityFrameworkCore;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Options;
+using SimpleModule.Core.RateLimiting;
+using SimpleModule.Database;
+using SimpleModule.RateLimiting.Contracts;
+
+namespace SimpleModule.RateLimiting.Tests;
+
+public sealed class RateLimitRuleCacheTests : IAsyncLifetime, IDisposable
+{
+ private readonly RateLimitingDbContext _db;
+ private readonly RateLimitRuleCache _cache;
+ private readonly ServiceProvider _services;
+
+ public RateLimitRuleCacheTests()
+ {
+ var dbOptions = new DbContextOptionsBuilder()
+ .UseSqlite("Data Source=:memory:")
+ .Options;
+ var databaseOptions = Options.Create(
+ new DatabaseOptions
+ {
+ ModuleConnections = new Dictionary
+ {
+ ["RateLimiting"] = "Data Source=:memory:",
+ },
+ }
+ );
+ _db = new RateLimitingDbContext(dbOptions, databaseOptions);
+ _db.Database.OpenConnection();
+ _db.Database.EnsureCreated();
+
+ var services = new ServiceCollection();
+ services.AddSingleton(_db);
+ _services = services.BuildServiceProvider();
+
+ _cache = new RateLimitRuleCache(
+ _services.GetRequiredService(),
+ NullLogger.Instance
+ );
+ }
+
+ public ValueTask InitializeAsync() => default;
+
+ public ValueTask DisposeAsync() => default;
+
+ [Fact]
+ public async Task FindForPath_ReturnsNull_BeforeRefresh()
+ {
+ _cache.FindForPath("/api/users").Should().BeNull();
+ await Task.CompletedTask;
+ }
+
+ [Fact]
+ public async Task FindForPath_ReturnsNull_WhenNoRulesEnabled()
+ {
+ _db.Rules.Add(NewRule("disabled-rule", "/api/users", isEnabled: false));
+ await _db.SaveChangesAsync();
+
+ await _cache.RefreshAsync();
+
+ _cache.FindForPath("/api/users").Should().BeNull();
+ }
+
+ [Fact]
+ public async Task FindForPath_MatchesExactPattern()
+ {
+ _db.Rules.Add(NewRule("exact", "/api/users"));
+ await _db.SaveChangesAsync();
+
+ await _cache.RefreshAsync();
+
+ _cache.FindForPath("/api/users").Should().NotBeNull();
+ _cache.FindForPath("/api/users/123").Should().BeNull();
+ }
+
+ [Fact]
+ public async Task FindForPath_MatchesPrefixWildcard()
+ {
+ _db.Rules.Add(NewRule("prefix", "/api/users/*"));
+ await _db.SaveChangesAsync();
+
+ await _cache.RefreshAsync();
+
+ _cache.FindForPath("/api/users/123").Should().NotBeNull();
+ _cache.FindForPath("/api/users/").Should().NotBeNull();
+ _cache.FindForPath("/api/orders").Should().BeNull();
+ }
+
+ [Fact]
+ public async Task FindForPath_MatchesBareWildcard()
+ {
+ _db.Rules.Add(NewRule("catchall", "*"));
+ await _db.SaveChangesAsync();
+
+ await _cache.RefreshAsync();
+
+ _cache.FindForPath("/literally/anything").Should().NotBeNull();
+ }
+
+ [Fact]
+ public async Task FindForPath_PrefersMoreSpecificRule()
+ {
+ _db.Rules.AddRange(
+ NewRule("catchall", "*", permitLimit: 10),
+ NewRule("specific", "/api/users", permitLimit: 99)
+ );
+ await _db.SaveChangesAsync();
+
+ await _cache.RefreshAsync();
+
+ _cache.FindForPath("/api/users")!.PermitLimit.Should().Be(99);
+ _cache.FindForPath("/api/orders")!.PermitLimit.Should().Be(10);
+ }
+
+ [Fact]
+ public async Task RefreshAsync_PicksUpNewRules()
+ {
+ await _cache.RefreshAsync();
+ _cache.FindForPath("/api/things").Should().BeNull();
+
+ _db.Rules.Add(NewRule("late", "/api/things"));
+ await _db.SaveChangesAsync();
+ await _cache.RefreshAsync();
+
+ _cache.FindForPath("/api/things").Should().NotBeNull();
+ }
+
+ private static RateLimitRule NewRule(
+ string name,
+ string? pattern,
+ bool isEnabled = true,
+ int permitLimit = 60
+ ) =>
+ new()
+ {
+ PolicyName = name,
+ EndpointPattern = pattern,
+ IsEnabled = isEnabled,
+ PermitLimit = permitLimit,
+ };
+
+ public void Dispose()
+ {
+ _db.Database.CloseConnection();
+ _db.Dispose();
+ _services.Dispose();
+ GC.SuppressFinalize(this);
+ }
+}
diff --git a/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingEndpointTests.cs b/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingEndpointTests.cs
index 4ab01c3e..a569a5cd 100644
--- a/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingEndpointTests.cs
+++ b/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingEndpointTests.cs
@@ -1,6 +1,7 @@
using System.Net;
using System.Security.Claims;
using FluentAssertions;
+using SimpleModule.Core.RateLimiting;
using SimpleModule.Tests.Shared.Fixtures;
namespace SimpleModule.RateLimiting.Tests;
@@ -34,7 +35,7 @@ public async Task ActivePoliciesApi_ReturnsOk_ForAuthenticatedAdmin()
response.StatusCode.Should().Be(HttpStatusCode.OK);
var content = await response.Content.ReadAsStringAsync();
- content.Should().Contain("fixed-default");
+ content.Should().Contain(RateLimitPolicies.FixedDefault);
}
[Fact]
diff --git a/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingServiceTests.cs b/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingServiceTests.cs
index 0361f0be..b9bf58ff 100644
--- a/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingServiceTests.cs
+++ b/modules/RateLimiting/tests/SimpleModule.RateLimiting.Tests/RateLimitingServiceTests.cs
@@ -1,4 +1,5 @@
using FluentAssertions;
+using Microsoft.AspNetCore.Http;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
@@ -12,6 +13,7 @@ public sealed class RateLimitingServiceTests : IDisposable
{
private readonly RateLimitingDbContext _db;
private readonly RateLimitingService _service;
+ private readonly RecordingRuleSource _ruleSource = new();
public RateLimitingServiceTests()
{
@@ -30,7 +32,11 @@ public RateLimitingServiceTests()
_db = new RateLimitingDbContext(dbOptions, databaseOptions);
_db.Database.OpenConnection();
_db.Database.EnsureCreated();
- _service = new RateLimitingService(_db, NullLogger.Instance);
+ _service = new RateLimitingService(
+ _db,
+ _ruleSource,
+ NullLogger.Instance
+ );
}
[Fact]
@@ -136,10 +142,57 @@ public async Task UpdateRuleAsync_ShouldThrow_WhenNotFound()
await act.Should().ThrowAsync();
}
+ [Fact]
+ public async Task CreateRuleAsync_ShouldRefreshRuleSource()
+ {
+ await _service.CreateRuleAsync(new CreateRateLimitRuleRequest { PolicyName = "x" });
+
+ _ruleSource.RefreshCount.Should().Be(1);
+ }
+
+ [Fact]
+ public async Task UpdateRuleAsync_ShouldRefreshRuleSource()
+ {
+ var created = await _service.CreateRuleAsync(
+ new CreateRateLimitRuleRequest { PolicyName = "x" }
+ );
+ _ruleSource.RefreshCount = 0;
+
+ await _service.UpdateRuleAsync(created.Id, new UpdateRateLimitRuleRequest());
+
+ _ruleSource.RefreshCount.Should().Be(1);
+ }
+
+ [Fact]
+ public async Task DeleteRuleAsync_ShouldRefreshRuleSource()
+ {
+ var created = await _service.CreateRuleAsync(
+ new CreateRateLimitRuleRequest { PolicyName = "x" }
+ );
+ _ruleSource.RefreshCount = 0;
+
+ await _service.DeleteRuleAsync(created.Id);
+
+ _ruleSource.RefreshCount.Should().Be(1);
+ }
+
public void Dispose()
{
_db.Database.CloseConnection();
_db.Dispose();
GC.SuppressFinalize(this);
}
+
+ private sealed class RecordingRuleSource : IRateLimitRuleSource
+ {
+ public int RefreshCount { get; set; }
+
+ public RateLimitPolicyDefinition? FindForPath(PathString path) => null;
+
+ public Task RefreshAsync(CancellationToken cancellationToken = default)
+ {
+ RefreshCount++;
+ return Task.CompletedTask;
+ }
+ }
}
diff --git a/modules/Users/src/SimpleModule.Users/Pages/Account/ForgotPasswordEndpoint.cs b/modules/Users/src/SimpleModule.Users/Pages/Account/ForgotPasswordEndpoint.cs
index 4d7244e1..d78db359 100644
--- a/modules/Users/src/SimpleModule.Users/Pages/Account/ForgotPasswordEndpoint.cs
+++ b/modules/Users/src/SimpleModule.Users/Pages/Account/ForgotPasswordEndpoint.cs
@@ -7,6 +7,7 @@
using Microsoft.AspNetCore.WebUtilities;
using SimpleModule.Core;
using SimpleModule.Core.Inertia;
+using SimpleModule.Core.RateLimiting;
using SimpleModule.Users.Contracts;
namespace SimpleModule.Users.Pages.Account;
@@ -49,6 +50,7 @@ HttpContext context
}
)
.AllowAnonymous()
- .DisableAntiforgery();
+ .DisableAntiforgery()
+ .RateLimit(RateLimitPolicies.AuthStrict);
}
}
diff --git a/modules/Users/src/SimpleModule.Users/Pages/Account/LoginEndpoint.cs b/modules/Users/src/SimpleModule.Users/Pages/Account/LoginEndpoint.cs
index 7bf30c4c..ccbda1c0 100644
--- a/modules/Users/src/SimpleModule.Users/Pages/Account/LoginEndpoint.cs
+++ b/modules/Users/src/SimpleModule.Users/Pages/Account/LoginEndpoint.cs
@@ -8,6 +8,7 @@
using Microsoft.Extensions.Options;
using SimpleModule.Core;
using SimpleModule.Core.Inertia;
+using SimpleModule.Core.RateLimiting;
using SimpleModule.Core.Settings;
using SimpleModule.Settings.Contracts;
using SimpleModule.Users.Constants;
@@ -117,6 +118,7 @@ IOptions passkeyOptions
}
)
.AllowAnonymous()
- .DisableAntiforgery();
+ .DisableAntiforgery()
+ .RateLimit(RateLimitPolicies.AuthStrict);
}
}
diff --git a/modules/Users/src/SimpleModule.Users/Pages/Account/ResetPasswordEndpoint.cs b/modules/Users/src/SimpleModule.Users/Pages/Account/ResetPasswordEndpoint.cs
index 4580d6dc..e9bd9544 100644
--- a/modules/Users/src/SimpleModule.Users/Pages/Account/ResetPasswordEndpoint.cs
+++ b/modules/Users/src/SimpleModule.Users/Pages/Account/ResetPasswordEndpoint.cs
@@ -7,6 +7,7 @@
using Microsoft.AspNetCore.WebUtilities;
using SimpleModule.Core;
using SimpleModule.Core.Inertia;
+using SimpleModule.Core.RateLimiting;
using SimpleModule.Users.Contracts;
namespace SimpleModule.Users.Pages.Account;
@@ -72,6 +73,7 @@ UserManager userManager
}
)
.AllowAnonymous()
- .DisableAntiforgery();
+ .DisableAntiforgery()
+ .RateLimit(RateLimitPolicies.AuthStrict);
}
}
diff --git a/modules/Users/src/SimpleModule.Users/Pages/Account/SendUnlockEmailEndpoint.cs b/modules/Users/src/SimpleModule.Users/Pages/Account/SendUnlockEmailEndpoint.cs
index d5a3d5a5..ee1b3b26 100644
--- a/modules/Users/src/SimpleModule.Users/Pages/Account/SendUnlockEmailEndpoint.cs
+++ b/modules/Users/src/SimpleModule.Users/Pages/Account/SendUnlockEmailEndpoint.cs
@@ -59,6 +59,6 @@ user is not null
)
.AllowAnonymous()
.DisableAntiforgery()
- .RateLimit("auth-strict");
+ .RateLimit(RateLimitPolicies.AuthStrict);
}
}