fix: enforce traffic period filtering in all limit check paths

Traffic reset was not taking effect because multiple enforcement paths
ignored the subscription's currentPeriodStart when calculating usage:

- ResolveTrafficPeriod now uses currentPeriodStart as a floor for
  calendar_month mode, so manual resets exclude pre-reset traffic
- Node and forward enforcement services now query traffic within the
  resolved period instead of from time zero
- NodeSubscriptionUsageReaderAdapter Redis query respects periodStart

Also refactors node subscription generation and token validation into
infrastructure layer, adds include_counts support for admin subscription
list endpoint, and extracts shared nodeutil package.
This commit is contained in:
orris-inc
2026-03-11 11:57:14 +08:00
parent 82cb5bb20f
commit c5dff3e26b
23 changed files with 471 additions and 106 deletions

View File

@@ -74,7 +74,7 @@ func (s *TrafficLimitEnforcementService) CheckAndEnforceLimit(ctx context.Contex
// Find the highest traffic limit across all Forward-type subscriptions
// and collect their subscription IDs for traffic query
trafficLimit, hasLimit, forwardSubscriptionIDs, err := s.getHighestTrafficLimitAndIDs(ctx, activeSubscriptions)
trafficLimit, hasLimit, forwardSubscriptionIDs, periodStart, err := s.getHighestTrafficLimitAndIDs(ctx, activeSubscriptions)
if err != nil {
s.logger.Errorw("failed to determine traffic limit",
"user_id", userID,
@@ -99,8 +99,8 @@ func (s *TrafficLimitEnforcementService) CheckAndEnforceLimit(ctx context.Contex
return nil
}
// Get user's total forward traffic by combining Redis (recent 24h) and MySQL (historical)
usedTraffic, err := s.getTotalTrafficForSubscriptions(ctx, forwardSubscriptionIDs)
// Get user's total forward traffic within the resolved traffic period
usedTraffic, err := s.getTotalTrafficForSubscriptions(ctx, forwardSubscriptionIDs, periodStart)
if err != nil {
s.logger.Errorw("failed to get total traffic for user",
"user_id", userID,
@@ -262,12 +262,14 @@ func (s *TrafficLimitEnforcementService) OnTrafficUpdate(ctx context.Context, ru
// getHighestTrafficLimitAndIDs returns the highest traffic limit across all Forward-type subscriptions
// and collects their subscription IDs for traffic query.
// Returns (limit, hasLimit, subscriptionIDs, error) where hasLimit is false if any subscription has unlimited traffic.
// Returns (limit, hasLimit, subscriptionIDs, periodStart, error) where hasLimit is false if any subscription has unlimited traffic.
// periodStart is the latest traffic period start across all forward subscriptions (respects manual resets).
// Only considers subscriptions with PlanType = "forward".
func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx context.Context, subscriptions []*subscription.Subscription) (uint64, bool, []uint, error) {
func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx context.Context, subscriptions []*subscription.Subscription) (uint64, bool, []uint, time.Time, error) {
var highestLimit uint64
hasLimit := false
var forwardSubscriptionIDs []uint
var latestPeriodStart time.Time
// Collect plan IDs for batch query
planIDs := make([]uint, 0, len(subscriptions))
@@ -279,7 +281,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
plansList, err := s.planRepo.GetByIDs(ctx, planIDs)
if err != nil {
s.logger.Errorw("failed to batch fetch plans", "error", err)
return 0, false, nil, err
return 0, false, nil, time.Time{}, err
}
// Convert to map for quick lookup
@@ -312,6 +314,13 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
// Collect Forward-type subscription ID
forwardSubscriptionIDs = append(forwardSubscriptionIDs, sub.ID())
// Resolve traffic period and track the latest period start across all forward subscriptions.
// This ensures manual usage resets are respected: traffic before the reset is excluded.
period := subscription.ResolveTrafficPeriod(plan, sub)
if latestPeriodStart.IsZero() || period.Start.After(latestPeriodStart) {
latestPeriodStart = period.Start
}
// Determine traffic limit: subscription override takes priority over plan
var limit uint64
if sub.TrafficLimitOverride() != nil {
@@ -323,7 +332,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
"subscription_id", sub.ID(),
"plan_id", sub.PlanID(),
)
return 0, false, forwardSubscriptionIDs, nil // Unlimited traffic - don't enforce
return 0, false, forwardSubscriptionIDs, latestPeriodStart, nil
}
var err error
@@ -344,7 +353,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
"subscription_id", sub.ID(),
"plan_id", sub.PlanID(),
)
return 0, false, forwardSubscriptionIDs, nil
return 0, false, forwardSubscriptionIDs, latestPeriodStart, nil
}
// Track the highest limit
@@ -354,7 +363,7 @@ func (s *TrafficLimitEnforcementService) getHighestTrafficLimitAndIDs(ctx contex
}
}
return highestLimit, hasLimit, forwardSubscriptionIDs, nil
return highestLimit, hasLimit, forwardSubscriptionIDs, latestPeriodStart, nil
}
// getForwardTrafficLimit extracts the traffic limit from a plan.
@@ -370,10 +379,10 @@ func (s *TrafficLimitEnforcementService) getForwardTrafficLimit(plan *subscripti
}
// getTotalTrafficForSubscriptions calculates total traffic for given subscription IDs
// by combining data from two sources:
// - Last 24 hours: from Redis HourlyTrafficCache
// - Before 24 hours: from MySQL subscription_usage_stats table
func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx context.Context, subscriptionIDs []uint) (uint64, error) {
// within the given traffic period by combining data from two sources:
// - Recent window: from Redis HourlyTrafficCache
// - Historical window: from MySQL subscription_usage_stats table
func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx context.Context, subscriptionIDs []uint, periodStart time.Time) (uint64, error) {
if len(subscriptionIDs) == 0 {
return 0, nil
}
@@ -386,10 +395,15 @@ func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx con
var total uint64
// Get recent traffic from Redis (yesterday + today, filter by forward_rule type)
// Determine Redis query range: max(recentBoundary, periodStart) to now
// This ensures traffic before periodStart (e.g. after manual reset) is excluded.
resourceType := subscription.ResourceTypeForwardRule.String()
redisFrom := recentBoundary
if periodStart.After(redisFrom) {
redisFrom = periodStart
}
recentTraffic, err := s.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs(
ctx, subscriptionIDs, resourceType, recentBoundary, now,
ctx, subscriptionIDs, resourceType, redisFrom, now,
)
if err != nil {
// Log warning but don't fail - Redis unavailability shouldn't block limit checks
@@ -403,16 +417,21 @@ func (s *TrafficLimitEnforcementService) getTotalTrafficForSubscriptions(ctx con
for _, traffic := range recentTraffic {
total += traffic.Total
}
s.logger.Debugw("got recent 24h traffic from Redis",
s.logger.Debugw("got recent traffic from Redis",
"subscription_ids_count", len(subscriptionIDs),
"recent_total", total,
)
}
// Get historical traffic from MySQL subscription_usage_stats (complete days before yesterday)
// Use daily granularity for historical aggregation, filter by forward_rule type
// Get historical traffic from MySQL subscription_usage_stats (within period, before recentBoundary)
historicalEnd := recentBoundary.Add(-time.Second)
if historicalEnd.Before(periodStart) {
// Period started after recentBoundary, no historical data needed
return total, nil
}
historicalTraffic, err := s.usageStatsRepo.GetTotalBySubscriptionIDs(
ctx, subscriptionIDs, &resourceType, subscription.GranularityDaily, time.Time{}, recentBoundary.Add(-time.Second),
ctx, subscriptionIDs, &resourceType, subscription.GranularityDaily, periodStart, historicalEnd,
)
if err != nil {
s.logger.Warnw("failed to get historical traffic from stats, using Redis data only",

View File

@@ -146,8 +146,11 @@ func (s *NodeTrafficLimitEnforcementService) CheckAndEnforceLimitForNode(ctx con
return nil
}
// Get total traffic for this subscription
usedTraffic, err := s.getTotalTrafficForSubscription(ctx, subscriptionID)
// Resolve traffic period so we only count traffic within the current period
period := subscription.ResolveTrafficPeriod(plan, sub)
// Get total traffic for this subscription within the resolved period
usedTraffic, err := s.getTotalTrafficForSubscription(ctx, subscriptionID, period.Start)
if err != nil {
s.logger.Errorw("failed to get total traffic",
"subscription_id", subscriptionID,
@@ -234,10 +237,10 @@ func (s *NodeTrafficLimitEnforcementService) CheckAndEnforceLimitForNode(ctx con
}
// getTotalTrafficForSubscription calculates total traffic for a subscription
// by combining data from two sources:
// - Last 24 hours: from Redis HourlyTrafficCache
// - Before 24 hours: from MySQL subscription_usage_stats table
func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx context.Context, subscriptionID uint) (uint64, error) {
// within the given traffic period by combining data from two sources:
// - Recent window: from Redis HourlyTrafficCache
// - Historical window: from MySQL subscription_usage_stats table
func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx context.Context, subscriptionID uint, periodStart time.Time) (uint64, error) {
now := biztime.NowUTC()
// Use start of yesterday's business day as batch/speed boundary (Lambda architecture)
@@ -246,10 +249,15 @@ func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx
var total uint64
// Get recent traffic from Redis (yesterday + today, filter by node type)
// Determine Redis query range: max(recentBoundary, periodStart) to now
// This ensures traffic before periodStart (e.g. after manual reset) is excluded.
resourceType := subscription.ResourceTypeNode.String()
redisFrom := recentBoundary
if periodStart.After(redisFrom) {
redisFrom = periodStart
}
recentTraffic, err := s.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs(
ctx, []uint{subscriptionID}, resourceType, recentBoundary, now,
ctx, []uint{subscriptionID}, resourceType, redisFrom, now,
)
if err != nil {
// Log warning but don't fail - Redis unavailability shouldn't block limit checks
@@ -263,16 +271,21 @@ func (s *NodeTrafficLimitEnforcementService) getTotalTrafficForSubscription(ctx
for _, traffic := range recentTraffic {
total += traffic.Total
}
s.logger.Debugw("got recent 24h traffic from Redis",
s.logger.Debugw("got recent traffic from Redis",
"subscription_id", subscriptionID,
"recent_total", total,
)
}
// Get historical traffic from MySQL subscription_usage_stats (complete days before yesterday)
// Use daily granularity for historical aggregation, filter by node type
// Get historical traffic from MySQL subscription_usage_stats (within period, before recentBoundary)
historicalEnd := recentBoundary.Add(-time.Second)
if historicalEnd.Before(periodStart) {
// Period started after recentBoundary, no historical data needed
return total, nil
}
historicalTraffic, err := s.usageStatsRepo.GetTotalBySubscriptionIDs(
ctx, []uint{subscriptionID}, &resourceType, subscription.GranularityDaily, time.Time{}, recentBoundary.Add(-time.Second),
ctx, []uint{subscriptionID}, &resourceType, subscription.GranularityDaily, periodStart, historicalEnd,
)
if err != nil {
s.logger.Warnw("failed to get historical traffic from stats, using Redis data only",

View File

@@ -0,0 +1,35 @@
package usecases
import (
"context"
"time"
)
// CachedQuotaInfo represents the cached subscription quota information.
// This mirrors cache.CachedQuota to avoid import cycle.
type CachedQuotaInfo struct {
Limit int64 // Traffic limit in bytes
PeriodStart time.Time // Billing period start
PeriodEnd time.Time // Billing period end
PlanType string // node/forward/hybrid
Suspended bool // Whether the subscription is suspended
NotFound bool // Null marker: subscription confirmed not found/inactive in DB
}
// NodeSubscriptionQuotaCache defines the interface for subscription quota caching.
type NodeSubscriptionQuotaCache interface {
// GetQuota retrieves subscription quota information from cache.
// Returns nil if cache does not exist.
GetQuota(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
// MarkSuspended marks the subscription as suspended in cache.
MarkSuspended(ctx context.Context, subscriptionID uint) error
}
// NodeSubscriptionQuotaLoader defines the interface for lazy loading subscription quota.
// This is used when quota cache miss occurs to load quota from database.
type NodeSubscriptionQuotaLoader interface {
// LoadQuotaByID loads subscription quota from database and caches it.
// Returns the cached quota info, or nil if subscription/plan not found.
LoadQuotaByID(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
}

View File

@@ -89,6 +89,16 @@ type SubscriptionTokenDTO struct {
CreatedAt time.Time `json:"created_at"`
}
// SubscriptionStatusCounts represents subscription status aggregation counts.
// Used by admin subscription list to return status counts alongside the list response.
type SubscriptionStatusCounts struct {
Active int64 `json:"active"`
Expired int64 `json:"expired"`
Suspended int64 `json:"suspended"`
PendingPayment int64 `json:"pending_payment"`
ExpiringIn7Days int64 `json:"expiring_in_7_days"`
}
var (
// SubscriptionMapper is a generic mapper for basic conversions.
// WARNING: Does not populate OnlineDeviceCount, DeviceLimit, DataUsedBytes, or DataLimitBytes.

View File

@@ -241,8 +241,13 @@ func (uc *CreateSubscriptionUseCase) Execute(ctx context.Context, cmd CreateSubs
}
func (uc *CreateSubscriptionUseCase) calculateEndDate(startDate time.Time, billingCycle vo.BillingCycle) time.Time {
// Use fixed days to ensure consistent subscription periods
// This prevents "drifting" when starting on month boundaries (e.g., Jan 31 -> Feb 28 -> Mar 28)
return CalculateEndDate(startDate, billingCycle)
}
// CalculateEndDate computes the subscription end date from a start date and billing cycle.
// Uses fixed days to ensure consistent subscription periods and prevent "drifting"
// when starting on month boundaries (e.g., Jan 31 -> Feb 28 -> Mar 28).
func CalculateEndDate(startDate time.Time, billingCycle vo.BillingCycle) time.Time {
switch billingCycle {
case vo.BillingCycleWeekly:
return startDate.Add(7 * 24 * time.Hour) // 7 days

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"time"
"golang.org/x/sync/errgroup"
"github.com/orris-inc/orris/internal/application/subscription/dto"
"github.com/orris-inc/orris/internal/domain/subscription"
"github.com/orris-inc/orris/internal/domain/user"
@@ -24,13 +26,15 @@ type ListUserSubscriptionsQuery struct {
PageSize int
SortBy string
SortDesc *bool // nil means default (true = DESC)
IncludeCounts bool // When true, also return subscription status counts
}
type ListUserSubscriptionsResult struct {
Subscriptions []*dto.SubscriptionDTO `json:"subscriptions"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Subscriptions []*dto.SubscriptionDTO `json:"subscriptions"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
StatusCounts *dto.SubscriptionStatusCounts `json:"status_counts,omitempty"` // Present when IncludeCounts is true
}
type ListUserSubscriptionsUseCase struct {
@@ -189,6 +193,60 @@ func (uc *ListUserSubscriptionsUseCase) Execute(ctx context.Context, query ListU
dtos = append(dtos, result)
}
// Query status counts if requested
var statusCounts *dto.SubscriptionStatusCounts
if query.IncludeCounts {
statusCounts = &dto.SubscriptionStatusCounts{}
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
count, err := uc.subscriptionRepo.CountByStatus(gctx, "active")
if err != nil {
return fmt.Errorf("count active: %w", err)
}
statusCounts.Active = count
return nil
})
g.Go(func() error {
count, err := uc.subscriptionRepo.CountByStatus(gctx, "expired")
if err != nil {
return fmt.Errorf("count expired: %w", err)
}
statusCounts.Expired = count
return nil
})
g.Go(func() error {
count, err := uc.subscriptionRepo.CountByStatus(gctx, "suspended")
if err != nil {
return fmt.Errorf("count suspended: %w", err)
}
statusCounts.Suspended = count
return nil
})
g.Go(func() error {
count, err := uc.subscriptionRepo.CountByStatus(gctx, "pending_payment")
if err != nil {
return fmt.Errorf("count pending_payment: %w", err)
}
statusCounts.PendingPayment = count
return nil
})
g.Go(func() error {
subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(gctx, 7)
if err != nil {
return fmt.Errorf("find expiring subscriptions: %w", err)
}
statusCounts.ExpiringIn7Days = int64(len(subs))
return nil
})
if err := g.Wait(); err != nil {
uc.logger.Warnw("failed to get subscription status counts", "error", err)
// Non-fatal: return list without counts
statusCounts = nil
}
}
uc.logger.Debugw("subscriptions listed successfully",
"user_id", query.UserID,
"total", total,
@@ -201,5 +259,6 @@ func (uc *ListUserSubscriptionsUseCase) Execute(ctx context.Context, query ListU
Total: total,
Page: query.Page,
PageSize: query.PageSize,
StatusCounts: statusCounts,
}, nil
}

View File

@@ -3,6 +3,7 @@ package usecases
import (
"context"
"fmt"
"math"
"github.com/orris-inc/orris/internal/application/subscription/dto"
"github.com/orris-inc/orris/internal/domain/subscription"
@@ -26,10 +27,13 @@ type PlanChangeNotifier interface {
}
type UpdatePlanUseCase struct {
planRepo subscription.PlanRepository
pricingRepo subscription.PlanPricingRepository
planChangeNotifier PlanChangeNotifier
logger logger.Interface
planRepo subscription.PlanRepository
pricingRepo subscription.PlanPricingRepository
subscriptionRepo subscription.SubscriptionRepository
planChangeNotifier PlanChangeNotifier
quotaCacheManager QuotaCacheManager
subscriptionNotifier SubscriptionChangeNotifier
logger logger.Interface
}
// SetPlanChangeNotifier sets the notifier for plan feature changes.
@@ -37,6 +41,21 @@ func (uc *UpdatePlanUseCase) SetPlanChangeNotifier(notifier PlanChangeNotifier)
uc.planChangeNotifier = notifier
}
// SetSubscriptionRepo sets the subscription repository for cascading plan changes.
func (uc *UpdatePlanUseCase) SetSubscriptionRepo(repo subscription.SubscriptionRepository) {
uc.subscriptionRepo = repo
}
// SetQuotaCacheManager sets the quota cache manager for invalidating cached quotas.
func (uc *UpdatePlanUseCase) SetQuotaCacheManager(manager QuotaCacheManager) {
uc.quotaCacheManager = manager
}
// SetSubscriptionNotifier sets the notifier for subscription changes.
func (uc *UpdatePlanUseCase) SetSubscriptionNotifier(notifier SubscriptionChangeNotifier) {
uc.subscriptionNotifier = notifier
}
func NewUpdatePlanUseCase(
planRepo subscription.PlanRepository,
pricingRepo subscription.PlanPricingRepository,
@@ -108,6 +127,11 @@ func (uc *UpdatePlanUseCase) Execute(
}
}
// Invalidate quota cache for all active subscriptions when limits change
if cmd.Limits != nil {
uc.invalidateQuotaCacheForPlan(ctx, planID)
}
// Sync pricing options if provided (delete old, create new)
if cmd.Pricings != nil {
uc.logger.Infow("syncing pricing options", "plan_id", planID, "count", len(*cmd.Pricings))
@@ -160,6 +184,9 @@ func (uc *UpdatePlanUseCase) Execute(
uc.logger.Infow("pricing options synced successfully",
"plan_id", planID,
"count", len(*cmd.Pricings))
// Migrate orphaned subscriptions whose billing cycle is no longer available
uc.migrateOrphanedBillingCycles(ctx, planID, *cmd.Pricings)
}
// Reload the plan from database to get the accurate state after update
@@ -183,3 +210,163 @@ func (uc *UpdatePlanUseCase) Execute(
return dto.ToPlanDTOWithPricings(updatedPlan, pricings), nil
}
// invalidateQuotaCacheForPlan invalidates Redis quota cache for all active
// subscriptions on the given plan so enforcement uses updated limits.
func (uc *UpdatePlanUseCase) invalidateQuotaCacheForPlan(ctx context.Context, planID uint) {
if uc.subscriptionRepo == nil || uc.quotaCacheManager == nil {
return
}
subs, _, err := uc.subscriptionRepo.List(ctx, subscription.SubscriptionFilter{
PlanID: &planID,
Statuses: []string{string(vo.StatusActive), string(vo.StatusTrialing)},
Page: 1,
PageSize: 10000,
})
if err != nil {
uc.logger.Warnw("failed to list subscriptions for quota cache invalidation",
"plan_id", planID, "error", err)
return
}
for _, sub := range subs {
if err := uc.quotaCacheManager.InvalidateQuota(ctx, sub.ID()); err != nil {
uc.logger.Warnw("failed to invalidate quota cache",
"subscription_id", sub.ID(), "error", err)
}
}
if len(subs) > 0 {
uc.logger.Infow("quota cache invalidated for plan subscriptions",
"plan_id", planID, "count", len(subs))
}
}
// migrateOrphanedBillingCycles migrates subscriptions whose billing cycle
// is no longer available in the new pricing options.
// e.g., plan changed from monthly to lifetime -> existing monthly subscriptions
// are migrated to lifetime with recalculated end_date.
func (uc *UpdatePlanUseCase) migrateOrphanedBillingCycles(
ctx context.Context, planID uint, newPricings []dto.PricingOptionInput,
) {
if uc.subscriptionRepo == nil {
return
}
// Build set of active billing cycles from new pricings
availableCycles := make(map[string]bool)
for _, p := range newPricings {
if p.IsActive {
availableCycles[p.BillingCycle] = true
}
}
if len(availableCycles) == 0 {
return
}
// Query affected subscriptions
subs, _, err := uc.subscriptionRepo.List(ctx, subscription.SubscriptionFilter{
PlanID: &planID,
Statuses: []string{string(vo.StatusActive), string(vo.StatusTrialing), string(vo.StatusSuspended)},
Page: 1,
PageSize: 10000,
})
if err != nil {
uc.logger.Warnw("failed to list subscriptions for billing cycle migration",
"plan_id", planID, "error", err)
return
}
// Find orphaned subscriptions
migratedCount := 0
for _, sub := range subs {
if sub.BillingCycle() == nil {
continue
}
if availableCycles[sub.BillingCycle().String()] {
continue // billing cycle still available
}
oldCycle := sub.BillingCycle().String()
targetCycle := findClosestBillingCycle(sub.BillingCycle(), availableCycles)
newEndDate := CalculateEndDate(sub.CurrentPeriodStart(), targetCycle)
if err := sub.ChangeBillingCycle(targetCycle, newEndDate); err != nil {
uc.logger.Warnw("failed to change billing cycle",
"subscription_id", sub.ID(), "old_cycle", oldCycle, "error", err)
continue
}
if err := uc.subscriptionRepo.Update(ctx, sub); err != nil {
uc.logger.Warnw("failed to persist subscription after billing cycle migration",
"subscription_id", sub.ID(), "error", err)
continue
}
// Invalidate quota cache for migrated subscription
if uc.quotaCacheManager != nil {
_ = uc.quotaCacheManager.InvalidateQuota(ctx, sub.ID())
}
// Notify nodes of subscription update
if uc.subscriptionNotifier != nil {
_ = uc.subscriptionNotifier.NotifySubscriptionUpdate(ctx, sub)
}
uc.logger.Infow("subscription billing cycle migrated",
"subscription_id", sub.ID(),
"subscription_sid", sub.SID(),
"old_cycle", oldCycle,
"new_cycle", targetCycle.String(),
"new_end_date", newEndDate,
)
migratedCount++
}
if migratedCount > 0 {
uc.logger.Infow("billing cycle migration completed",
"plan_id", planID, "migrated_count", migratedCount)
}
}
// findClosestBillingCycle finds the billing cycle from availableCycles
// that is closest in duration to the given oldCycle.
// On tie, prefers the longer duration to minimize disruption.
func findClosestBillingCycle(oldCycle *vo.BillingCycle, availableCycles map[string]bool) vo.BillingCycle {
oldDays := 0
if oldCycle != nil {
oldDays = oldCycle.Days()
if oldCycle.IsLifetime() {
oldDays = math.MaxInt32
}
}
var bestCycle vo.BillingCycle
bestDiff := math.MaxInt32
bestDays := 0
for c := range availableCycles {
parsed, err := vo.ParseBillingCycle(c)
if err != nil {
continue
}
days := parsed.Days()
if parsed.IsLifetime() {
days = math.MaxInt32
}
diff := oldDays - days
if diff < 0 {
diff = -diff
}
// Prefer closer; on tie prefer longer duration
if diff < bestDiff || (diff == bestDiff && days > bestDays) {
bestDiff = diff
bestCycle = parsed
bestDays = days
}
}
return bestCycle
}

View File

@@ -34,6 +34,7 @@ type SubscriptionFilter struct {
UserID *uint
PlanID *uint
Status *string
Statuses []string // Multiple status filter (IN clause), takes precedence over Status
BillingCycle *string
CreatedFrom *time.Time
CreatedTo *time.Time

View File

@@ -570,6 +570,23 @@ func (s *Subscription) ChangePlan(newPlanID uint) error {
return nil
}
// ChangeBillingCycle updates the billing cycle and end date.
// Used when plan pricing changes make the current billing cycle unavailable.
func (s *Subscription) ChangeBillingCycle(newCycle vo.BillingCycle, newEndDate time.Time) error {
if !newCycle.IsValid() {
return fmt.Errorf("invalid billing cycle: %s", newCycle)
}
if newEndDate.Before(s.startDate) {
return fmt.Errorf("new end date must be after start date")
}
s.billingCycle = &newCycle
s.endDate = newEndDate
s.currentPeriodEnd = newEndDate
s.updatedAt = biztime.NowUTC()
s.version++
return nil
}
// IsExpired checks if subscription is expired
func (s *Subscription) IsExpired() bool {
return biztime.NowUTC().After(s.endDate)

View File

@@ -44,6 +44,9 @@ func GetTrafficResetMode(plan *Plan) TrafficResetMode {
// For calendar_month: uses business timezone month boundaries (backward compatible default).
// For billing_cycle: uses the subscription's CurrentPeriodStart/CurrentPeriodEnd.
// Falls back to calendar_month if sub is nil.
//
// In both modes, if the subscription's CurrentPeriodStart is after the resolved period start
// (e.g. after a manual usage reset), it is used as a floor to exclude pre-reset traffic.
func ResolveTrafficPeriod(plan *Plan, sub *Subscription) TrafficPeriod {
mode := GetTrafficResetMode(plan)
@@ -56,8 +59,16 @@ func ResolveTrafficPeriod(plan *Plan, sub *Subscription) TrafficPeriod {
// Fallback: calendar month (default, or billing_cycle with nil sub)
bizNow := biztime.ToBizTimezone(biztime.NowUTC())
return TrafficPeriod{
period := TrafficPeriod{
Start: biztime.StartOfMonthUTC(bizNow.Year(), bizNow.Month()),
End: biztime.EndOfMonthUTC(bizNow.Year(), bizNow.Month()),
}
// If subscription's period start is after the calendar month start (e.g. after manual
// usage reset), use it as the floor so pre-reset traffic is excluded.
if sub != nil && sub.CurrentPeriodStart().After(period.Start) {
period.Start = sub.CurrentPeriodStart()
}
return period
}

View File

@@ -1,4 +1,4 @@
package adapters
package repository
import (
"context"
@@ -10,7 +10,7 @@ import (
"github.com/orris-inc/orris/internal/domain/node"
nodevo "github.com/orris-inc/orris/internal/domain/node/valueobjects"
"github.com/orris-inc/orris/internal/domain/subscription/valueobjects"
"github.com/orris-inc/orris/internal/interfaces/adapters/nodeutil"
"github.com/orris-inc/orris/internal/infrastructure/persistence/nodeutil"
"github.com/orris-inc/orris/internal/infrastructure/persistence/models"
"github.com/orris-inc/orris/internal/shared/logger"
"github.com/orris-inc/orris/internal/shared/utils/jsonutil"
@@ -18,21 +18,19 @@ import (
"github.com/orris-inc/orris/internal/shared/utils/setutil"
)
// NodeRepository defines the interface for node persistence operations
type NodeRepository interface {
GetByToken(ctx context.Context, tokenHash string) (*node.Node, error)
}
type NodeRepositoryAdapter struct {
nodeRepo NodeRepository
// NodeSubscriptionRepository implements usecases.NodeRepository by querying
// subscriptions, plans, resource groups, and nodes to build subscription node lists.
type NodeSubscriptionRepository struct {
nodeRepo node.NodeRepository
forwardRuleRepo forward.Repository
db *gorm.DB
logger logger.Interface
configLoader *nodeutil.ConfigLoader
}
func NewNodeRepositoryAdapter(nodeRepo NodeRepository, forwardRuleRepo forward.Repository, db *gorm.DB, logger logger.Interface) *NodeRepositoryAdapter {
return &NodeRepositoryAdapter{
// NewNodeSubscriptionRepository creates a new NodeSubscriptionRepository.
func NewNodeSubscriptionRepository(nodeRepo node.NodeRepository, forwardRuleRepo forward.Repository, db *gorm.DB, logger logger.Interface) *NodeSubscriptionRepository {
return &NodeSubscriptionRepository{
nodeRepo: nodeRepo,
forwardRuleRepo: forwardRuleRepo,
db: db,
@@ -41,7 +39,7 @@ func NewNodeRepositoryAdapter(nodeRepo NodeRepository, forwardRuleRepo forward.R
}
}
func (r *NodeRepositoryAdapter) GetBySubscriptionToken(ctx context.Context, linkToken string, mode string) ([]*usecases.Node, error) {
func (r *NodeSubscriptionRepository) GetBySubscriptionToken(ctx context.Context, linkToken string, mode string) ([]*usecases.Node, error) {
var subscriptionModel models.SubscriptionModel
// Query subscription by link_token
@@ -148,7 +146,7 @@ func (r *NodeRepositoryAdapter) GetBySubscriptionToken(ctx context.Context, link
}
}
func (r *NodeRepositoryAdapter) GetByTokenHash(ctx context.Context, tokenHash string) (usecases.NodeData, error) {
func (r *NodeSubscriptionRepository) GetByTokenHash(ctx context.Context, tokenHash string) (usecases.NodeData, error) {
nodeEntity, err := r.nodeRepo.GetByToken(ctx, tokenHash)
if err != nil {
r.logger.Warnw("failed to get node by token hash",
@@ -172,7 +170,7 @@ func (r *NodeRepositoryAdapter) GetByTokenHash(ctx context.Context, tokenHash st
// Rules are selected based on resource group membership, regardless of whether their
// target nodes are in the same resource groups.
// Uses Repository method to ensure proper scope isolation (system rules only).
func (r *NodeRepositoryAdapter) getForwardedNodes(ctx context.Context, groupIDs []uint, originNodeMap map[uint]*usecases.Node) []*usecases.Node {
func (r *NodeSubscriptionRepository) getForwardedNodes(ctx context.Context, groupIDs []uint, originNodeMap map[uint]*usecases.Node) []*usecases.Node {
if len(groupIDs) == 0 {
return nil
}
@@ -238,7 +236,7 @@ func (r *NodeRepositoryAdapter) getForwardedNodes(ctx context.Context, groupIDs
// For forward plans, users see their own forward rules as subscription nodes
// Forward plans have no "origin" nodes - all nodes are forwarded by nature
// Uses Repository method to ensure proper scope isolation (user's own rules only).
func (r *NodeRepositoryAdapter) getForwardPlanNodes(ctx context.Context, subscriptionID uint, userID uint, mode string) ([]*usecases.Node, error) {
func (r *NodeSubscriptionRepository) getForwardPlanNodes(ctx context.Context, subscriptionID uint, userID uint, mode string) ([]*usecases.Node, error) {
// Forward plans have no origin nodes, return empty for origin mode
if mode == usecases.NodeModeOrigin {
r.logger.Debugw("forward plan has no origin nodes", "user_id", userID, "mode", mode)
@@ -262,7 +260,7 @@ func (r *NodeRepositoryAdapter) getForwardPlanNodes(ctx context.Context, subscri
// getHybridPlanNodes returns nodes for hybrid plan subscriptions
// For hybrid plans, users see both resource group nodes AND their own forward rules
func (r *NodeRepositoryAdapter) getHybridPlanNodes(ctx context.Context, subscriptionID uint, userID uint, planID uint, mode string) ([]*usecases.Node, error) {
func (r *NodeSubscriptionRepository) getHybridPlanNodes(ctx context.Context, subscriptionID uint, userID uint, planID uint, mode string) ([]*usecases.Node, error) {
// Step 1: Get resource group nodes (same as node plan logic)
// Query resource group IDs for this plan
var groupIDs []uint
@@ -349,7 +347,7 @@ func (r *NodeRepositoryAdapter) getHybridPlanNodes(ctx context.Context, subscrip
}
// buildNodesWithConfigs builds use case nodes from node models with protocol configs loaded
func (r *NodeRepositoryAdapter) buildNodesWithConfigs(ctx context.Context, nodeModels []models.NodeModel) []*usecases.Node {
func (r *NodeSubscriptionRepository) buildNodesWithConfigs(ctx context.Context, nodeModels []models.NodeModel) []*usecases.Node {
configs := r.configLoader.LoadProtocolConfigs(ctx, nodeModels)
nodes := make([]*usecases.Node, 0, len(nodeModels))
@@ -365,7 +363,7 @@ func (r *NodeRepositoryAdapter) buildNodesWithConfigs(ctx context.Context, nodeM
// getUserForwardNodes retrieves user's forward rules with target nodes as subscription nodes
// Only returns forward rules where target_node_id is NOT NULL
// Uses Repository method to ensure proper scope isolation (user's own rules only).
func (r *NodeRepositoryAdapter) getUserForwardNodes(ctx context.Context, userID uint) ([]*usecases.Node, error) {
func (r *NodeSubscriptionRepository) getUserForwardNodes(ctx context.Context, userID uint) ([]*usecases.Node, error) {
forwardRules, err := r.forwardRuleRepo.ListUserRulesForDelivery(ctx, userID)
if err != nil {
r.logger.Errorw("failed to query user forward rules", "user_id", userID, "error", err)
@@ -395,7 +393,7 @@ func (r *NodeRepositoryAdapter) getUserForwardNodes(ctx context.Context, userID
}
// collectIDsFromRules extracts node IDs and agent IDs from forward rules
func (r *NodeRepositoryAdapter) collectIDsFromRules(rules []*forward.ForwardRule) (nodeIDs, agentIDs []uint) {
func (r *NodeSubscriptionRepository) collectIDsFromRules(rules []*forward.ForwardRule) (nodeIDs, agentIDs []uint) {
nodeIDSet := setutil.NewUintSet()
agentIDSet := setutil.NewUintSet()
@@ -412,7 +410,7 @@ func (r *NodeRepositoryAdapter) collectIDsFromRules(rules []*forward.ForwardRule
}
// collectAgentIDsFromRules extracts agent IDs from forward rules (skipping external rules).
func (r *NodeRepositoryAdapter) collectAgentIDsFromRules(rules []*forward.ForwardRule) []uint {
func (r *NodeSubscriptionRepository) collectAgentIDsFromRules(rules []*forward.ForwardRule) []uint {
agentIDSet := setutil.NewUintSet()
for _, rule := range rules {
if rule.AgentID() > 0 {
@@ -423,7 +421,7 @@ func (r *NodeRepositoryAdapter) collectAgentIDsFromRules(rules []*forward.Forwar
}
// queryActiveNodes queries active nodes by IDs and returns both slice and map
func (r *NodeRepositoryAdapter) queryActiveNodes(ctx context.Context, nodeIDs []uint) ([]models.NodeModel, map[uint]*models.NodeModel, error) {
func (r *NodeSubscriptionRepository) queryActiveNodes(ctx context.Context, nodeIDs []uint) ([]models.NodeModel, map[uint]*models.NodeModel, error) {
nodeMap := make(map[uint]*models.NodeModel)
if len(nodeIDs) == 0 {
return nil, nodeMap, nil
@@ -447,7 +445,7 @@ func (r *NodeRepositoryAdapter) queryActiveNodes(ctx context.Context, nodeIDs []
}
// loadForwardAgents loads forward agents by IDs and returns a map
func (r *NodeRepositoryAdapter) loadForwardAgents(ctx context.Context, agentIDs []uint) map[uint]*models.ForwardAgentModel {
func (r *NodeSubscriptionRepository) loadForwardAgents(ctx context.Context, agentIDs []uint) map[uint]*models.ForwardAgentModel {
agentMap := make(map[uint]*models.ForwardAgentModel)
if len(agentIDs) == 0 {
return agentMap

View File

@@ -460,7 +460,9 @@ func (r *SubscriptionRepositoryImpl) List(ctx context.Context, filter subscripti
if filter.PlanID != nil {
query = query.Where("plan_id = ?", *filter.PlanID)
}
if filter.Status != nil {
if len(filter.Statuses) > 0 {
query = query.Where("status IN ?", filter.Statuses)
} else if filter.Status != nil {
query = query.Where("status = ?", *filter.Status)
}
if filter.BillingCycle != nil {

View File

@@ -1,4 +1,4 @@
package adapters
package repository
import (
"context"
@@ -13,13 +13,15 @@ import (
"github.com/orris-inc/orris/internal/shared/logger"
)
type SubscriptionTokenValidatorAdapter struct {
// SubscriptionTokenValidator validates subscription tokens by querying the database directly.
type SubscriptionTokenValidator struct {
db *gorm.DB
logger logger.Interface
}
func NewSubscriptionTokenValidatorAdapter(db *gorm.DB, logger logger.Interface) *SubscriptionTokenValidatorAdapter {
return &SubscriptionTokenValidatorAdapter{
// NewSubscriptionTokenValidator creates a new SubscriptionTokenValidator.
func NewSubscriptionTokenValidator(db *gorm.DB, logger logger.Interface) *SubscriptionTokenValidator {
return &SubscriptionTokenValidator{
db: db,
logger: logger,
}
@@ -54,7 +56,7 @@ func subscriptionStatusError(status string) error {
}
}
func (v *SubscriptionTokenValidatorAdapter) Validate(ctx context.Context, linkToken string) error {
func (v *SubscriptionTokenValidator) Validate(ctx context.Context, linkToken string) error {
var subscriptionModel models.SubscriptionModel
if err := v.db.WithContext(ctx).
Where("link_token = ?", linkToken).
@@ -82,7 +84,7 @@ func (v *SubscriptionTokenValidatorAdapter) Validate(ctx context.Context, linkTo
return nil
}
func (v *SubscriptionTokenValidatorAdapter) ValidateAndGetSubscription(ctx context.Context, linkToken string) (*nodeusecases.SubscriptionValidationResult, error) {
func (v *SubscriptionTokenValidator) ValidateAndGetSubscription(ctx context.Context, linkToken string) (*nodeusecases.SubscriptionValidationResult, error) {
var subscriptionModel models.SubscriptionModel
if err := v.db.WithContext(ctx).
Where("link_token = ?", linkToken).

View File

@@ -4,9 +4,9 @@ import (
"context"
"time"
nodeUsecases "github.com/orris-inc/orris/internal/application/node/usecases"
"github.com/orris-inc/orris/internal/domain/subscription"
"github.com/orris-inc/orris/internal/infrastructure/cache"
nodeHandlers "github.com/orris-inc/orris/internal/interfaces/http/handlers/node"
"github.com/orris-inc/orris/internal/shared/biztime"
"github.com/orris-inc/orris/internal/shared/logger"
)
@@ -29,7 +29,7 @@ func NewNodeSubscriptionQuotaCacheAdapter(
}
// GetQuota retrieves subscription quota from cache
func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscriptionID uint) (*nodeHandlers.CachedQuotaInfo, error) {
func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscriptionID uint) (*nodeUsecases.CachedQuotaInfo, error) {
cached, err := a.cache.GetQuota(ctx, subscriptionID)
if err != nil {
return nil, err
@@ -38,7 +38,7 @@ func (a *NodeSubscriptionQuotaCacheAdapter) GetQuota(ctx context.Context, subscr
return nil, nil
}
return &nodeHandlers.CachedQuotaInfo{
return &nodeUsecases.CachedQuotaInfo{
Limit: cached.Limit,
PeriodStart: cached.PeriodStart,
PeriodEnd: cached.PeriodEnd,
@@ -79,7 +79,7 @@ func NewNodeSubscriptionQuotaLoaderAdapter(
// LoadQuotaByID loads subscription quota from database and caches it.
// When subscription is not found or inactive, a null marker is cached to prevent
// repeated DB lookups (cache penetration protection).
func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, subscriptionID uint) (*nodeHandlers.CachedQuotaInfo, error) {
func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context, subscriptionID uint) (*nodeUsecases.CachedQuotaInfo, error) {
// Get subscription from database
sub, err := a.subscriptionRepo.GetByID(ctx, subscriptionID)
if err != nil {
@@ -134,7 +134,7 @@ func (a *NodeSubscriptionQuotaLoaderAdapter) LoadQuotaByID(ctx context.Context,
)
}
return &nodeHandlers.CachedQuotaInfo{
return &nodeUsecases.CachedQuotaInfo{
Limit: cachedQuota.Limit,
PeriodStart: cachedQuota.PeriodStart,
PeriodEnd: cachedQuota.PeriodEnd,
@@ -186,10 +186,15 @@ func (a *NodeSubscriptionUsageReaderAdapter) GetCurrentPeriodUsage(
var total int64
// Get recent traffic from Redis (yesterday + today, filter by node type)
// Get recent traffic from Redis (yesterday + today, filter by node type).
// Use max(recentBoundary, periodStart) so traffic before periodStart (e.g. after reset) is excluded.
resourceType := subscription.ResourceTypeNode.String()
redisFrom := recentBoundary
if periodStart.After(redisFrom) {
redisFrom = periodStart
}
recentTraffic, err := a.hourlyTrafficCache.GetTotalTrafficBySubscriptionIDs(
ctx, []uint{subscriptionID}, resourceType, recentBoundary, now,
ctx, []uint{subscriptionID}, resourceType, redisFrom, now,
)
if err != nil {
a.logger.Warnw("failed to get recent traffic from Redis",

View File

@@ -22,6 +22,7 @@ import (
"github.com/orris-inc/orris/internal/infrastructure/email"
infraPayment "github.com/orris-inc/orris/internal/infrastructure/payment"
"github.com/orris-inc/orris/internal/infrastructure/pubsub"
"github.com/orris-inc/orris/internal/infrastructure/repository"
"github.com/orris-inc/orris/internal/infrastructure/scheduler"
"github.com/orris-inc/orris/internal/infrastructure/services"
telegramInfra "github.com/orris-inc/orris/internal/infrastructure/telegram"
@@ -100,8 +101,8 @@ type Container struct {
adminNotificationServiceDDD *telegramAdminApp.ServiceDDD
// Cross-cutting adapters and services (created in one section, used in another)
nodeRepoAdapter *adapters.NodeRepositoryAdapter
tokenValidator *adapters.SubscriptionTokenValidatorAdapter
nodeRepoAdapter *repository.NodeSubscriptionRepository
tokenValidator *repository.SubscriptionTokenValidator
templateLoader *template.SubscriptionTemplateLoader
nodeStatusQuerier *adapters.NodeSystemStatusQuerierAdapter
forwardAgentReleaseService *services.GitHubReleaseService

View File

@@ -270,6 +270,9 @@ func (h *Handler) List(c *gin.Context) {
sortDesc = &desc
}
// Parse include_counts parameter
includeCounts := c.Query("include_counts") == "true"
query := usecases.ListUserSubscriptionsQuery{
UserID: userID,
PlanID: planID,
@@ -282,6 +285,7 @@ func (h *Handler) List(c *gin.Context) {
PageSize: p.PageSize,
SortBy: sortBy,
SortDesc: sortDesc,
IncludeCounts: includeCounts,
}
result, err := h.listUseCase.Execute(c.Request.Context(), query)
@@ -291,6 +295,20 @@ func (h *Handler) List(c *gin.Context) {
return
}
// Build response with optional status_counts
if result.StatusCounts != nil {
response := map[string]any{
"items": result.Subscriptions,
"total": result.Total,
"page": result.Page,
"page_size": result.PageSize,
"total_pages": utils.TotalPages(result.Total, result.PageSize),
"status_counts": result.StatusCounts,
}
utils.SuccessResponse(c, http.StatusOK, "", response)
return
}
utils.ListSuccessResponse(c, result.Subscriptions, result.Total, result.Page, result.PageSize)
}

View File

@@ -90,34 +90,13 @@ type NodeSubscriptionUsageReader interface {
}
// NodeSubscriptionQuotaCache defines the interface for subscription quota caching.
// This is a local interface to avoid import cycle with cache package.
type NodeSubscriptionQuotaCache interface {
// GetQuota retrieves subscription quota information from cache.
// Returns nil if cache does not exist.
GetQuota(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
// MarkSuspended marks the subscription as suspended in cache.
MarkSuspended(ctx context.Context, subscriptionID uint) error
}
type NodeSubscriptionQuotaCache = usecases.NodeSubscriptionQuotaCache
// CachedQuotaInfo represents the cached subscription quota information.
// This mirrors cache.CachedQuota to avoid import cycle.
type CachedQuotaInfo struct {
Limit int64 // Traffic limit in bytes
PeriodStart time.Time // Billing period start
PeriodEnd time.Time // Billing period end
PlanType string // node/forward/hybrid
Suspended bool // Whether the subscription is suspended
NotFound bool // Null marker: subscription confirmed not found/inactive in DB
}
type CachedQuotaInfo = usecases.CachedQuotaInfo
// NodeSubscriptionQuotaLoader defines the interface for lazy loading subscription quota.
// This is used when quota cache miss occurs to load quota from database.
type NodeSubscriptionQuotaLoader interface {
// LoadQuotaByID loads subscription quota from database and caches it.
// Returns the cached quota info, or nil if subscription/plan not found.
LoadQuotaByID(ctx context.Context, subscriptionID uint) (*CachedQuotaInfo, error)
}
type NodeSubscriptionQuotaLoader = usecases.NodeSubscriptionQuotaLoader
// NodeHubHandler handles WebSocket connections for node agents.
type NodeHubHandler struct {

View File

@@ -311,8 +311,8 @@ func (c *Container) initNode() {
hdlrs := c.hdlrs
// Initialize adapters
c.nodeRepoAdapter = adapters.NewNodeRepositoryAdapter(repos.nodeRepoImpl, repos.forwardRuleRepo, db, log)
c.tokenValidator = adapters.NewSubscriptionTokenValidatorAdapter(db, log)
c.nodeRepoAdapter = repository.NewNodeSubscriptionRepository(repos.nodeRepoImpl, repos.forwardRuleRepo, db, log)
c.tokenValidator = repository.NewSubscriptionTokenValidator(db, log)
c.nodeStatusQuerier = adapters.NewNodeSystemStatusQuerierAdapter(c.redis, log)
// Initialize GitHub release services for version checking
@@ -1359,6 +1359,9 @@ func (c *Container) initCallbacksAndNotifiers() {
// Set plan change notifier to propagate plan feature changes (e.g. device_limit) to nodes
ucs.updatePlanUC.SetPlanChangeNotifier(c.subscriptionSyncService)
ucs.updatePlanUC.SetSubscriptionRepo(repos.subscriptionRepo)
ucs.updatePlanUC.SetQuotaCacheManager(c.quotaCacheSyncService)
ucs.updatePlanUC.SetSubscriptionNotifier(c.subscriptionSyncService)
}
// ============================================================