73 lines
3.2 KiB
Python
73 lines
3.2 KiB
Python
from django.test import TestCase, Client
|
|
from django.contrib.auth import get_user_model
|
|
from django.core.management import call_command
|
|
from django.urls import reverse
|
|
from decimal import Decimal
|
|
from django.utils import timezone
|
|
|
|
from plans.models import Plan
|
|
from domains.models import Domain, DomainTrafficDaily
|
|
from billing.models import Invoice
|
|
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
class BillingFlowTests(TestCase):
|
|
def setUp(self):
|
|
self.user = User.objects.create_user(username='u2', email='u2@example.com', password='p', is_staff=True)
|
|
self.plan = Plan.objects.create(name='P', base_price_per_domain=Decimal('10.00'), included_traffic_gb_per_domain=1, overage_price_per_gb=Decimal('1.00'))
|
|
today = timezone.now().date()
|
|
month_start = today.replace(day=1)
|
|
next_month = (month_start.replace(year=month_start.year + 1, month=1) if month_start.month == 12 else month_start.replace(month=month_start.month + 1))
|
|
cycle_end = next_month - timezone.timedelta(days=1)
|
|
self.domain = Domain.objects.create(user=self.user, name='ex2.com', status=Domain.STATUS_ACTIVE, current_plan=self.plan, current_cycle_start=month_start, current_cycle_end=cycle_end)
|
|
DomainTrafficDaily.objects.create(domain=self.domain, day=month_start, bytes=2 * (1024 ** 3))
|
|
|
|
def test_generate_invoice_and_apply_policy_dry(self):
|
|
call_command('generate_invoices')
|
|
inv = Invoice.objects.filter(user=self.user).first()
|
|
self.assertIsNotNone(inv)
|
|
self.assertEqual(inv.status, Invoice.STATUS_UNPAID)
|
|
call_command('apply_invoice_policies', dry_run=True)
|
|
inv.refresh_from_db()
|
|
self.assertEqual(inv.status, Invoice.STATUS_UNPAID)
|
|
|
|
def test_mark_paid_via_admin_view(self):
|
|
call_command('generate_invoices')
|
|
inv = Invoice.objects.filter(user=self.user).first()
|
|
c = Client()
|
|
c.login(username='u2', password='p')
|
|
resp = c.post(reverse('admin_panel:billing_detail', kwargs={'invoice_id': inv.id}), {'action': 'mark_paid'})
|
|
self.assertEqual(resp.status_code, 302)
|
|
inv.refresh_from_db()
|
|
self.assertEqual(inv.status, Invoice.STATUS_PAID)
|
|
|
|
def test_epay_notify_marks_paid(self):
|
|
call_command('generate_invoices')
|
|
inv = Invoice.objects.filter(user=self.user).first()
|
|
from core.models import SystemSettings
|
|
s = SystemSettings.objects.order_by('id').first()
|
|
if not s:
|
|
s = SystemSettings.objects.create()
|
|
s.epay_api_base_url = 'https://api.example.com'
|
|
s.epay_pid = 'pid1'
|
|
s.epay_key = 'k123'
|
|
s.save()
|
|
params = {
|
|
'pid': s.epay_pid,
|
|
'out_trade_no': f'INV{inv.id}',
|
|
'money': str(inv.amount_total),
|
|
'type': 'alipay',
|
|
'trade_status': 'SUCCESS',
|
|
'sign_type': 'MD5',
|
|
}
|
|
import hashlib
|
|
src = '&'.join(f"{k}={params[k]}" for k in sorted(params) if k not in ['sign', 'sign_type']) + '&key=' + s.epay_key
|
|
params['sign'] = hashlib.md5(src.encode('utf-8')).hexdigest().upper()
|
|
c = Client()
|
|
resp = c.get(reverse('billing:notify'), params)
|
|
self.assertEqual(resp.status_code, 200)
|
|
inv.refresh_from_db()
|
|
self.assertEqual(inv.status, Invoice.STATUS_PAID)
|