mirror of
https://github.com/UrloMythus/UnHided.git
synced 2026-04-09 02:40:47 +00:00
130 lines
5.3 KiB
Python
130 lines
5.3 KiB
Python
import logging
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, Optional, Type
|
|
|
|
from mediaflow_proxy.utils.cache_utils import get_cached_speedtest, set_cache_speedtest
|
|
from mediaflow_proxy.utils.http_utils import Streamer, create_httpx_client
|
|
from .models import SpeedTestTask, LocationResult, SpeedTestResult, SpeedTestProvider
|
|
from .providers.all_debrid import AllDebridSpeedTest
|
|
from .providers.base import BaseSpeedTestProvider
|
|
from .providers.real_debrid import RealDebridSpeedTest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SpeedTestService:
|
|
"""Service for managing speed tests across different providers."""
|
|
|
|
def __init__(self):
|
|
# Provider mapping
|
|
self._providers: Dict[SpeedTestProvider, Type[BaseSpeedTestProvider]] = {
|
|
SpeedTestProvider.REAL_DEBRID: RealDebridSpeedTest,
|
|
SpeedTestProvider.ALL_DEBRID: AllDebridSpeedTest,
|
|
}
|
|
|
|
def _get_provider(self, provider: SpeedTestProvider, api_key: Optional[str] = None) -> BaseSpeedTestProvider:
|
|
"""Get the appropriate provider implementation."""
|
|
provider_class = self._providers.get(provider)
|
|
if not provider_class:
|
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
|
|
if provider == SpeedTestProvider.ALL_DEBRID and not api_key:
|
|
raise ValueError("API key required for AllDebrid")
|
|
|
|
return provider_class(api_key) if provider == SpeedTestProvider.ALL_DEBRID else provider_class()
|
|
|
|
async def create_test(
|
|
self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None
|
|
) -> SpeedTestTask:
|
|
"""Create a new speed test task."""
|
|
provider_impl = self._get_provider(provider, api_key)
|
|
|
|
# Get initial URLs and user info
|
|
urls, user_info = await provider_impl.get_test_urls()
|
|
|
|
task = SpeedTestTask(
|
|
task_id=task_id, provider=provider, started_at=datetime.now(tz=timezone.utc), user_info=user_info
|
|
)
|
|
|
|
await set_cache_speedtest(task_id, task)
|
|
return task
|
|
|
|
@staticmethod
|
|
async def get_test_results(task_id: str) -> Optional[SpeedTestTask]:
|
|
"""Get results for a specific task."""
|
|
return await get_cached_speedtest(task_id)
|
|
|
|
async def run_speedtest(self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None):
|
|
"""Run the speed test with real-time updates."""
|
|
try:
|
|
task = await get_cached_speedtest(task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
provider_impl = self._get_provider(provider, api_key)
|
|
config = await provider_impl.get_config()
|
|
|
|
async with create_httpx_client() as client:
|
|
streamer = Streamer(client)
|
|
|
|
for location, url in config.test_urls.items():
|
|
try:
|
|
task.current_location = location
|
|
await set_cache_speedtest(task_id, task)
|
|
result = await self._test_location(location, url, streamer, config.test_duration, provider_impl)
|
|
task.results[location] = result
|
|
await set_cache_speedtest(task_id, task)
|
|
except Exception as e:
|
|
logger.error(f"Error testing {location}: {str(e)}")
|
|
task.results[location] = LocationResult(
|
|
error=str(e), server_name=location, server_url=config.test_urls[location]
|
|
)
|
|
await set_cache_speedtest(task_id, task)
|
|
|
|
# Mark task as completed
|
|
task.completed_at = datetime.now(tz=timezone.utc)
|
|
task.status = "completed"
|
|
task.current_location = None
|
|
await set_cache_speedtest(task_id, task)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in speed test task {task_id}: {str(e)}")
|
|
if task := await get_cached_speedtest(task_id):
|
|
task.status = "failed"
|
|
await set_cache_speedtest(task_id, task)
|
|
|
|
async def _test_location(
|
|
self, location: str, url: str, streamer: Streamer, test_duration: int, provider: BaseSpeedTestProvider
|
|
) -> LocationResult:
|
|
"""Test speed for a specific location."""
|
|
try:
|
|
start_time = time.time()
|
|
total_bytes = 0
|
|
|
|
await streamer.create_streaming_response(url, headers={})
|
|
|
|
async for chunk in streamer.stream_content():
|
|
if time.time() - start_time >= test_duration:
|
|
break
|
|
total_bytes += len(chunk)
|
|
|
|
duration = time.time() - start_time
|
|
speed_mbps = (total_bytes * 8) / (duration * 1_000_000)
|
|
|
|
# Get server info if available (for AllDebrid)
|
|
server_info = getattr(provider, "servers", {}).get(location)
|
|
server_url = server_info.url if server_info else url
|
|
|
|
return LocationResult(
|
|
result=SpeedTestResult(
|
|
speed_mbps=round(speed_mbps, 2), duration=round(duration, 2), data_transferred=total_bytes
|
|
),
|
|
server_name=location,
|
|
server_url=server_url,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing {location}: {str(e)}")
|
|
raise # Re-raise to be handled by run_speedtest
|