mirror of
https://github.com/orris-inc/orris.git
synced 2026-05-06 21:44:01 +08:00
fix: enforce traffic period filtering in all limit check paths
Traffic reset was not taking effect because multiple enforcement paths ignored the subscription's currentPeriodStart when calculating usage: - ResolveTrafficPeriod now uses currentPeriodStart as a floor for calendar_month mode, so manual resets exclude pre-reset traffic - Node and forward enforcement services now query traffic within the resolved period instead of from time zero - NodeSubscriptionUsageReaderAdapter Redis query respects periodStart Also refactors node subscription generation and token validation into infrastructure layer, adds include_counts support for admin subscription list endpoint, and extracts shared nodeutil package.
This commit is contained in:
@@ -74,7 +74,7 @@ func (s *TrafficLimitEnforcementService) CheckAndEnforceLimit(ctx context.Contex
|
||||
|
||||
// Find the highest traffic limit across all Forward-type subscriptions
|
||||
// and collect their subscription IDs for traffic query
|
||||
trafficLimit, hasLimit, forwardSubscriptionIDs, err := s.getHighestTrafficLimitAndIDs(ctx, activeSubscriptions)
|
||||
trafficLimit, hasLimit, forwardSubscriptionIDs, periodStart, err := s.getHighestTrafficLimitAndIDs(ctx, activeSubscriptions)
|
||||
if err != nil {
|
||||
s.logger.Errorw("failed to determine traffic limit",
|
||||
"user_id", userID,
|
||||
@@ -99,8 +99,8 @@ func (s *TrafficLimitEnforcementService) CheckAndEnforceLimit(ctx context.Contex
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get user's total forward traffic by combining Redis (recent 24h) and MySQL (historical)
|
||||
usedTraffic, err := s.getTotalTrafficForSubscriptions(ctx, forwardSubscriptionIDs)
|
||||
// Get user's total forward traffic within the resolved traffic period
|
||||
usedTraffic, err := s.getTotalTrafficForSubscriptions(ctx, forwardSubscriptionIDs, periodStart)
|
||||
if err != nil {
|
||||
s.logger.Errorw("failed to get total traffic for user",
|
||||
"user_id", userID,
|
||||
@@ -262,12 +262,14 @@ func (s *TrafficLimitEnforcementService) OnTrafficUpdate(ctx context.Context, ru
|
||||
|
||||
// getHighestTrafficLimitAndIDs returns the highest traffic limit across all Forward-type subscriptions
|
||||
// and collects their subscription IDs for traffic query.
|
||||
// Returns (limit, hasLimit, subscriptionIDs, error) where hasLimit is false if any subscription has unlimited traffic.
|
||||
// Returns (limit, hasLimit, subscriptionIDs, periodStart, error) where hasLimit is false if any subscription has unlimited traffic.
|
||||
// periodStart is the latest traffic period start across all forward subscriptions (respects manual resets).
|
||||
// Only considers subscriptions with PlanType = "forward".
|
||||
func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx context.Context, subscriptions []*subscription.Subscription) (uint64, bool, []uint, error) {
|
||||
func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx context.Context, subscriptions []*subscription.Subscription) (uint64, bool, []uint, time.Time, error) {
|
||||
var highestLimit uint64
|
||||
hasLimit := false
|
||||
var forwardSubscriptionIDs []uint
|
||||
var latestPeriodStart time.Time
|
||||
|
||||
// Collect plan IDs for batch query
|
||||
planIDs := make([]uint, 0, len(subscriptions))
|
||||
@@ -279,7 +281,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
|
||||
plansList, err := s.planRepo.GetByIDs(ctx, planIDs)
|
||||
if err != nil {
|
||||
s.logger.Errorw("failed to batch fetch plans", "error", err)
|
||||
return 0, false, nil, err
|
||||
return 0, false, nil, time.Time{}, err
|
||||
}
|
||||
|
||||
// Convert to map for quick lookup
|
||||
@@ -312,6 +314,13 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
|
||||
// Collect Forward-type subscription ID
|
||||
forwardSubscriptionIDs = append(forwardSubscriptionIDs, sub.ID())
|
||||
|
||||
// Resolve traffic period and track the latest period start across all forward subscriptions.
|
||||
// This ensures manual usage resets are respected: traffic before the reset is excluded.
|
||||
period := subscription.ResolveTrafficPeriod(plan, sub)
|
||||
if latestPeriodStart.IsZero() || period.Start.After(latestPeriodStart) {
|
||||
latestPeriodStart = period.Start
|
||||
}
|
||||
|
||||
// Determine traffic limit: subscription override takes priority over plan
|
||||
var limit uint64
|
||||
if sub.TrafficLimitOverride() != nil {
|
||||
@@ -323,7 +332,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
|
||||
"subscription_id", sub.ID(),
|
||||
"plan_id", sub.PlanID(),
|
||||
)
|
||||
return 0, false, forwardSubscriptionIDs, nil // Unlimited traffic - don't enforce
|
||||
return 0, false, forwardSubscriptionIDs, latestPeriodStart, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
@@ -344,7 +353,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
|
||||
"subscription_id", sub.ID(),
|
||||
"plan_id", sub.PlanID(),
|
||||
)
|
||||
return 0, false, forwardSubscriptionIDs, nil
|
||||
return 0, false, forwardSubscriptionIDs, latestPeriodStart, nil
|
||||
}
|
||||
|
||||
// Track the highest limit
|
||||
@@ -354,7 +363,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
|
||||
}
|
||||
}
|
||||
|
||||
return highestLimit, hasLimit, forwardSubscriptionIDs, nil
|
||||
return highestLimit, hasLimit, forwardSubscriptionIDs, latestPeriodStart, nil
|
||||
}
|
||||
|
||||
// getForwardTrafficLimit extracts the traffic limit from a plan.
|
||||
@@ -370,10 +379,10 @@ func (s *TrafficLimitEnforcementService) getForwardTrafficLimit(plan *subscripti
|
||||
}
|
||||
|
||||
// getTotalTrafficForSubscriptions calculates total traffic for given subscription IDs
|
||||
// by combining data from two sources:
|
||||
// - Last 24 hours: from Redis HourlyTrafficCache
|
||||
// - Before 24 hours: from MySQL subscription_usage_stats table
|
||||
func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx context.Context, subscriptionIDs []uint) (uint64, error) {
|
||||
// within the given traffic period by combining data from two sources:
|
||||
// - Recent window: from Redis HourlyTrafficCache
|
||||
// - Historical window: from MySQL subscription_usage_stats table
|
||||
func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx context.Context, subscriptionIDs []uint, periodStart time.Time) (uint64, error) {
|
||||
if len(subscriptionIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -386,10 +395,15 @@ func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx con
|
||||
|
||||
var total uint64
|
||||
|
||||
// Get recent traffic from Redis (yesterday + today, filter by forward_rule type)
|
||||
// Determine Redis query range: max(recentBoundary, periodStart) to now
|
||||
// This ensures traffic before periodStart (e.g. after manual reset) is excluded.
|
||||
resourceType := subscription.ResourceTypeForwardRule.String()
|
||||
redisFrom := recentBoundary
|
||||
if periodStart.After(redisFrom) {
|
||||
redisFrom = periodStart
|
||||
}
|
||||
recentTraffic, err := s.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs(
|
||||
ctx, subscriptionIDs, resourceType, recentBoundary, now,
|
||||
ctx, subscriptionIDs, resourceType, redisFrom, now,
|
||||
)
|
||||
if err != nil {
|
||||
// Log warning but don't fail - Redis unavailability shouldn't block limit checks
|
||||
@@ -403,16 +417,21 @@ func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx con
|
||||
for _, traffic := range recentTraffic {
|
||||
total += traffic.Total
|
||||
}
|
||||
s.logger.Debugw("got recent 24h traffic from Redis",
|
||||
s.logger.Debugw("got recent traffic from Redis",
|
||||
"subscription_ids_count", len(subscriptionIDs),
|
||||
"recent_total", total,
|
||||
)
|
||||
}
|
||||
|
||||
// Get historical traffic from MySQL subscription_usage_stats (complete days before yesterday)
|
||||
// Use daily granularity for historical aggregation, filter by forward_rule type
|
||||
// Get historical traffic from MySQL subscription_usage_stats (within period, before recentBoundary)
|
||||
historicalEnd := recentBoundary.Add(-time.Second)
|
||||
if historicalEnd.Before(periodStart) {
|
||||
// Period started after recentBoundary, no historical data needed
|
||||
return total, nil
|
||||
}
|
||||
|
||||
historicalTraffic, err := s.usageStatsRepo.GetTotalBySubscriptionIDs(
|
||||
ctx, subscriptionIDs, &resourceType, subscription.GranularityDaily, time.Time{}, recentBoundary.Add(-time.Second),
|
||||
ctx, subscriptionIDs, &resourceType, subscription.GranularityDaily, periodStart, historicalEnd,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Warnw("failed to get historical traffic from stats, using Redis data only",
|
||||
|
||||
@@ -146,8 +146,11 @@ func (s *NodeTrafficLimitEnforcementService) CheckAndEnforceLimitForNode(ctx con
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get total traffic for this subscription
|
||||
usedTraffic, err := s.getTotalTrafficForSubscription(ctx, subscriptionID)
|
||||
// Resolve traffic period so we only count traffic within the current period
|
||||
period := subscription.ResolveTrafficPeriod(plan, sub)
|
||||
|
||||
// Get total traffic for this subscription within the resolved period
|
||||
usedTraffic, err := s.getTotalTrafficForSubscription(ctx, subscriptionID, period.Start)
|
||||
if err != nil {
|
||||
s.logger.Errorw("failed to get total traffic",
|
||||
"subscription_id", subscriptionID,
|
||||
@@ -234,10 +237,10 @@ func (s *NodeTrafficLimitEnforcementService) CheckAndEnforceLimitForNode(ctx con
|
||||
}
|
||||
|
||||
// getTotalTrafficForSubscription calculates total traffic for a subscription
|
||||
// by combining data from two sources:
|
||||
// - Last 24 hours: from Redis HourlyTrafficCache
|
||||
// - Before 24 hours: from MySQL subscription_usage_stats table
|
||||
func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx context.Context, subscriptionID uint) (uint64, error) {
|
||||
// within the given traffic period by combining data from two sources:
|
||||
// - Recent window: from Redis HourlyTrafficCache
|
||||
// - Historical window: from MySQL subscription_usage_stats table
|
||||
func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx context.Context, subscriptionID uint, periodStart time.Time) (uint64, error) {
|
||||
now := biztime.NowUTC()
|
||||
|
||||
// Use start of yesterday's business day as batch/speed boundary (Lambda architecture)
|
||||
@@ -246,10 +249,15 @@ func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx
|
||||
|
||||
var total uint64
|
||||
|
||||
// Get recent traffic from Redis (yesterday + today, filter by node type)
|
||||
// Determine Redis query range: max(recentBoundary, periodStart) to now
|
||||
// This ensures traffic before periodStart (e.g. after manual reset) is excluded.
|
||||
resourceType := subscription.ResourceTypeNode.String()
|
||||
redisFrom := recentBoundary
|
||||
if periodStart.After(redisFrom) {
|
||||
redisFrom = periodStart
|
||||
}
|
||||
recentTraffic, err := s.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs(
|
||||
ctx, []uint{subscriptionID}, resourceType, recentBoundary, now,
|
||||
ctx, []uint{subscriptionID}, resourceType, redisFrom, now,
|
||||
)
|
||||
if err != nil {
|
||||
// Log warning but don't fail - Redis unavailability shouldn't block limit checks
|
||||
@@ -263,16 +271,21 @@ func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx
|
||||
for _, traffic := range recentTraffic {
|
||||
total += traffic.Total
|
||||
}
|
||||
s.logger.Debugw("got recent 24h traffic from Redis",
|
||||
s.logger.Debugw("got recent traffic from Redis",
|
||||
"subscription_id", subscriptionID,
|
||||
"recent_total", total,
|
||||
)
|
||||
}
|
||||
|
||||
// Get historical traffic from MySQL subscription_usage_stats (complete days before yesterday)
|
||||
// Use daily granularity for historical aggregation, filter by node type
|
||||
// Get historical traffic from MySQL subscription_usage_stats (within period, before recentBoundary)
|
||||
historicalEnd := recentBoundary.Add(-time.Second)
|
||||
if historicalEnd.Before(periodStart) {
|
||||
// Period started after recentBoundary, no historical data needed
|
||||
return total, nil
|
||||
}
|
||||
|
||||
historicalTraffic, err := s.usageStatsRepo.GetTotalBySubscriptionIDs(
|
||||
ctx, []uint{subscriptionID}, &resourceType, subscription.GranularityDaily, time.Time{}, recentBoundary.Add(-time.Second),
|
||||
ctx, []uint{subscriptionID}, &resourceType, subscription.GranularityDaily, periodStart, historicalEnd,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Warnw("failed to get historical traffic from stats, using Redis data only",
|
||||
|
||||
35
internal/application/node/usecases/quotacache.go
Normal file
35
internal/application/node/usecases/quotacache.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CachedQuotaInfo represents the cached subscription quota information.
|
||||
// This mirrors cache.CachedQuota to avoid import cycle.
|
||||
type CachedQuotaInfo struct {
|
||||
Limit int64 // Traffic limit in bytes
|
||||
PeriodStart time.Time // Billing period start
|
||||
PeriodEnd time.Time // Billing period end
|
||||
PlanType string // node/forward/hybrid
|
||||
Suspended bool // Whether the subscription is suspended
|
||||
NotFound bool // Null marker: subscription confirmed not found/inactive in DB
|
||||
}
|
||||
|
||||
// NodeSubscriptionQuotaCache defines the interface for subscription quota caching.
|
||||
type NodeSubscriptionQuotaCache interface {
|
||||
// GetQuota retrieves subscription quota information from cache.
|
||||
// Returns nil if cache does not exist.
|
||||
GetQuota(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
|
||||
|
||||
// MarkSuspended marks the subscription as suspended in cache.
|
||||
MarkSuspended(ctx context.Context, subscriptionID uint) error
|
||||
}
|
||||
|
||||
// NodeSubscriptionQuotaLoader defines the interface for lazy loading subscription quota.
|
||||
// This is used when quota cache miss occurs to load quota from database.
|
||||
type NodeSubscriptionQuotaLoader interface {
|
||||
// LoadQuotaByID loads subscription quota from database and caches it.
|
||||
// Returns the cached quota info, or nil if subscription/plan not found.
|
||||
LoadQuotaByID(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
|
||||
}
|
||||
@@ -89,6 +89,16 @@ type SubscriptionTokenDTO struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SubscriptionStatusCounts represents subscription status aggregation counts.
|
||||
// Used by admin subscription list to return status counts alongside the list response.
|
||||
type SubscriptionStatusCounts struct {
|
||||
Active int64 `json:"active"`
|
||||
Expired int64 `json:"expired"`
|
||||
Suspended int64 `json:"suspended"`
|
||||
PendingPayment int64 `json:"pending_payment"`
|
||||
ExpiringIn7Days int64 `json:"expiring_in_7_days"`
|
||||
}
|
||||
|
||||
var (
|
||||
// SubscriptionMapper is a generic mapper for basic conversions.
|
||||
// WARNING: Does not populate OnlineDeviceCount, DeviceLimit, DataUsedBytes, or DataLimitBytes.
|
||||
|
||||
@@ -241,8 +241,13 @@ func (uc *CreateSubscriptionUseCase) Execute(ctx context.Context, cmd CreateSubs
|
||||
}
|
||||
|
||||
func (uc *CreateSubscriptionUseCase) calculateEndDate(startDate time.Time, billingCycle vo.BillingCycle) time.Time {
|
||||
// Use fixed days to ensure consistent subscription periods
|
||||
// This prevents "drifting" when starting on month boundaries (e.g., Jan 31 -> Feb 28 -> Mar 28)
|
||||
return CalculateEndDate(startDate, billingCycle)
|
||||
}
|
||||
|
||||
// CalculateEndDate computes the subscription end date from a start date and billing cycle.
|
||||
// Uses fixed days to ensure consistent subscription periods and prevent "drifting"
|
||||
// when starting on month boundaries (e.g., Jan 31 -> Feb 28 -> Mar 28).
|
||||
func CalculateEndDate(startDate time.Time, billingCycle vo.BillingCycle) time.Time {
|
||||
switch billingCycle {
|
||||
case vo.BillingCycleWeekly:
|
||||
return startDate.Add(7 * 24 * time.Hour) // 7 days
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/orris-inc/orris/internal/application/subscription/dto"
|
||||
"github.com/orris-inc/orris/internal/domain/subscription"
|
||||
"github.com/orris-inc/orris/internal/domain/user"
|
||||
@@ -24,13 +26,15 @@ type ListUserSubscriptionsQuery struct {
|
||||
PageSize int
|
||||
SortBy string
|
||||
SortDesc *bool // nil means default (true = DESC)
|
||||
IncludeCounts bool // When true, also return subscription status counts
|
||||
}
|
||||
|
||||
type ListUserSubscriptionsResult struct {
|
||||
Subscriptions []*dto.SubscriptionDTO `json:"subscriptions"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Subscriptions []*dto.SubscriptionDTO `json:"subscriptions"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StatusCounts *dto.SubscriptionStatusCounts `json:"status_counts,omitempty"` // Present when IncludeCounts is true
|
||||
}
|
||||
|
||||
type ListUserSubscriptionsUseCase struct {
|
||||
@@ -189,6 +193,60 @@ func (uc *ListUserSubscriptionsUseCase) Execute(ctx context.Context, query ListU
|
||||
dtos = append(dtos, result)
|
||||
}
|
||||
|
||||
// Query status counts if requested
|
||||
var statusCounts *dto.SubscriptionStatusCounts
|
||||
if query.IncludeCounts {
|
||||
statusCounts = &dto.SubscriptionStatusCounts{}
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
count, err := uc.subscriptionRepo.CountByStatus(gctx, "active")
|
||||
if err != nil {
|
||||
return fmt.Errorf("count active: %w", err)
|
||||
}
|
||||
statusCounts.Active = count
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
count, err := uc.subscriptionRepo.CountByStatus(gctx, "expired")
|
||||
if err != nil {
|
||||
return fmt.Errorf("count expired: %w", err)
|
||||
}
|
||||
statusCounts.Expired = count
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
count, err := uc.subscriptionRepo.CountByStatus(gctx, "suspended")
|
||||
if err != nil {
|
||||
return fmt.Errorf("count suspended: %w", err)
|
||||
}
|
||||
statusCounts.Suspended = count
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
count, err := uc.subscriptionRepo.CountByStatus(gctx, "pending_payment")
|
||||
if err != nil {
|
||||
return fmt.Errorf("count pending_payment: %w", err)
|
||||
}
|
||||
statusCounts.PendingPayment = count
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(gctx, 7)
|
||||
if err != nil {
|
||||
return fmt.Errorf("find expiring subscriptions: %w", err)
|
||||
}
|
||||
statusCounts.ExpiringIn7Days = int64(len(subs))
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
uc.logger.Warnw("failed to get subscription status counts", "error", err)
|
||||
// Non-fatal: return list without counts
|
||||
statusCounts = nil
|
||||
}
|
||||
}
|
||||
|
||||
uc.logger.Debugw("subscriptions listed successfully",
|
||||
"user_id", query.UserID,
|
||||
"total", total,
|
||||
@@ -201,5 +259,6 @@ func (uc *ListUserSubscriptionsUseCase) Execute(ctx context.Context, query ListU
|
||||
Total: total,
|
||||
Page: query.Page,
|
||||
PageSize: query.PageSize,
|
||||
StatusCounts: statusCounts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package usecases
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/orris-inc/orris/internal/application/subscription/dto"
|
||||
"github.com/orris-inc/orris/internal/domain/subscription"
|
||||
@@ -26,10 +27,13 @@ type PlanChangeNotifier interface {
|
||||
}
|
||||
|
||||
type UpdatePlanUseCase struct {
|
||||
planRepo subscription.PlanRepository
|
||||
pricingRepo subscription.PlanPricingRepository
|
||||
planChangeNotifier PlanChangeNotifier
|
||||
logger logger.Interface
|
||||
planRepo subscription.PlanRepository
|
||||
pricingRepo subscription.PlanPricingRepository
|
||||
subscriptionRepo subscription.SubscriptionRepository
|
||||
planChangeNotifier PlanChangeNotifier
|
||||
quotaCacheManager QuotaCacheManager
|
||||
subscriptionNotifier SubscriptionChangeNotifier
|
||||
logger logger.Interface
|
||||
}
|
||||
|
||||
// SetPlanChangeNotifier sets the notifier for plan feature changes.
|
||||
@@ -37,6 +41,21 @@ func (uc *UpdatePlanUseCase) SetPlanChangeNotifier(notifier PlanChangeNotifier)
|
||||
uc.planChangeNotifier = notifier
|
||||
}
|
||||
|
||||
// SetSubscriptionRepo sets the subscription repository for cascading plan changes.
|
||||
func (uc *UpdatePlanUseCase) SetSubscriptionRepo(repo subscription.SubscriptionRepository) {
|
||||
uc.subscriptionRepo = repo
|
||||
}
|
||||
|
||||
// SetQuotaCacheManager sets the quota cache manager for invalidating cached quotas.
|
||||
func (uc *UpdatePlanUseCase) SetQuotaCacheManager(manager QuotaCacheManager) {
|
||||
uc.quotaCacheManager = manager
|
||||
}
|
||||
|
||||
// SetSubscriptionNotifier sets the notifier for subscription changes.
|
||||
func (uc *UpdatePlanUseCase) SetSubscriptionNotifier(notifier SubscriptionChangeNotifier) {
|
||||
uc.subscriptionNotifier = notifier
|
||||
}
|
||||
|
||||
func NewUpdatePlanUseCase(
|
||||
planRepo subscription.PlanRepository,
|
||||
pricingRepo subscription.PlanPricingRepository,
|
||||
@@ -108,6 +127,11 @@ func (uc *UpdatePlanUseCase) Execute(
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate quota cache for all active subscriptions when limits change
|
||||
if cmd.Limits != nil {
|
||||
uc.invalidateQuotaCacheForPlan(ctx, planID)
|
||||
}
|
||||
|
||||
// Sync pricing options if provided (delete old, create new)
|
||||
if cmd.Pricings != nil {
|
||||
uc.logger.Infow("syncing pricing options", "plan_id", planID, "count", len(*cmd.Pricings))
|
||||
@@ -160,6 +184,9 @@ func (uc *UpdatePlanUseCase) Execute(
|
||||
uc.logger.Infow("pricing options synced successfully",
|
||||
"plan_id", planID,
|
||||
"count", len(*cmd.Pricings))
|
||||
|
||||
// Migrate orphaned subscriptions whose billing cycle is no longer available
|
||||
uc.migrateOrphanedBillingCycles(ctx, planID, *cmd.Pricings)
|
||||
}
|
||||
|
||||
// Reload the plan from database to get the accurate state after update
|
||||
@@ -183,3 +210,163 @@ func (uc *UpdatePlanUseCase) Execute(
|
||||
|
||||
return dto.ToPlanDTOWithPricings(updatedPlan, pricings), nil
|
||||
}
|
||||
|
||||
// invalidateQuotaCacheForPlan invalidates Redis quota cache for all active
|
||||
// subscriptions on the given plan so enforcement uses updated limits.
|
||||
func (uc *UpdatePlanUseCase) invalidateQuotaCacheForPlan(ctx context.Context, planID uint) {
|
||||
if uc.subscriptionRepo == nil || uc.quotaCacheManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
subs, _, err := uc.subscriptionRepo.List(ctx, subscription.SubscriptionFilter{
|
||||
PlanID: &planID,
|
||||
Statuses: []string{string(vo.StatusActive), string(vo.StatusTrialing)},
|
||||
Page: 1,
|
||||
PageSize: 10000,
|
||||
})
|
||||
if err != nil {
|
||||
uc.logger.Warnw("failed to list subscriptions for quota cache invalidation",
|
||||
"plan_id", planID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range subs {
|
||||
if err := uc.quotaCacheManager.InvalidateQuota(ctx, sub.ID()); err != nil {
|
||||
uc.logger.Warnw("failed to invalidate quota cache",
|
||||
"subscription_id", sub.ID(), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(subs) > 0 {
|
||||
uc.logger.Infow("quota cache invalidated for plan subscriptions",
|
||||
"plan_id", planID, "count", len(subs))
|
||||
}
|
||||
}
|
||||
|
||||
// migrateOrphanedBillingCycles migrates subscriptions whose billing cycle
|
||||
// is no longer available in the new pricing options.
|
||||
// e.g., plan changed from monthly to lifetime -> existing monthly subscriptions
|
||||
// are migrated to lifetime with recalculated end_date.
|
||||
func (uc *UpdatePlanUseCase) migrateOrphanedBillingCycles(
|
||||
ctx context.Context, planID uint, newPricings []dto.PricingOptionInput,
|
||||
) {
|
||||
if uc.subscriptionRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Build set of active billing cycles from new pricings
|
||||
availableCycles := make(map[string]bool)
|
||||
for _, p := range newPricings {
|
||||
if p.IsActive {
|
||||
availableCycles[p.BillingCycle] = true
|
||||
}
|
||||
}
|
||||
if len(availableCycles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Query affected subscriptions
|
||||
subs, _, err := uc.subscriptionRepo.List(ctx, subscription.SubscriptionFilter{
|
||||
PlanID: &planID,
|
||||
Statuses: []string{string(vo.StatusActive), string(vo.StatusTrialing), string(vo.StatusSuspended)},
|
||||
Page: 1,
|
||||
PageSize: 10000,
|
||||
})
|
||||
if err != nil {
|
||||
uc.logger.Warnw("failed to list subscriptions for billing cycle migration",
|
||||
"plan_id", planID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Find orphaned subscriptions
|
||||
migratedCount := 0
|
||||
for _, sub := range subs {
|
||||
if sub.BillingCycle() == nil {
|
||||
continue
|
||||
}
|
||||
if availableCycles[sub.BillingCycle().String()] {
|
||||
continue // billing cycle still available
|
||||
}
|
||||
|
||||
oldCycle := sub.BillingCycle().String()
|
||||
targetCycle := findClosestBillingCycle(sub.BillingCycle(), availableCycles)
|
||||
newEndDate := CalculateEndDate(sub.CurrentPeriodStart(), targetCycle)
|
||||
|
||||
if err := sub.ChangeBillingCycle(targetCycle, newEndDate); err != nil {
|
||||
uc.logger.Warnw("failed to change billing cycle",
|
||||
"subscription_id", sub.ID(), "old_cycle", oldCycle, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := uc.subscriptionRepo.Update(ctx, sub); err != nil {
|
||||
uc.logger.Warnw("failed to persist subscription after billing cycle migration",
|
||||
"subscription_id", sub.ID(), "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Invalidate quota cache for migrated subscription
|
||||
if uc.quotaCacheManager != nil {
|
||||
_ = uc.quotaCacheManager.InvalidateQuota(ctx, sub.ID())
|
||||
}
|
||||
|
||||
// Notify nodes of subscription update
|
||||
if uc.subscriptionNotifier != nil {
|
||||
_ = uc.subscriptionNotifier.NotifySubscriptionUpdate(ctx, sub)
|
||||
}
|
||||
|
||||
uc.logger.Infow("subscription billing cycle migrated",
|
||||
"subscription_id", sub.ID(),
|
||||
"subscription_sid", sub.SID(),
|
||||
"old_cycle", oldCycle,
|
||||
"new_cycle", targetCycle.String(),
|
||||
"new_end_date", newEndDate,
|
||||
)
|
||||
migratedCount++
|
||||
}
|
||||
|
||||
if migratedCount > 0 {
|
||||
uc.logger.Infow("billing cycle migration completed",
|
||||
"plan_id", planID, "migrated_count", migratedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// findClosestBillingCycle finds the billing cycle from availableCycles
|
||||
// that is closest in duration to the given oldCycle.
|
||||
// On tie, prefers the longer duration to minimize disruption.
|
||||
func findClosestBillingCycle(oldCycle *vo.BillingCycle, availableCycles map[string]bool) vo.BillingCycle {
|
||||
oldDays := 0
|
||||
if oldCycle != nil {
|
||||
oldDays = oldCycle.Days()
|
||||
if oldCycle.IsLifetime() {
|
||||
oldDays = math.MaxInt32
|
||||
}
|
||||
}
|
||||
|
||||
var bestCycle vo.BillingCycle
|
||||
bestDiff := math.MaxInt32
|
||||
bestDays := 0
|
||||
|
||||
for c := range availableCycles {
|
||||
parsed, err := vo.ParseBillingCycle(c)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
days := parsed.Days()
|
||||
if parsed.IsLifetime() {
|
||||
days = math.MaxInt32
|
||||
}
|
||||
|
||||
diff := oldDays - days
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
// Prefer closer; on tie prefer longer duration
|
||||
if diff < bestDiff || (diff == bestDiff && days > bestDays) {
|
||||
bestDiff = diff
|
||||
bestCycle = parsed
|
||||
bestDays = days
|
||||
}
|
||||
}
|
||||
|
||||
return bestCycle
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ type SubscriptionFilter struct {
|
||||
UserID *uint
|
||||
PlanID *uint
|
||||
Status *string
|
||||
Statuses []string // Multiple status filter (IN clause), takes precedence over Status
|
||||
BillingCycle *string
|
||||
CreatedFrom *time.Time
|
||||
CreatedTo *time.Time
|
||||
|
||||
@@ -570,6 +570,23 @@ func (s *Subscription) ChangePlan(newPlanID uint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeBillingCycle updates the billing cycle and end date.
|
||||
// Used when plan pricing changes make the current billing cycle unavailable.
|
||||
func (s *Subscription) ChangeBillingCycle(newCycle vo.BillingCycle, newEndDate time.Time) error {
|
||||
if !newCycle.IsValid() {
|
||||
return fmt.Errorf("invalid billing cycle: %s", newCycle)
|
||||
}
|
||||
if newEndDate.Before(s.startDate) {
|
||||
return fmt.Errorf("new end date must be after start date")
|
||||
}
|
||||
s.billingCycle = &newCycle
|
||||
s.endDate = newEndDate
|
||||
s.currentPeriodEnd = newEndDate
|
||||
s.updatedAt = biztime.NowUTC()
|
||||
s.version++
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if subscription is expired
|
||||
func (s *Subscription) IsExpired() bool {
|
||||
return biztime.NowUTC().After(s.endDate)
|
||||
|
||||
@@ -44,6 +44,9 @@ func GetTrafficResetMode(plan *Plan) TrafficResetMode {
|
||||
// For calendar_month: uses business timezone month boundaries (backward compatible default).
|
||||
// For billing_cycle: uses the subscription's CurrentPeriodStart/CurrentPeriodEnd.
|
||||
// Falls back to calendar_month if sub is nil.
|
||||
//
|
||||
// In both modes, if the subscription's CurrentPeriodStart is after the resolved period start
|
||||
// (e.g. after a manual usage reset), it is used as a floor to exclude pre-reset traffic.
|
||||
func ResolveTrafficPeriod(plan *Plan, sub *Subscription) TrafficPeriod {
|
||||
mode := GetTrafficResetMode(plan)
|
||||
|
||||
@@ -56,8 +59,16 @@ func ResolveTrafficPeriod(plan *Plan, sub *Subscription) TrafficPeriod {
|
||||
|
||||
// Fallback: calendar month (default, or billing_cycle with nil sub)
|
||||
bizNow := biztime.ToBizTimezone(biztime.NowUTC())
|
||||
return TrafficPeriod{
|
||||
period := TrafficPeriod{
|
||||
Start: biztime.StartOfMonthUTC(bizNow.Year(), bizNow.Month()),
|
||||
End: biztime.EndOfMonthUTC(bizNow.Year(), bizNow.Month()),
|
||||
}
|
||||
|
||||
// If subscription's period start is after the calendar month start (e.g. after manual
|
||||
// usage reset), use it as the floor so pre-reset traffic is excluded.
|
||||
if sub != nil && sub.CurrentPeriodStart().After(period.Start) {
|
||||
period.Start = sub.CurrentPeriodStart()
|
||||
}
|
||||
|
||||
return period
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package adapters
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/orris-inc/orris/internal/domain/node"
|
||||
nodevo "github.com/orris-inc/orris/internal/domain/node/valueobjects"
|
||||
"github.com/orris-inc/orris/internal/domain/subscription/valueobjects"
|
||||
"github.com/orris-inc/orris/internal/interfaces/adapters/nodeutil"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/persistence/nodeutil"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/persistence/models"
|
||||
"github.com/orris-inc/orris/internal/shared/logger"
|
||||
"github.com/orris-inc/orris/internal/shared/utils/jsonutil"
|
||||
@@ -18,21 +18,19 @@ import (
|
||||
"github.com/orris-inc/orris/internal/shared/utils/setutil"
|
||||
)
|
||||
|
||||
// NodeRepository defines the interface for node persistence operations
|
||||
type NodeRepository interface {
|
||||
GetByToken(ctx context.Context, tokenHash string) (*node.Node, error)
|
||||
}
|
||||
|
||||
type NodeRepositoryAdapter struct {
|
||||
nodeRepo NodeRepository
|
||||
// NodeSubscriptionRepository implements usecases.NodeRepository by querying
|
||||
// subscriptions, plans, resource groups, and nodes to build subscription node lists.
|
||||
type NodeSubscriptionRepository struct {
|
||||
nodeRepo node.NodeRepository
|
||||
forwardRuleRepo forward.Repository
|
||||
db *gorm.DB
|
||||
logger logger.Interface
|
||||
configLoader *nodeutil.ConfigLoader
|
||||
}
|
||||
|
||||
func NewNodeRepositoryAdapter(nodeRepo NodeRepository, forwardRuleRepo forward.Repository, db *gorm.DB, logger logger.Interface) *NodeRepositoryAdapter {
|
||||
return &NodeRepositoryAdapter{
|
||||
// NewNodeSubscriptionRepository creates a new NodeSubscriptionRepository.
|
||||
func NewNodeSubscriptionRepository(nodeRepo node.NodeRepository, forwardRuleRepo forward.Repository, db *gorm.DB, logger logger.Interface) *NodeSubscriptionRepository {
|
||||
return &NodeSubscriptionRepository{
|
||||
nodeRepo: nodeRepo,
|
||||
forwardRuleRepo: forwardRuleRepo,
|
||||
db: db,
|
||||
@@ -41,7 +39,7 @@ func NewNodeRepositoryAdapter(nodeRepo NodeRepository, forwardRuleRepo forward.R
|
||||
}
|
||||
}
|
||||
|
||||
func (r *NodeRepositoryAdapter) GetBySubscriptionToken(ctx context.Context, linkToken string, mode string) ([]*usecases.Node, error) {
|
||||
func (r *NodeSubscriptionRepository) GetBySubscriptionToken(ctx context.Context, linkToken string, mode string) ([]*usecases.Node, error) {
|
||||
var subscriptionModel models.SubscriptionModel
|
||||
|
||||
// Query subscription by link_token
|
||||
@@ -148,7 +146,7 @@ func (r *NodeRepositoryAdapter) GetBySubscriptionToken(ctx context.Context, link
|
||||
}
|
||||
}
|
||||
|
||||
func (r *NodeRepositoryAdapter) GetByTokenHash(ctx context.Context, tokenHash string) (usecases.NodeData, error) {
|
||||
func (r *NodeSubscriptionRepository) GetByTokenHash(ctx context.Context, tokenHash string) (usecases.NodeData, error) {
|
||||
nodeEntity, err := r.nodeRepo.GetByToken(ctx, tokenHash)
|
||||
if err != nil {
|
||||
r.logger.Warnw("failed to get node by token hash",
|
||||
@@ -172,7 +170,7 @@ func (r *NodeRepositoryAdapter) GetByTokenHash(ctx context.Context, tokenHash st
|
||||
// Rules are selected based on resource group membership, regardless of whether their
|
||||
// target nodes are in the same resource groups.
|
||||
// Uses Repository method to ensure proper scope isolation (system rules only).
|
||||
func (r *NodeRepositoryAdapter) getForwardedNodes(ctx context.Context, groupIDs []uint, originNodeMap map[uint]*usecases.Node) []*usecases.Node {
|
||||
func (r *NodeSubscriptionRepository) getForwardedNodes(ctx context.Context, groupIDs []uint, originNodeMap map[uint]*usecases.Node) []*usecases.Node {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -238,7 +236,7 @@ func (r *NodeRepositoryAdapter) getForwardedNodes(ctx context.Context, groupIDs
|
||||
// For forward plans, users see their own forward rules as subscription nodes
|
||||
// Forward plans have no "origin" nodes - all nodes are forwarded by nature
|
||||
// Uses Repository method to ensure proper scope isolation (user's own rules only).
|
||||
func (r *NodeRepositoryAdapter) getForwardPlanNodes(ctx context.Context, subscriptionID uint, userID uint, mode string) ([]*usecases.Node, error) {
|
||||
func (r *NodeSubscriptionRepository) getForwardPlanNodes(ctx context.Context, subscriptionID uint, userID uint, mode string) ([]*usecases.Node, error) {
|
||||
// Forward plans have no origin nodes, return empty for origin mode
|
||||
if mode == usecases.NodeModeOrigin {
|
||||
r.logger.Debugw("forward plan has no origin nodes", "user_id", userID, "mode", mode)
|
||||
@@ -262,7 +260,7 @@ func (r *NodeRepositoryAdapter) getForwardPlanNodes(ctx context.Context, subscri
|
||||
|
||||
// getHybridPlanNodes returns nodes for hybrid plan subscriptions
|
||||
// For hybrid plans, users see both resource group nodes AND their own forward rules
|
||||
func (r *NodeRepositoryAdapter) getHybridPlanNodes(ctx context.Context, subscriptionID uint, userID uint, planID uint, mode string) ([]*usecases.Node, error) {
|
||||
func (r *NodeSubscriptionRepository) getHybridPlanNodes(ctx context.Context, subscriptionID uint, userID uint, planID uint, mode string) ([]*usecases.Node, error) {
|
||||
// Step 1: Get resource group nodes (same as node plan logic)
|
||||
// Query resource group IDs for this plan
|
||||
var groupIDs []uint
|
||||
@@ -349,7 +347,7 @@ func (r *NodeRepositoryAdapter) getHybridPlanNodes(ctx context.Context, subscrip
|
||||
}
|
||||
|
||||
// buildNodesWithConfigs builds use case nodes from node models with protocol configs loaded
|
||||
func (r *NodeRepositoryAdapter) buildNodesWithConfigs(ctx context.Context, nodeModels []models.NodeModel) []*usecases.Node {
|
||||
func (r *NodeSubscriptionRepository) buildNodesWithConfigs(ctx context.Context, nodeModels []models.NodeModel) []*usecases.Node {
|
||||
configs := r.configLoader.LoadProtocolConfigs(ctx, nodeModels)
|
||||
|
||||
nodes := make([]*usecases.Node, 0, len(nodeModels))
|
||||
@@ -365,7 +363,7 @@ func (r *NodeRepositoryAdapter) buildNodesWithConfigs(ctx context.Context, nodeM
|
||||
// getUserForwardNodes retrieves user's forward rules with target nodes as subscription nodes
|
||||
// Only returns forward rules where target_node_id is NOT NULL
|
||||
// Uses Repository method to ensure proper scope isolation (user's own rules only).
|
||||
func (r *NodeRepositoryAdapter) getUserForwardNodes(ctx context.Context, userID uint) ([]*usecases.Node, error) {
|
||||
func (r *NodeSubscriptionRepository) getUserForwardNodes(ctx context.Context, userID uint) ([]*usecases.Node, error) {
|
||||
forwardRules, err := r.forwardRuleRepo.ListUserRulesForDelivery(ctx, userID)
|
||||
if err != nil {
|
||||
r.logger.Errorw("failed to query user forward rules", "user_id", userID, "error", err)
|
||||
@@ -395,7 +393,7 @@ func (r *NodeRepositoryAdapter) getUserForwardNodes(ctx context.Context, userID
|
||||
}
|
||||
|
||||
// collectIDsFromRules extracts node IDs and agent IDs from forward rules
|
||||
func (r *NodeRepositoryAdapter) collectIDsFromRules(rules []*forward.ForwardRule) (nodeIDs, agentIDs []uint) {
|
||||
func (r *NodeSubscriptionRepository) collectIDsFromRules(rules []*forward.ForwardRule) (nodeIDs, agentIDs []uint) {
|
||||
nodeIDSet := setutil.NewUintSet()
|
||||
agentIDSet := setutil.NewUintSet()
|
||||
|
||||
@@ -412,7 +410,7 @@ func (r *NodeRepositoryAdapter) collectIDsFromRules(rules []*forward.ForwardRule
|
||||
}
|
||||
|
||||
// collectAgentIDsFromRules extracts agent IDs from forward rules (skipping external rules).
|
||||
func (r *NodeRepositoryAdapter) collectAgentIDsFromRules(rules []*forward.ForwardRule) []uint {
|
||||
func (r *NodeSubscriptionRepository) collectAgentIDsFromRules(rules []*forward.ForwardRule) []uint {
|
||||
agentIDSet := setutil.NewUintSet()
|
||||
for _, rule := range rules {
|
||||
if rule.AgentID() > 0 {
|
||||
@@ -423,7 +421,7 @@ func (r *NodeRepositoryAdapter) collectAgentIDsFromRules(rules []*forward.Forwar
|
||||
}
|
||||
|
||||
// queryActiveNodes queries active nodes by IDs and returns both slice and map
|
||||
func (r *NodeRepositoryAdapter) queryActiveNodes(ctx context.Context, nodeIDs []uint) ([]models.NodeModel, map[uint]*models.NodeModel, error) {
|
||||
func (r *NodeSubscriptionRepository) queryActiveNodes(ctx context.Context, nodeIDs []uint) ([]models.NodeModel, map[uint]*models.NodeModel, error) {
|
||||
nodeMap := make(map[uint]*models.NodeModel)
|
||||
if len(nodeIDs) == 0 {
|
||||
return nil, nodeMap, nil
|
||||
@@ -447,7 +445,7 @@ func (r *NodeRepositoryAdapter) queryActiveNodes(ctx context.Context, nodeIDs []
|
||||
}
|
||||
|
||||
// loadForwardAgents loads forward agents by IDs and returns a map
|
||||
func (r *NodeRepositoryAdapter) loadForwardAgents(ctx context.Context, agentIDs []uint) map[uint]*models.ForwardAgentModel {
|
||||
func (r *NodeSubscriptionRepository) loadForwardAgents(ctx context.Context, agentIDs []uint) map[uint]*models.ForwardAgentModel {
|
||||
agentMap := make(map[uint]*models.ForwardAgentModel)
|
||||
if len(agentIDs) == 0 {
|
||||
return agentMap
|
||||
@@ -460,7 +460,9 @@ func (r *SubscriptionRepositoryImpl) List(ctx context.Context, filter subscripti
|
||||
if filter.PlanID != nil {
|
||||
query = query.Where("plan_id = ?", *filter.PlanID)
|
||||
}
|
||||
if filter.Status != nil {
|
||||
if len(filter.Statuses) > 0 {
|
||||
query = query.Where("status IN ?", filter.Statuses)
|
||||
} else if filter.Status != nil {
|
||||
query = query.Where("status = ?", *filter.Status)
|
||||
}
|
||||
if filter.BillingCycle != nil {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package adapters
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,13 +13,15 @@ import (
|
||||
"github.com/orris-inc/orris/internal/shared/logger"
|
||||
)
|
||||
|
||||
type SubscriptionTokenValidatorAdapter struct {
|
||||
// SubscriptionTokenValidator validates subscription tokens by querying the database directly.
|
||||
type SubscriptionTokenValidator struct {
|
||||
db *gorm.DB
|
||||
logger logger.Interface
|
||||
}
|
||||
|
||||
func NewSubscriptionTokenValidatorAdapter(db *gorm.DB, logger logger.Interface) *SubscriptionTokenValidatorAdapter {
|
||||
return &SubscriptionTokenValidatorAdapter{
|
||||
// NewSubscriptionTokenValidator creates a new SubscriptionTokenValidator.
|
||||
func NewSubscriptionTokenValidator(db *gorm.DB, logger logger.Interface) *SubscriptionTokenValidator {
|
||||
return &SubscriptionTokenValidator{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
@@ -54,7 +56,7 @@ func subscriptionStatusError(status string) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *SubscriptionTokenValidatorAdapter) Validate(ctx context.Context, linkToken string) error {
|
||||
func (v *SubscriptionTokenValidator) Validate(ctx context.Context, linkToken string) error {
|
||||
var subscriptionModel models.SubscriptionModel
|
||||
if err := v.db.WithContext(ctx).
|
||||
Where("link_token = ?", linkToken).
|
||||
@@ -82,7 +84,7 @@ func (v *SubscriptionTokenValidatorAdapter) Validate(ctx context.Context, linkTo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *SubscriptionTokenValidatorAdapter) ValidateAndGetSubscription(ctx context.Context, linkToken string) (*nodeusecases.SubscriptionValidationResult, error) {
|
||||
func (v *SubscriptionTokenValidator) ValidateAndGetSubscription(ctx context.Context, linkToken string) (*nodeusecases.SubscriptionValidationResult, error) {
|
||||
var subscriptionModel models.SubscriptionModel
|
||||
if err := v.db.WithContext(ctx).
|
||||
Where("link_token = ?", linkToken).
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
nodeUsecases "github.com/orris-inc/orris/internal/application/node/usecases"
|
||||
"github.com/orris-inc/orris/internal/domain/subscription"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/cache"
|
||||
nodeHandlers "github.com/orris-inc/orris/internal/interfaces/http/handlers/node"
|
||||
"github.com/orris-inc/orris/internal/shared/biztime"
|
||||
"github.com/orris-inc/orris/internal/shared/logger"
|
||||
)
|
||||
@@ -29,7 +29,7 @@ func NewNodeSubscriptionQuotaCacheAdapter(
|
||||
}
|
||||
|
||||
// GetQuota retrieves subscription quota from cache
|
||||
func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscriptionID uint) (*nodeHandlers.CachedQuotaInfo, error) {
|
||||
func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscriptionID uint) (*nodeUsecases.CachedQuotaInfo, error) {
|
||||
cached, err := a.cache.GetQuota(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -38,7 +38,7 @@ func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &nodeHandlers.CachedQuotaInfo{
|
||||
return &nodeUsecases.CachedQuotaInfo{
|
||||
Limit: cached.Limit,
|
||||
PeriodStart: cached.PeriodStart,
|
||||
PeriodEnd: cached.PeriodEnd,
|
||||
@@ -79,7 +79,7 @@ func NewNodeSubscriptionQuotaLoaderAdapter(
|
||||
// LoadQuotaByID loads subscription quota from database and caches it.
|
||||
// When subscription is not found or inactive, a null marker is cached to prevent
|
||||
// repeated DB lookups (cache penetration protection).
|
||||
func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, subscriptionID uint) (*nodeHandlers.CachedQuotaInfo, error) {
|
||||
func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, subscriptionID uint) (*nodeUsecases.CachedQuotaInfo, error) {
|
||||
// Get subscription from database
|
||||
sub, err := a.subscriptionRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
@@ -134,7 +134,7 @@ func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context,
|
||||
)
|
||||
}
|
||||
|
||||
return &nodeHandlers.CachedQuotaInfo{
|
||||
return &nodeUsecases.CachedQuotaInfo{
|
||||
Limit: cachedQuota.Limit,
|
||||
PeriodStart: cachedQuota.PeriodStart,
|
||||
PeriodEnd: cachedQuota.PeriodEnd,
|
||||
@@ -186,10 +186,15 @@ func (a *NodeSubscriptionUsageReaderAdapter) GetCurrentPeriodUsage(
|
||||
|
||||
var total int64
|
||||
|
||||
// Get recent traffic from Redis (yesterday + today, filter by node type)
|
||||
// Get recent traffic from Redis (yesterday + today, filter by node type).
|
||||
// Use max(recentBoundary, periodStart) so traffic before periodStart (e.g. after reset) is excluded.
|
||||
resourceType := subscription.ResourceTypeNode.String()
|
||||
redisFrom := recentBoundary
|
||||
if periodStart.After(redisFrom) {
|
||||
redisFrom = periodStart
|
||||
}
|
||||
recentTraffic, err := a.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs(
|
||||
ctx, []uint{subscriptionID}, resourceType, recentBoundary, now,
|
||||
ctx, []uint{subscriptionID}, resourceType, redisFrom, now,
|
||||
)
|
||||
if err != nil {
|
||||
a.logger.Warnw("failed to get recent traffic from Redis",
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/orris-inc/orris/internal/infrastructure/email"
|
||||
infraPayment "github.com/orris-inc/orris/internal/infrastructure/payment"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/pubsub"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/repository"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/scheduler"
|
||||
"github.com/orris-inc/orris/internal/infrastructure/services"
|
||||
telegramInfra "github.com/orris-inc/orris/internal/infrastructure/telegram"
|
||||
@@ -100,8 +101,8 @@ type Container struct {
|
||||
adminNotificationServiceDDD *telegramAdminApp.ServiceDDD
|
||||
|
||||
// Cross-cutting adapters and services (created in one section, used in another)
|
||||
nodeRepoAdapter *adapters.NodeRepositoryAdapter
|
||||
tokenValidator *adapters.SubscriptionTokenValidatorAdapter
|
||||
nodeRepoAdapter *repository.NodeSubscriptionRepository
|
||||
tokenValidator *repository.SubscriptionTokenValidator
|
||||
templateLoader *template.SubscriptionTemplateLoader
|
||||
nodeStatusQuerier *adapters.NodeSystemStatusQuerierAdapter
|
||||
forwardAgentReleaseService *services.GitHubReleaseService
|
||||
|
||||
@@ -270,6 +270,9 @@ func (h *Handler) List(c *gin.Context) {
|
||||
sortDesc = &desc
|
||||
}
|
||||
|
||||
// Parse include_counts parameter
|
||||
includeCounts := c.Query("include_counts") == "true"
|
||||
|
||||
query := usecases.ListUserSubscriptionsQuery{
|
||||
UserID: userID,
|
||||
PlanID: planID,
|
||||
@@ -282,6 +285,7 @@ func (h *Handler) List(c *gin.Context) {
|
||||
PageSize: p.PageSize,
|
||||
SortBy: sortBy,
|
||||
SortDesc: sortDesc,
|
||||
IncludeCounts: includeCounts,
|
||||
}
|
||||
|
||||
result, err := h.listUseCase.Execute(c.Request.Context(), query)
|
||||
@@ -291,6 +295,20 @@ func (h *Handler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build response with optional status_counts
|
||||
if result.StatusCounts != nil {
|
||||
response := map[string]any{
|
||||
"items": result.Subscriptions,
|
||||
"total": result.Total,
|
||||
"page": result.Page,
|
||||
"page_size": result.PageSize,
|
||||
"total_pages": utils.TotalPages(result.Total, result.PageSize),
|
||||
"status_counts": result.StatusCounts,
|
||||
}
|
||||
utils.SuccessResponse(c, http.StatusOK, "", response)
|
||||
return
|
||||
}
|
||||
|
||||
utils.ListSuccessResponse(c, result.Subscriptions, result.Total, result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
|
||||
@@ -90,34 +90,13 @@ type NodeSubscriptionUsageReader interface {
|
||||
}
|
||||
|
||||
// NodeSubscriptionQuotaCache defines the interface for subscription quota caching.
|
||||
// This is a local interface to avoid import cycle with cache package.
|
||||
type NodeSubscriptionQuotaCache interface {
|
||||
// GetQuota retrieves subscription quota information from cache.
|
||||
// Returns nil if cache does not exist.
|
||||
GetQuota(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
|
||||
|
||||
// MarkSuspended marks the subscription as suspended in cache.
|
||||
MarkSuspended(ctx context.Context, subscriptionID uint) error
|
||||
}
|
||||
type NodeSubscriptionQuotaCache = usecases.NodeSubscriptionQuotaCache
|
||||
|
||||
// CachedQuotaInfo represents the cached subscription quota information.
|
||||
// This mirrors cache.CachedQuota to avoid import cycle.
|
||||
type CachedQuotaInfo struct {
|
||||
Limit int64 // Traffic limit in bytes
|
||||
PeriodStart time.Time // Billing period start
|
||||
PeriodEnd time.Time // Billing period end
|
||||
PlanType string // node/forward/hybrid
|
||||
Suspended bool // Whether the subscription is suspended
|
||||
NotFound bool // Null marker: subscription confirmed not found/inactive in DB
|
||||
}
|
||||
type CachedQuotaInfo = usecases.CachedQuotaInfo
|
||||
|
||||
// NodeSubscriptionQuotaLoader defines the interface for lazy loading subscription quota.
|
||||
// This is used when quota cache miss occurs to load quota from database.
|
||||
type NodeSubscriptionQuotaLoader interface {
|
||||
// LoadQuotaByID loads subscription quota from database and caches it.
|
||||
// Returns the cached quota info, or nil if subscription/plan not found.
|
||||
LoadQuotaByID(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
|
||||
}
|
||||
type NodeSubscriptionQuotaLoader = usecases.NodeSubscriptionQuotaLoader
|
||||
|
||||
// NodeHubHandler handles WebSocket connections for node agents.
|
||||
type NodeHubHandler struct {
|
||||
|
||||
@@ -311,8 +311,8 @@ func (c *Container) initNode() {
|
||||
hdlrs := c.hdlrs
|
||||
|
||||
// Initialize adapters
|
||||
c.nodeRepoAdapter = adapters.NewNodeRepositoryAdapter(repos.nodeRepoImpl, repos.forwardRuleRepo, db, log)
|
||||
c.tokenValidator = adapters.NewSubscriptionTokenValidatorAdapter(db, log)
|
||||
c.nodeRepoAdapter = repository.NewNodeSubscriptionRepository(repos.nodeRepoImpl, repos.forwardRuleRepo, db, log)
|
||||
c.tokenValidator = repository.NewSubscriptionTokenValidator(db, log)
|
||||
c.nodeStatusQuerier = adapters.NewNodeSystemStatusQuerierAdapter(c.redis, log)
|
||||
|
||||
// Initialize GitHub release services for version checking
|
||||
@@ -1359,6 +1359,9 @@ func (c *Container) initCallbacksAndNotifiers() {
|
||||
|
||||
// Set plan change notifier to propagate plan feature changes (e.g. device_limit) to nodes
|
||||
ucs.updatePlanUC.SetPlanChangeNotifier(c.subscriptionSyncService)
|
||||
ucs.updatePlanUC.SetSubscriptionRepo(repos.subscriptionRepo)
|
||||
ucs.updatePlanUC.SetQuotaCacheManager(c.quotaCacheSyncService)
|
||||
ucs.updatePlanUC.SetSubscriptionNotifier(c.subscriptionSyncService)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
|
||||
Reference in New Issue
Block a user