diff --git a/internal/application/forward/services/trafficlimitenforcement.go b/internal/application/forward/services/trafficlimitenforcement.go index 92ad48a..3ab94d1 100644 --- a/internal/application/forward/services/trafficlimitenforcement.go +++ b/internal/application/forward/services/trafficlimitenforcement.go @@ -74,7 +74,7 @@ func (s *TrafficLimitEnforcementService) CheckAndEnforceLimit(ctx context.Contex // Find the highest traffic limit across all Forward-type subscriptions // and collect their subscription IDs for traffic query - trafficLimit, hasLimit, forwardSubscriptionIDs, err := s.getHighestTrafficLimitAndIDs(ctx, activeSubscriptions) + trafficLimit, hasLimit, forwardSubscriptionIDs, periodStart, err := s.getHighestTrafficLimitAndIDs(ctx, activeSubscriptions) if err != nil { s.logger.Errorw("failed to determine traffic limit", "user_id", userID, @@ -99,8 +99,8 @@ func (s *TrafficLimitEnforcementService) CheckAndEnforceLimit(ctx context.Contex return nil } - // Get user's total forward traffic by combining Redis (recent 24h) and MySQL (historical) - usedTraffic, err := s.getTotalTrafficForSubscriptions(ctx, forwardSubscriptionIDs) + // Get user's total forward traffic within the resolved traffic period + usedTraffic, err := s.getTotalTrafficForSubscriptions(ctx, forwardSubscriptionIDs, periodStart) if err != nil { s.logger.Errorw("failed to get total traffic for user", "user_id", userID, @@ -262,12 +262,14 @@ func (s *TrafficLimitEnforcementService) OnTrafficUpdate(ctx context.Context, ru // getHighestTrafficLimitAndIDs returns the highest traffic limit across all Forward-type subscriptions // and collects their subscription IDs for traffic query. -// Returns (limit, hasLimit, subscriptionIDs, error) where hasLimit is false if any subscription has unlimited traffic. +// Returns (limit, hasLimit, subscriptionIDs, periodStart, error) where hasLimit is false if any subscription has unlimited traffic. +// periodStart is the latest traffic period start across all forward subscriptions (respects manual resets). // Only considers subscriptions with PlanType = "forward". -func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx context.Context, subscriptions []*subscription.Subscription) (uint64, bool, []uint, error) { +func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx context.Context, subscriptions []*subscription.Subscription) (uint64, bool, []uint, time.Time, error) { var highestLimit uint64 hasLimit := false var forwardSubscriptionIDs []uint + var latestPeriodStart time.Time // Collect plan IDs for batch query planIDs := make([]uint, 0, len(subscriptions)) @@ -279,7 +281,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex plansList, err := s.planRepo.GetByIDs(ctx, planIDs) if err != nil { s.logger.Errorw("failed to batch fetch plans", "error", err) - return 0, false, nil, err + return 0, false, nil, time.Time{}, err } // Convert to map for quick lookup @@ -312,6 +314,13 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex // Collect Forward-type subscription ID forwardSubscriptionIDs = append(forwardSubscriptionIDs, sub.ID()) + // Resolve traffic period and track the latest period start across all forward subscriptions. + // This ensures manual usage resets are respected: traffic before the reset is excluded. + period := subscription.ResolveTrafficPeriod(plan, sub) + if latestPeriodStart.IsZero() || period.Start.After(latestPeriodStart) { + latestPeriodStart = period.Start + } + // Determine traffic limit: subscription override takes priority over plan var limit uint64 if sub.TrafficLimitOverride() != nil { @@ -323,7 +332,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex "subscription_id", sub.ID(), "plan_id", sub.PlanID(), ) - return 0, false, forwardSubscriptionIDs, nil // Unlimited traffic - don't enforce + return 0, false, forwardSubscriptionIDs, latestPeriodStart, nil } var err error @@ -344,7 +353,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex "subscription_id", sub.ID(), "plan_id", sub.PlanID(), ) - return 0, false, forwardSubscriptionIDs, nil + return 0, false, forwardSubscriptionIDs, latestPeriodStart, nil } // Track the highest limit @@ -354,7 +363,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex } } - return highestLimit, hasLimit, forwardSubscriptionIDs, nil + return highestLimit, hasLimit, forwardSubscriptionIDs, latestPeriodStart, nil } // getForwardTrafficLimit extracts the traffic limit from a plan. @@ -370,10 +379,10 @@ func (s *TrafficLimitEnforcementService) getForwardTrafficLimit(plan *subscripti } // getTotalTrafficForSubscriptions calculates total traffic for given subscription IDs -// by combining data from two sources: -// - Last 24 hours: from Redis HourlyTrafficCache -// - Before 24 hours: from MySQL subscription_usage_stats table -func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx context.Context, subscriptionIDs []uint) (uint64, error) { +// within the given traffic period by combining data from two sources: +// - Recent window: from Redis HourlyTrafficCache +// - Historical window: from MySQL subscription_usage_stats table +func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx context.Context, subscriptionIDs []uint, periodStart time.Time) (uint64, error) { if len(subscriptionIDs) == 0 { return 0, nil } @@ -386,10 +395,15 @@ func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx con var total uint64 - // Get recent traffic from Redis (yesterday + today, filter by forward_rule type) + // Determine Redis query range: max(recentBoundary, periodStart) to now + // This ensures traffic before periodStart (e.g. after manual reset) is excluded. resourceType := subscription.ResourceTypeForwardRule.String() + redisFrom := recentBoundary + if periodStart.After(redisFrom) { + redisFrom = periodStart + } recentTraffic, err := s.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs( - ctx, subscriptionIDs, resourceType, recentBoundary, now, + ctx, subscriptionIDs, resourceType, redisFrom, now, ) if err != nil { // Log warning but don't fail - Redis unavailability shouldn't block limit checks @@ -403,16 +417,21 @@ func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx con for _, traffic := range recentTraffic { total += traffic.Total } - s.logger.Debugw("got recent 24h traffic from Redis", + s.logger.Debugw("got recent traffic from Redis", "subscription_ids_count", len(subscriptionIDs), "recent_total", total, ) } - // Get historical traffic from MySQL subscription_usage_stats (complete days before yesterday) - // Use daily granularity for historical aggregation, filter by forward_rule type + // Get historical traffic from MySQL subscription_usage_stats (within period, before recentBoundary) + historicalEnd := recentBoundary.Add(-time.Second) + if historicalEnd.Before(periodStart) { + // Period started after recentBoundary, no historical data needed + return total, nil + } + historicalTraffic, err := s.usageStatsRepo.GetTotalBySubscriptionIDs( - ctx, subscriptionIDs, &resourceType, subscription.GranularityDaily, time.Time{}, recentBoundary.Add(-time.Second), + ctx, subscriptionIDs, &resourceType, subscription.GranularityDaily, periodStart, historicalEnd, ) if err != nil { s.logger.Warnw("failed to get historical traffic from stats, using Redis data only", diff --git a/internal/application/node/services/trafficlimitenforcement.go b/internal/application/node/services/trafficlimitenforcement.go index 5eba65b..0237f05 100644 --- a/internal/application/node/services/trafficlimitenforcement.go +++ b/internal/application/node/services/trafficlimitenforcement.go @@ -146,8 +146,11 @@ func (s *NodeTrafficLimitEnforcementService) CheckAndEnforceLimitForNode(ctx con return nil } - // Get total traffic for this subscription - usedTraffic, err := s.getTotalTrafficForSubscription(ctx, subscriptionID) + // Resolve traffic period so we only count traffic within the current period + period := subscription.ResolveTrafficPeriod(plan, sub) + + // Get total traffic for this subscription within the resolved period + usedTraffic, err := s.getTotalTrafficForSubscription(ctx, subscriptionID, period.Start) if err != nil { s.logger.Errorw("failed to get total traffic", "subscription_id", subscriptionID, @@ -234,10 +237,10 @@ func (s *NodeTrafficLimitEnforcementService) CheckAndEnforceLimitForNode(ctx con } // getTotalTrafficForSubscription calculates total traffic for a subscription -// by combining data from two sources: -// - Last 24 hours: from Redis HourlyTrafficCache -// - Before 24 hours: from MySQL subscription_usage_stats table -func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx context.Context, subscriptionID uint) (uint64, error) { +// within the given traffic period by combining data from two sources: +// - Recent window: from Redis HourlyTrafficCache +// - Historical window: from MySQL subscription_usage_stats table +func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx context.Context, subscriptionID uint, periodStart time.Time) (uint64, error) { now := biztime.NowUTC() // Use start of yesterday's business day as batch/speed boundary (Lambda architecture) @@ -246,10 +249,15 @@ func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx var total uint64 - // Get recent traffic from Redis (yesterday + today, filter by node type) + // Determine Redis query range: max(recentBoundary, periodStart) to now + // This ensures traffic before periodStart (e.g. after manual reset) is excluded. resourceType := subscription.ResourceTypeNode.String() + redisFrom := recentBoundary + if periodStart.After(redisFrom) { + redisFrom = periodStart + } recentTraffic, err := s.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs( - ctx, []uint{subscriptionID}, resourceType, recentBoundary, now, + ctx, []uint{subscriptionID}, resourceType, redisFrom, now, ) if err != nil { // Log warning but don't fail - Redis unavailability shouldn't block limit checks @@ -263,16 +271,21 @@ func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx for _, traffic := range recentTraffic { total += traffic.Total } - s.logger.Debugw("got recent 24h traffic from Redis", + s.logger.Debugw("got recent traffic from Redis", "subscription_id", subscriptionID, "recent_total", total, ) } - // Get historical traffic from MySQL subscription_usage_stats (complete days before yesterday) - // Use daily granularity for historical aggregation, filter by node type + // Get historical traffic from MySQL subscription_usage_stats (within period, before recentBoundary) + historicalEnd := recentBoundary.Add(-time.Second) + if historicalEnd.Before(periodStart) { + // Period started after recentBoundary, no historical data needed + return total, nil + } + historicalTraffic, err := s.usageStatsRepo.GetTotalBySubscriptionIDs( - ctx, []uint{subscriptionID}, &resourceType, subscription.GranularityDaily, time.Time{}, recentBoundary.Add(-time.Second), + ctx, []uint{subscriptionID}, &resourceType, subscription.GranularityDaily, periodStart, historicalEnd, ) if err != nil { s.logger.Warnw("failed to get historical traffic from stats, using Redis data only", diff --git a/internal/application/node/usecases/quotacache.go b/internal/application/node/usecases/quotacache.go new file mode 100644 index 0000000..cea7541 --- /dev/null +++ b/internal/application/node/usecases/quotacache.go @@ -0,0 +1,35 @@ +package usecases + +import ( + "context" + "time" +) + +// CachedQuotaInfo represents the cached subscription quota information. +// This mirrors cache.CachedQuota to avoid import cycle. +type CachedQuotaInfo struct { + Limit int64 // Traffic limit in bytes + PeriodStart time.Time // Billing period start + PeriodEnd time.Time // Billing period end + PlanType string // node/forward/hybrid + Suspended bool // Whether the subscription is suspended + NotFound bool // Null marker: subscription confirmed not found/inactive in DB +} + +// NodeSubscriptionQuotaCache defines the interface for subscription quota caching. +type NodeSubscriptionQuotaCache interface { + // GetQuota retrieves subscription quota information from cache. + // Returns nil if cache does not exist. + GetQuota(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error) + + // MarkSuspended marks the subscription as suspended in cache. + MarkSuspended(ctx context.Context, subscriptionID uint) error +} + +// NodeSubscriptionQuotaLoader defines the interface for lazy loading subscription quota. +// This is used when quota cache miss occurs to load quota from database. +type NodeSubscriptionQuotaLoader interface { + // LoadQuotaByID loads subscription quota from database and caches it. + // Returns the cached quota info, or nil if subscription/plan not found. + LoadQuotaByID(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error) +} diff --git a/internal/application/subscription/dto/dto.go b/internal/application/subscription/dto/dto.go index dfcdbc9..e05f32d 100644 --- a/internal/application/subscription/dto/dto.go +++ b/internal/application/subscription/dto/dto.go @@ -89,6 +89,16 @@ type SubscriptionTokenDTO struct { CreatedAt time.Time `json:"created_at"` } +// SubscriptionStatusCounts represents subscription status aggregation counts. +// Used by admin subscription list to return status counts alongside the list response. +type SubscriptionStatusCounts struct { + Active int64 `json:"active"` + Expired int64 `json:"expired"` + Suspended int64 `json:"suspended"` + PendingPayment int64 `json:"pending_payment"` + ExpiringIn7Days int64 `json:"expiring_in_7_days"` +} + var ( // SubscriptionMapper is a generic mapper for basic conversions. // WARNING: Does not populate OnlineDeviceCount, DeviceLimit, DataUsedBytes, or DataLimitBytes. diff --git a/internal/application/subscription/usecases/createsubscription.go b/internal/application/subscription/usecases/createsubscription.go index d5525a3..8ebf25f 100644 --- a/internal/application/subscription/usecases/createsubscription.go +++ b/internal/application/subscription/usecases/createsubscription.go @@ -241,8 +241,13 @@ func (uc *CreateSubscriptionUseCase) Execute(ctx context.Context, cmd CreateSubs } func (uc *CreateSubscriptionUseCase) calculateEndDate(startDate time.Time, billingCycle vo.BillingCycle) time.Time { - // Use fixed days to ensure consistent subscription periods - // This prevents "drifting" when starting on month boundaries (e.g., Jan 31 -> Feb 28 -> Mar 28) + return CalculateEndDate(startDate, billingCycle) +} + +// CalculateEndDate computes the subscription end date from a start date and billing cycle. +// Uses fixed days to ensure consistent subscription periods and prevent "drifting" +// when starting on month boundaries (e.g., Jan 31 -> Feb 28 -> Mar 28). +func CalculateEndDate(startDate time.Time, billingCycle vo.BillingCycle) time.Time { switch billingCycle { case vo.BillingCycleWeekly: return startDate.Add(7 * 24 * time.Hour) // 7 days diff --git a/internal/application/subscription/usecases/listusersubscriptions.go b/internal/application/subscription/usecases/listusersubscriptions.go index b9b4d77..8c3b846 100644 --- a/internal/application/subscription/usecases/listusersubscriptions.go +++ b/internal/application/subscription/usecases/listusersubscriptions.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "golang.org/x/sync/errgroup" + "github.com/orris-inc/orris/internal/application/subscription/dto" "github.com/orris-inc/orris/internal/domain/subscription" "github.com/orris-inc/orris/internal/domain/user" @@ -24,13 +26,15 @@ type ListUserSubscriptionsQuery struct { PageSize int SortBy string SortDesc *bool // nil means default (true = DESC) + IncludeCounts bool // When true, also return subscription status counts } type ListUserSubscriptionsResult struct { - Subscriptions []*dto.SubscriptionDTO `json:"subscriptions"` - Total int64 `json:"total"` - Page int `json:"page"` - PageSize int `json:"page_size"` + Subscriptions []*dto.SubscriptionDTO `json:"subscriptions"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + StatusCounts *dto.SubscriptionStatusCounts `json:"status_counts,omitempty"` // Present when IncludeCounts is true } type ListUserSubscriptionsUseCase struct { @@ -189,6 +193,60 @@ func (uc *ListUserSubscriptionsUseCase) Execute(ctx context.Context, query ListU dtos = append(dtos, result) } + // Query status counts if requested + var statusCounts *dto.SubscriptionStatusCounts + if query.IncludeCounts { + statusCounts = &dto.SubscriptionStatusCounts{} + g, gctx := errgroup.WithContext(ctx) + + g.Go(func() error { + count, err := uc.subscriptionRepo.CountByStatus(gctx, "active") + if err != nil { + return fmt.Errorf("count active: %w", err) + } + statusCounts.Active = count + return nil + }) + g.Go(func() error { + count, err := uc.subscriptionRepo.CountByStatus(gctx, "expired") + if err != nil { + return fmt.Errorf("count expired: %w", err) + } + statusCounts.Expired = count + return nil + }) + g.Go(func() error { + count, err := uc.subscriptionRepo.CountByStatus(gctx, "suspended") + if err != nil { + return fmt.Errorf("count suspended: %w", err) + } + statusCounts.Suspended = count + return nil + }) + g.Go(func() error { + count, err := uc.subscriptionRepo.CountByStatus(gctx, "pending_payment") + if err != nil { + return fmt.Errorf("count pending_payment: %w", err) + } + statusCounts.PendingPayment = count + return nil + }) + g.Go(func() error { + subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(gctx, 7) + if err != nil { + return fmt.Errorf("find expiring subscriptions: %w", err) + } + statusCounts.ExpiringIn7Days = int64(len(subs)) + return nil + }) + + if err := g.Wait(); err != nil { + uc.logger.Warnw("failed to get subscription status counts", "error", err) + // Non-fatal: return list without counts + statusCounts = nil + } + } + uc.logger.Debugw("subscriptions listed successfully", "user_id", query.UserID, "total", total, @@ -201,5 +259,6 @@ func (uc *ListUserSubscriptionsUseCase) Execute(ctx context.Context, query ListU Total: total, Page: query.Page, PageSize: query.PageSize, + StatusCounts: statusCounts, }, nil } diff --git a/internal/application/subscription/usecases/updateplan.go b/internal/application/subscription/usecases/updateplan.go index 3ac45f9..33274cd 100644 --- a/internal/application/subscription/usecases/updateplan.go +++ b/internal/application/subscription/usecases/updateplan.go @@ -3,6 +3,7 @@ package usecases import ( "context" "fmt" + "math" "github.com/orris-inc/orris/internal/application/subscription/dto" "github.com/orris-inc/orris/internal/domain/subscription" @@ -26,10 +27,13 @@ type PlanChangeNotifier interface { } type UpdatePlanUseCase struct { - planRepo subscription.PlanRepository - pricingRepo subscription.PlanPricingRepository - planChangeNotifier PlanChangeNotifier - logger logger.Interface + planRepo subscription.PlanRepository + pricingRepo subscription.PlanPricingRepository + subscriptionRepo subscription.SubscriptionRepository + planChangeNotifier PlanChangeNotifier + quotaCacheManager QuotaCacheManager + subscriptionNotifier SubscriptionChangeNotifier + logger logger.Interface } // SetPlanChangeNotifier sets the notifier for plan feature changes. @@ -37,6 +41,21 @@ func (uc *UpdatePlanUseCase) SetPlanChangeNotifier(notifier PlanChangeNotifier) uc.planChangeNotifier = notifier } +// SetSubscriptionRepo sets the subscription repository for cascading plan changes. +func (uc *UpdatePlanUseCase) SetSubscriptionRepo(repo subscription.SubscriptionRepository) { + uc.subscriptionRepo = repo +} + +// SetQuotaCacheManager sets the quota cache manager for invalidating cached quotas. +func (uc *UpdatePlanUseCase) SetQuotaCacheManager(manager QuotaCacheManager) { + uc.quotaCacheManager = manager +} + +// SetSubscriptionNotifier sets the notifier for subscription changes. +func (uc *UpdatePlanUseCase) SetSubscriptionNotifier(notifier SubscriptionChangeNotifier) { + uc.subscriptionNotifier = notifier +} + func NewUpdatePlanUseCase( planRepo subscription.PlanRepository, pricingRepo subscription.PlanPricingRepository, @@ -108,6 +127,11 @@ func (uc *UpdatePlanUseCase) Execute( } } + // Invalidate quota cache for all active subscriptions when limits change + if cmd.Limits != nil { + uc.invalidateQuotaCacheForPlan(ctx, planID) + } + // Sync pricing options if provided (delete old, create new) if cmd.Pricings != nil { uc.logger.Infow("syncing pricing options", "plan_id", planID, "count", len(*cmd.Pricings)) @@ -160,6 +184,9 @@ func (uc *UpdatePlanUseCase) Execute( uc.logger.Infow("pricing options synced successfully", "plan_id", planID, "count", len(*cmd.Pricings)) + + // Migrate orphaned subscriptions whose billing cycle is no longer available + uc.migrateOrphanedBillingCycles(ctx, planID, *cmd.Pricings) } // Reload the plan from database to get the accurate state after update @@ -183,3 +210,163 @@ func (uc *UpdatePlanUseCase) Execute( return dto.ToPlanDTOWithPricings(updatedPlan, pricings), nil } + +// invalidateQuotaCacheForPlan invalidates Redis quota cache for all active +// subscriptions on the given plan so enforcement uses updated limits. +func (uc *UpdatePlanUseCase) invalidateQuotaCacheForPlan(ctx context.Context, planID uint) { + if uc.subscriptionRepo == nil || uc.quotaCacheManager == nil { + return + } + + subs, _, err := uc.subscriptionRepo.List(ctx, subscription.SubscriptionFilter{ + PlanID: &planID, + Statuses: []string{string(vo.StatusActive), string(vo.StatusTrialing)}, + Page: 1, + PageSize: 10000, + }) + if err != nil { + uc.logger.Warnw("failed to list subscriptions for quota cache invalidation", + "plan_id", planID, "error", err) + return + } + + for _, sub := range subs { + if err := uc.quotaCacheManager.InvalidateQuota(ctx, sub.ID()); err != nil { + uc.logger.Warnw("failed to invalidate quota cache", + "subscription_id", sub.ID(), "error", err) + } + } + + if len(subs) > 0 { + uc.logger.Infow("quota cache invalidated for plan subscriptions", + "plan_id", planID, "count", len(subs)) + } +} + +// migrateOrphanedBillingCycles migrates subscriptions whose billing cycle +// is no longer available in the new pricing options. +// e.g., plan changed from monthly to lifetime -> existing monthly subscriptions +// are migrated to lifetime with recalculated end_date. +func (uc *UpdatePlanUseCase) migrateOrphanedBillingCycles( + ctx context.Context, planID uint, newPricings []dto.PricingOptionInput, +) { + if uc.subscriptionRepo == nil { + return + } + + // Build set of active billing cycles from new pricings + availableCycles := make(map[string]bool) + for _, p := range newPricings { + if p.IsActive { + availableCycles[p.BillingCycle] = true + } + } + if len(availableCycles) == 0 { + return + } + + // Query affected subscriptions + subs, _, err := uc.subscriptionRepo.List(ctx, subscription.SubscriptionFilter{ + PlanID: &planID, + Statuses: []string{string(vo.StatusActive), string(vo.StatusTrialing), string(vo.StatusSuspended)}, + Page: 1, + PageSize: 10000, + }) + if err != nil { + uc.logger.Warnw("failed to list subscriptions for billing cycle migration", + "plan_id", planID, "error", err) + return + } + + // Find orphaned subscriptions + migratedCount := 0 + for _, sub := range subs { + if sub.BillingCycle() == nil { + continue + } + if availableCycles[sub.BillingCycle().String()] { + continue // billing cycle still available + } + + oldCycle := sub.BillingCycle().String() + targetCycle := findClosestBillingCycle(sub.BillingCycle(), availableCycles) + newEndDate := CalculateEndDate(sub.CurrentPeriodStart(), targetCycle) + + if err := sub.ChangeBillingCycle(targetCycle, newEndDate); err != nil { + uc.logger.Warnw("failed to change billing cycle", + "subscription_id", sub.ID(), "old_cycle", oldCycle, "error", err) + continue + } + + if err := uc.subscriptionRepo.Update(ctx, sub); err != nil { + uc.logger.Warnw("failed to persist subscription after billing cycle migration", + "subscription_id", sub.ID(), "error", err) + continue + } + + // Invalidate quota cache for migrated subscription + if uc.quotaCacheManager != nil { + _ = uc.quotaCacheManager.InvalidateQuota(ctx, sub.ID()) + } + + // Notify nodes of subscription update + if uc.subscriptionNotifier != nil { + _ = uc.subscriptionNotifier.NotifySubscriptionUpdate(ctx, sub) + } + + uc.logger.Infow("subscription billing cycle migrated", + "subscription_id", sub.ID(), + "subscription_sid", sub.SID(), + "old_cycle", oldCycle, + "new_cycle", targetCycle.String(), + "new_end_date", newEndDate, + ) + migratedCount++ + } + + if migratedCount > 0 { + uc.logger.Infow("billing cycle migration completed", + "plan_id", planID, "migrated_count", migratedCount) + } +} + +// findClosestBillingCycle finds the billing cycle from availableCycles +// that is closest in duration to the given oldCycle. +// On tie, prefers the longer duration to minimize disruption. +func findClosestBillingCycle(oldCycle *vo.BillingCycle, availableCycles map[string]bool) vo.BillingCycle { + oldDays := 0 + if oldCycle != nil { + oldDays = oldCycle.Days() + if oldCycle.IsLifetime() { + oldDays = math.MaxInt32 + } + } + + var bestCycle vo.BillingCycle + bestDiff := math.MaxInt32 + bestDays := 0 + + for c := range availableCycles { + parsed, err := vo.ParseBillingCycle(c) + if err != nil { + continue + } + days := parsed.Days() + if parsed.IsLifetime() { + days = math.MaxInt32 + } + + diff := oldDays - days + if diff < 0 { + diff = -diff + } + // Prefer closer; on tie prefer longer duration + if diff < bestDiff || (diff == bestDiff && days > bestDays) { + bestDiff = diff + bestCycle = parsed + bestDays = days + } + } + + return bestCycle +} diff --git a/internal/domain/subscription/repository.go b/internal/domain/subscription/repository.go index 971762a..93b4556 100644 --- a/internal/domain/subscription/repository.go +++ b/internal/domain/subscription/repository.go @@ -34,6 +34,7 @@ type SubscriptionFilter struct { UserID *uint PlanID *uint Status *string + Statuses []string // Multiple status filter (IN clause), takes precedence over Status BillingCycle *string CreatedFrom *time.Time CreatedTo *time.Time diff --git a/internal/domain/subscription/subscription.go b/internal/domain/subscription/subscription.go index 6b4b20e..d0a2ab0 100644 --- a/internal/domain/subscription/subscription.go +++ b/internal/domain/subscription/subscription.go @@ -570,6 +570,23 @@ func (s *Subscription) ChangePlan(newPlanID uint) error { return nil } +// ChangeBillingCycle updates the billing cycle and end date. +// Used when plan pricing changes make the current billing cycle unavailable. +func (s *Subscription) ChangeBillingCycle(newCycle vo.BillingCycle, newEndDate time.Time) error { + if !newCycle.IsValid() { + return fmt.Errorf("invalid billing cycle: %s", newCycle) + } + if newEndDate.Before(s.startDate) { + return fmt.Errorf("new end date must be after start date") + } + s.billingCycle = &newCycle + s.endDate = newEndDate + s.currentPeriodEnd = newEndDate + s.updatedAt = biztime.NowUTC() + s.version++ + return nil +} + // IsExpired checks if subscription is expired func (s *Subscription) IsExpired() bool { return biztime.NowUTC().After(s.endDate) diff --git a/internal/domain/subscription/trafficperiod.go b/internal/domain/subscription/trafficperiod.go index 927cd0c..5a95c6f 100644 --- a/internal/domain/subscription/trafficperiod.go +++ b/internal/domain/subscription/trafficperiod.go @@ -44,6 +44,9 @@ func GetTrafficResetMode(plan *Plan) TrafficResetMode { // For calendar_month: uses business timezone month boundaries (backward compatible default). // For billing_cycle: uses the subscription's CurrentPeriodStart/CurrentPeriodEnd. // Falls back to calendar_month if sub is nil. +// +// In both modes, if the subscription's CurrentPeriodStart is after the resolved period start +// (e.g. after a manual usage reset), it is used as a floor to exclude pre-reset traffic. func ResolveTrafficPeriod(plan *Plan, sub *Subscription) TrafficPeriod { mode := GetTrafficResetMode(plan) @@ -56,8 +59,16 @@ func ResolveTrafficPeriod(plan *Plan, sub *Subscription) TrafficPeriod { // Fallback: calendar month (default, or billing_cycle with nil sub) bizNow := biztime.ToBizTimezone(biztime.NowUTC()) - return TrafficPeriod{ + period := TrafficPeriod{ Start: biztime.StartOfMonthUTC(bizNow.Year(), bizNow.Month()), End: biztime.EndOfMonthUTC(bizNow.Year(), bizNow.Month()), } + + // If subscription's period start is after the calendar month start (e.g. after manual + // usage reset), use it as the floor so pre-reset traffic is excluded. + if sub != nil && sub.CurrentPeriodStart().After(period.Start) { + period.Start = sub.CurrentPeriodStart() + } + + return period } diff --git a/internal/interfaces/adapters/nodeutil/builder.go b/internal/infrastructure/persistence/nodeutil/builder.go similarity index 100% rename from internal/interfaces/adapters/nodeutil/builder.go rename to internal/infrastructure/persistence/nodeutil/builder.go diff --git a/internal/interfaces/adapters/nodeutil/builder_test.go b/internal/infrastructure/persistence/nodeutil/builder_test.go similarity index 100% rename from internal/interfaces/adapters/nodeutil/builder_test.go rename to internal/infrastructure/persistence/nodeutil/builder_test.go diff --git a/internal/interfaces/adapters/nodeutil/configloader.go b/internal/infrastructure/persistence/nodeutil/configloader.go similarity index 100% rename from internal/interfaces/adapters/nodeutil/configloader.go rename to internal/infrastructure/persistence/nodeutil/configloader.go diff --git a/internal/interfaces/adapters/nodeutil/forwardbuilder.go b/internal/infrastructure/persistence/nodeutil/forwardbuilder.go similarity index 100% rename from internal/interfaces/adapters/nodeutil/forwardbuilder.go rename to internal/infrastructure/persistence/nodeutil/forwardbuilder.go diff --git a/internal/interfaces/adapters/nodeutil/forwardbuilder_test.go b/internal/infrastructure/persistence/nodeutil/forwardbuilder_test.go similarity index 100% rename from internal/interfaces/adapters/nodeutil/forwardbuilder_test.go rename to internal/infrastructure/persistence/nodeutil/forwardbuilder_test.go diff --git a/internal/interfaces/adapters/noderepositoryadapter.go b/internal/infrastructure/repository/nodesubscriptionrepository.go similarity index 87% rename from internal/interfaces/adapters/noderepositoryadapter.go rename to internal/infrastructure/repository/nodesubscriptionrepository.go index 54a57e7..1eeaf11 100644 --- a/internal/interfaces/adapters/noderepositoryadapter.go +++ b/internal/infrastructure/repository/nodesubscriptionrepository.go @@ -1,4 +1,4 @@ -package adapters +package repository import ( "context" @@ -10,7 +10,7 @@ import ( "github.com/orris-inc/orris/internal/domain/node" nodevo "github.com/orris-inc/orris/internal/domain/node/valueobjects" "github.com/orris-inc/orris/internal/domain/subscription/valueobjects" - "github.com/orris-inc/orris/internal/interfaces/adapters/nodeutil" + "github.com/orris-inc/orris/internal/infrastructure/persistence/nodeutil" "github.com/orris-inc/orris/internal/infrastructure/persistence/models" "github.com/orris-inc/orris/internal/shared/logger" "github.com/orris-inc/orris/internal/shared/utils/jsonutil" @@ -18,21 +18,19 @@ import ( "github.com/orris-inc/orris/internal/shared/utils/setutil" ) -// NodeRepository defines the interface for node persistence operations -type NodeRepository interface { - GetByToken(ctx context.Context, tokenHash string) (*node.Node, error) -} - -type NodeRepositoryAdapter struct { - nodeRepo NodeRepository +// NodeSubscriptionRepository implements usecases.NodeRepository by querying +// subscriptions, plans, resource groups, and nodes to build subscription node lists. +type NodeSubscriptionRepository struct { + nodeRepo node.NodeRepository forwardRuleRepo forward.Repository db *gorm.DB logger logger.Interface configLoader *nodeutil.ConfigLoader } -func NewNodeRepositoryAdapter(nodeRepo NodeRepository, forwardRuleRepo forward.Repository, db *gorm.DB, logger logger.Interface) *NodeRepositoryAdapter { - return &NodeRepositoryAdapter{ +// NewNodeSubscriptionRepository creates a new NodeSubscriptionRepository. +func NewNodeSubscriptionRepository(nodeRepo node.NodeRepository, forwardRuleRepo forward.Repository, db *gorm.DB, logger logger.Interface) *NodeSubscriptionRepository { + return &NodeSubscriptionRepository{ nodeRepo: nodeRepo, forwardRuleRepo: forwardRuleRepo, db: db, @@ -41,7 +39,7 @@ func NewNodeRepositoryAdapter(nodeRepo NodeRepository, forwardRuleRepo forward.R } } -func (r *NodeRepositoryAdapter) GetBySubscriptionToken(ctx context.Context, linkToken string, mode string) ([]*usecases.Node, error) { +func (r *NodeSubscriptionRepository) GetBySubscriptionToken(ctx context.Context, linkToken string, mode string) ([]*usecases.Node, error) { var subscriptionModel models.SubscriptionModel // Query subscription by link_token @@ -148,7 +146,7 @@ func (r *NodeRepositoryAdapter) GetBySubscriptionToken(ctx context.Context, link } } -func (r *NodeRepositoryAdapter) GetByTokenHash(ctx context.Context, tokenHash string) (usecases.NodeData, error) { +func (r *NodeSubscriptionRepository) GetByTokenHash(ctx context.Context, tokenHash string) (usecases.NodeData, error) { nodeEntity, err := r.nodeRepo.GetByToken(ctx, tokenHash) if err != nil { r.logger.Warnw("failed to get node by token hash", @@ -172,7 +170,7 @@ func (r *NodeRepositoryAdapter) GetByTokenHash(ctx context.Context, tokenHash st // Rules are selected based on resource group membership, regardless of whether their // target nodes are in the same resource groups. // Uses Repository method to ensure proper scope isolation (system rules only). -func (r *NodeRepositoryAdapter) getForwardedNodes(ctx context.Context, groupIDs []uint, originNodeMap map[uint]*usecases.Node) []*usecases.Node { +func (r *NodeSubscriptionRepository) getForwardedNodes(ctx context.Context, groupIDs []uint, originNodeMap map[uint]*usecases.Node) []*usecases.Node { if len(groupIDs) == 0 { return nil } @@ -238,7 +236,7 @@ func (r *NodeRepositoryAdapter) getForwardedNodes(ctx context.Context, groupIDs // For forward plans, users see their own forward rules as subscription nodes // Forward plans have no "origin" nodes - all nodes are forwarded by nature // Uses Repository method to ensure proper scope isolation (user's own rules only). -func (r *NodeRepositoryAdapter) getForwardPlanNodes(ctx context.Context, subscriptionID uint, userID uint, mode string) ([]*usecases.Node, error) { +func (r *NodeSubscriptionRepository) getForwardPlanNodes(ctx context.Context, subscriptionID uint, userID uint, mode string) ([]*usecases.Node, error) { // Forward plans have no origin nodes, return empty for origin mode if mode == usecases.NodeModeOrigin { r.logger.Debugw("forward plan has no origin nodes", "user_id", userID, "mode", mode) @@ -262,7 +260,7 @@ func (r *NodeRepositoryAdapter) getForwardPlanNodes(ctx context.Context, subscri // getHybridPlanNodes returns nodes for hybrid plan subscriptions // For hybrid plans, users see both resource group nodes AND their own forward rules -func (r *NodeRepositoryAdapter) getHybridPlanNodes(ctx context.Context, subscriptionID uint, userID uint, planID uint, mode string) ([]*usecases.Node, error) { +func (r *NodeSubscriptionRepository) getHybridPlanNodes(ctx context.Context, subscriptionID uint, userID uint, planID uint, mode string) ([]*usecases.Node, error) { // Step 1: Get resource group nodes (same as node plan logic) // Query resource group IDs for this plan var groupIDs []uint @@ -349,7 +347,7 @@ func (r *NodeRepositoryAdapter) getHybridPlanNodes(ctx context.Context, subscrip } // buildNodesWithConfigs builds use case nodes from node models with protocol configs loaded -func (r *NodeRepositoryAdapter) buildNodesWithConfigs(ctx context.Context, nodeModels []models.NodeModel) []*usecases.Node { +func (r *NodeSubscriptionRepository) buildNodesWithConfigs(ctx context.Context, nodeModels []models.NodeModel) []*usecases.Node { configs := r.configLoader.LoadProtocolConfigs(ctx, nodeModels) nodes := make([]*usecases.Node, 0, len(nodeModels)) @@ -365,7 +363,7 @@ func (r *NodeRepositoryAdapter) buildNodesWithConfigs(ctx context.Context, nodeM // getUserForwardNodes retrieves user's forward rules with target nodes as subscription nodes // Only returns forward rules where target_node_id is NOT NULL // Uses Repository method to ensure proper scope isolation (user's own rules only). -func (r *NodeRepositoryAdapter) getUserForwardNodes(ctx context.Context, userID uint) ([]*usecases.Node, error) { +func (r *NodeSubscriptionRepository) getUserForwardNodes(ctx context.Context, userID uint) ([]*usecases.Node, error) { forwardRules, err := r.forwardRuleRepo.ListUserRulesForDelivery(ctx, userID) if err != nil { r.logger.Errorw("failed to query user forward rules", "user_id", userID, "error", err) @@ -395,7 +393,7 @@ func (r *NodeRepositoryAdapter) getUserForwardNodes(ctx context.Context, userID } // collectIDsFromRules extracts node IDs and agent IDs from forward rules -func (r *NodeRepositoryAdapter) collectIDsFromRules(rules []*forward.ForwardRule) (nodeIDs, agentIDs []uint) { +func (r *NodeSubscriptionRepository) collectIDsFromRules(rules []*forward.ForwardRule) (nodeIDs, agentIDs []uint) { nodeIDSet := setutil.NewUintSet() agentIDSet := setutil.NewUintSet() @@ -412,7 +410,7 @@ func (r *NodeRepositoryAdapter) collectIDsFromRules(rules []*forward.ForwardRule } // collectAgentIDsFromRules extracts agent IDs from forward rules (skipping external rules). -func (r *NodeRepositoryAdapter) collectAgentIDsFromRules(rules []*forward.ForwardRule) []uint { +func (r *NodeSubscriptionRepository) collectAgentIDsFromRules(rules []*forward.ForwardRule) []uint { agentIDSet := setutil.NewUintSet() for _, rule := range rules { if rule.AgentID() > 0 { @@ -423,7 +421,7 @@ func (r *NodeRepositoryAdapter) collectAgentIDsFromRules(rules []*forward.Forwar } // queryActiveNodes queries active nodes by IDs and returns both slice and map -func (r *NodeRepositoryAdapter) queryActiveNodes(ctx context.Context, nodeIDs []uint) ([]models.NodeModel, map[uint]*models.NodeModel, error) { +func (r *NodeSubscriptionRepository) queryActiveNodes(ctx context.Context, nodeIDs []uint) ([]models.NodeModel, map[uint]*models.NodeModel, error) { nodeMap := make(map[uint]*models.NodeModel) if len(nodeIDs) == 0 { return nil, nodeMap, nil @@ -447,7 +445,7 @@ func (r *NodeRepositoryAdapter) queryActiveNodes(ctx context.Context, nodeIDs [] } // loadForwardAgents loads forward agents by IDs and returns a map -func (r *NodeRepositoryAdapter) loadForwardAgents(ctx context.Context, agentIDs []uint) map[uint]*models.ForwardAgentModel { +func (r *NodeSubscriptionRepository) loadForwardAgents(ctx context.Context, agentIDs []uint) map[uint]*models.ForwardAgentModel { agentMap := make(map[uint]*models.ForwardAgentModel) if len(agentIDs) == 0 { return agentMap diff --git a/internal/infrastructure/repository/subscriptionrepository.go b/internal/infrastructure/repository/subscriptionrepository.go index 7c0cdea..a39a494 100644 --- a/internal/infrastructure/repository/subscriptionrepository.go +++ b/internal/infrastructure/repository/subscriptionrepository.go @@ -460,7 +460,9 @@ func (r *SubscriptionRepositoryImpl) List(ctx context.Context, filter subscripti if filter.PlanID != nil { query = query.Where("plan_id = ?", *filter.PlanID) } - if filter.Status != nil { + if len(filter.Statuses) > 0 { + query = query.Where("status IN ?", filter.Statuses) + } else if filter.Status != nil { query = query.Where("status = ?", *filter.Status) } if filter.BillingCycle != nil { diff --git a/internal/interfaces/adapters/subscriptiontokenvalidator.go b/internal/infrastructure/repository/subscriptiontokenvalidator.go similarity index 87% rename from internal/interfaces/adapters/subscriptiontokenvalidator.go rename to internal/infrastructure/repository/subscriptiontokenvalidator.go index 9e32121..2c893ed 100644 --- a/internal/interfaces/adapters/subscriptiontokenvalidator.go +++ b/internal/infrastructure/repository/subscriptiontokenvalidator.go @@ -1,4 +1,4 @@ -package adapters +package repository import ( "context" @@ -13,13 +13,15 @@ import ( "github.com/orris-inc/orris/internal/shared/logger" ) -type SubscriptionTokenValidatorAdapter struct { +// SubscriptionTokenValidator validates subscription tokens by querying the database directly. +type SubscriptionTokenValidator struct { db *gorm.DB logger logger.Interface } -func NewSubscriptionTokenValidatorAdapter(db *gorm.DB, logger logger.Interface) *SubscriptionTokenValidatorAdapter { - return &SubscriptionTokenValidatorAdapter{ +// NewSubscriptionTokenValidator creates a new SubscriptionTokenValidator. +func NewSubscriptionTokenValidator(db *gorm.DB, logger logger.Interface) *SubscriptionTokenValidator { + return &SubscriptionTokenValidator{ db: db, logger: logger, } @@ -54,7 +56,7 @@ func subscriptionStatusError(status string) error { } } -func (v *SubscriptionTokenValidatorAdapter) Validate(ctx context.Context, linkToken string) error { +func (v *SubscriptionTokenValidator) Validate(ctx context.Context, linkToken string) error { var subscriptionModel models.SubscriptionModel if err := v.db.WithContext(ctx). Where("link_token = ?", linkToken). @@ -82,7 +84,7 @@ func (v *SubscriptionTokenValidatorAdapter) Validate(ctx context.Context, linkTo return nil } -func (v *SubscriptionTokenValidatorAdapter) ValidateAndGetSubscription(ctx context.Context, linkToken string) (*nodeusecases.SubscriptionValidationResult, error) { +func (v *SubscriptionTokenValidator) ValidateAndGetSubscription(ctx context.Context, linkToken string) (*nodeusecases.SubscriptionValidationResult, error) { var subscriptionModel models.SubscriptionModel if err := v.db.WithContext(ctx). Where("link_token = ?", linkToken). diff --git a/internal/interfaces/adapters/nodequotaadapters.go b/internal/interfaces/adapters/nodequotaadapters.go index 8c4865c..f15658c 100644 --- a/internal/interfaces/adapters/nodequotaadapters.go +++ b/internal/interfaces/adapters/nodequotaadapters.go @@ -4,9 +4,9 @@ import ( "context" "time" + nodeUsecases "github.com/orris-inc/orris/internal/application/node/usecases" "github.com/orris-inc/orris/internal/domain/subscription" "github.com/orris-inc/orris/internal/infrastructure/cache" - nodeHandlers "github.com/orris-inc/orris/internal/interfaces/http/handlers/node" "github.com/orris-inc/orris/internal/shared/biztime" "github.com/orris-inc/orris/internal/shared/logger" ) @@ -29,7 +29,7 @@ func NewNodeSubscriptionQuotaCacheAdapter( } // GetQuota retrieves subscription quota from cache -func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscriptionID uint) (*nodeHandlers.CachedQuotaInfo, error) { +func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscriptionID uint) (*nodeUsecases.CachedQuotaInfo, error) { cached, err := a.cache.GetQuota(ctx, subscriptionID) if err != nil { return nil, err @@ -38,7 +38,7 @@ func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscr return nil, nil } - return &nodeHandlers.CachedQuotaInfo{ + return &nodeUsecases.CachedQuotaInfo{ Limit: cached.Limit, PeriodStart: cached.PeriodStart, PeriodEnd: cached.PeriodEnd, @@ -79,7 +79,7 @@ func NewNodeSubscriptionQuotaLoaderAdapter( // LoadQuotaByID loads subscription quota from database and caches it. // When subscription is not found or inactive, a null marker is cached to prevent // repeated DB lookups (cache penetration protection). -func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, subscriptionID uint) (*nodeHandlers.CachedQuotaInfo, error) { +func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, subscriptionID uint) (*nodeUsecases.CachedQuotaInfo, error) { // Get subscription from database sub, err := a.subscriptionRepo.GetByID(ctx, subscriptionID) if err != nil { @@ -134,7 +134,7 @@ func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, ) } - return &nodeHandlers.CachedQuotaInfo{ + return &nodeUsecases.CachedQuotaInfo{ Limit: cachedQuota.Limit, PeriodStart: cachedQuota.PeriodStart, PeriodEnd: cachedQuota.PeriodEnd, @@ -186,10 +186,15 @@ func (a *NodeSubscriptionUsageReaderAdapter) GetCurrentPeriodUsage( var total int64 - // Get recent traffic from Redis (yesterday + today, filter by node type) + // Get recent traffic from Redis (yesterday + today, filter by node type). + // Use max(recentBoundary, periodStart) so traffic before periodStart (e.g. after reset) is excluded. resourceType := subscription.ResourceTypeNode.String() + redisFrom := recentBoundary + if periodStart.After(redisFrom) { + redisFrom = periodStart + } recentTraffic, err := a.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs( - ctx, []uint{subscriptionID}, resourceType, recentBoundary, now, + ctx, []uint{subscriptionID}, resourceType, redisFrom, now, ) if err != nil { a.logger.Warnw("failed to get recent traffic from Redis", diff --git a/internal/interfaces/http/container.go b/internal/interfaces/http/container.go index 2a5a781..4220c83 100644 --- a/internal/interfaces/http/container.go +++ b/internal/interfaces/http/container.go @@ -22,6 +22,7 @@ import ( "github.com/orris-inc/orris/internal/infrastructure/email" infraPayment "github.com/orris-inc/orris/internal/infrastructure/payment" "github.com/orris-inc/orris/internal/infrastructure/pubsub" + "github.com/orris-inc/orris/internal/infrastructure/repository" "github.com/orris-inc/orris/internal/infrastructure/scheduler" "github.com/orris-inc/orris/internal/infrastructure/services" telegramInfra "github.com/orris-inc/orris/internal/infrastructure/telegram" @@ -100,8 +101,8 @@ type Container struct { adminNotificationServiceDDD *telegramAdminApp.ServiceDDD // Cross-cutting adapters and services (created in one section, used in another) - nodeRepoAdapter *adapters.NodeRepositoryAdapter - tokenValidator *adapters.SubscriptionTokenValidatorAdapter + nodeRepoAdapter *repository.NodeSubscriptionRepository + tokenValidator *repository.SubscriptionTokenValidator templateLoader *template.SubscriptionTemplateLoader nodeStatusQuerier *adapters.NodeSystemStatusQuerierAdapter forwardAgentReleaseService *services.GitHubReleaseService diff --git a/internal/interfaces/http/handlers/admin/subscription/handler.go b/internal/interfaces/http/handlers/admin/subscription/handler.go index adc98ad..40a4910 100644 --- a/internal/interfaces/http/handlers/admin/subscription/handler.go +++ b/internal/interfaces/http/handlers/admin/subscription/handler.go @@ -270,6 +270,9 @@ func (h *Handler) List(c *gin.Context) { sortDesc = &desc } + // Parse include_counts parameter + includeCounts := c.Query("include_counts") == "true" + query := usecases.ListUserSubscriptionsQuery{ UserID: userID, PlanID: planID, @@ -282,6 +285,7 @@ func (h *Handler) List(c *gin.Context) { PageSize: p.PageSize, SortBy: sortBy, SortDesc: sortDesc, + IncludeCounts: includeCounts, } result, err := h.listUseCase.Execute(c.Request.Context(), query) @@ -291,6 +295,20 @@ func (h *Handler) List(c *gin.Context) { return } + // Build response with optional status_counts + if result.StatusCounts != nil { + response := map[string]any{ + "items": result.Subscriptions, + "total": result.Total, + "page": result.Page, + "page_size": result.PageSize, + "total_pages": utils.TotalPages(result.Total, result.PageSize), + "status_counts": result.StatusCounts, + } + utils.SuccessResponse(c, http.StatusOK, "", response) + return + } + utils.ListSuccessResponse(c, result.Subscriptions, result.Total, result.Page, result.PageSize) } diff --git a/internal/interfaces/http/handlers/node/hubhandler.go b/internal/interfaces/http/handlers/node/hubhandler.go index 29c1912..657fb29 100644 --- a/internal/interfaces/http/handlers/node/hubhandler.go +++ b/internal/interfaces/http/handlers/node/hubhandler.go @@ -90,34 +90,13 @@ type NodeSubscriptionUsageReader interface { } // NodeSubscriptionQuotaCache defines the interface for subscription quota caching. -// This is a local interface to avoid import cycle with cache package. -type NodeSubscriptionQuotaCache interface { - // GetQuota retrieves subscription quota information from cache. - // Returns nil if cache does not exist. - GetQuota(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error) - - // MarkSuspended marks the subscription as suspended in cache. - MarkSuspended(ctx context.Context, subscriptionID uint) error -} +type NodeSubscriptionQuotaCache = usecases.NodeSubscriptionQuotaCache // CachedQuotaInfo represents the cached subscription quota information. -// This mirrors cache.CachedQuota to avoid import cycle. -type CachedQuotaInfo struct { - Limit int64 // Traffic limit in bytes - PeriodStart time.Time // Billing period start - PeriodEnd time.Time // Billing period end - PlanType string // node/forward/hybrid - Suspended bool // Whether the subscription is suspended - NotFound bool // Null marker: subscription confirmed not found/inactive in DB -} +type CachedQuotaInfo = usecases.CachedQuotaInfo // NodeSubscriptionQuotaLoader defines the interface for lazy loading subscription quota. -// This is used when quota cache miss occurs to load quota from database. -type NodeSubscriptionQuotaLoader interface { - // LoadQuotaByID loads subscription quota from database and caches it. - // Returns the cached quota info, or nil if subscription/plan not found. - LoadQuotaByID(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error) -} +type NodeSubscriptionQuotaLoader = usecases.NodeSubscriptionQuotaLoader // NodeHubHandler handles WebSocket connections for node agents. type NodeHubHandler struct { diff --git a/internal/interfaces/http/wire_services.go b/internal/interfaces/http/wire_services.go index 961104b..1151d7d 100644 --- a/internal/interfaces/http/wire_services.go +++ b/internal/interfaces/http/wire_services.go @@ -311,8 +311,8 @@ func (c *Container) initNode() { hdlrs := c.hdlrs // Initialize adapters - c.nodeRepoAdapter = adapters.NewNodeRepositoryAdapter(repos.nodeRepoImpl, repos.forwardRuleRepo, db, log) - c.tokenValidator = adapters.NewSubscriptionTokenValidatorAdapter(db, log) + c.nodeRepoAdapter = repository.NewNodeSubscriptionRepository(repos.nodeRepoImpl, repos.forwardRuleRepo, db, log) + c.tokenValidator = repository.NewSubscriptionTokenValidator(db, log) c.nodeStatusQuerier = adapters.NewNodeSystemStatusQuerierAdapter(c.redis, log) // Initialize GitHub release services for version checking @@ -1359,6 +1359,9 @@ func (c *Container) initCallbacksAndNotifiers() { // Set plan change notifier to propagate plan feature changes (e.g. device_limit) to nodes ucs.updatePlanUC.SetPlanChangeNotifier(c.subscriptionSyncService) + ucs.updatePlanUC.SetSubscriptionRepo(repos.subscriptionRepo) + ucs.updatePlanUC.SetQuotaCacheManager(c.quotaCacheSyncService) + ucs.updatePlanUC.SetSubscriptionNotifier(c.subscriptionSyncService) } // ============================================================