355 lines
12 KiB
Python
355 lines
12 KiB
Python
import os
|
||
import time
|
||
import threading
|
||
import requests
|
||
from urllib.parse import urljoin
|
||
import m3u8
|
||
from Crypto.Cipher import AES
|
||
import concurrent.futures
|
||
from pathlib import Path
|
||
|
||
|
||
class M3U8Downloader:
|
||
def __init__(self, max_workers=5, output_dir="downloads"):
|
||
self.max_workers = max_workers
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(exist_ok=True)
|
||
|
||
# 存储下载任务状态
|
||
self.tasks = {}
|
||
self.lock = threading.Lock()
|
||
self.task_counter = 0
|
||
|
||
def get_task_info(self, task_id):
|
||
"""获取任务信息"""
|
||
with self.lock:
|
||
return self.tasks.get(task_id, {"status": "not_found"})
|
||
|
||
def list_tasks(self):
|
||
"""列出所有任务"""
|
||
with self.lock:
|
||
return {task_id: info for task_id, info in self.tasks.items()}
|
||
|
||
def get_all_tasks(self):
|
||
"""
|
||
获取全部任务的信息,包括文件名和任务ID
|
||
|
||
Returns:
|
||
list: 包含所有任务信息的列表,每个元素为字典
|
||
[{'task_id': 'task_1', 'filename': 'video1.mp4', 'status': 'downloading', 'progress': 0.56}, ...]
|
||
"""
|
||
with self.lock:
|
||
all_tasks = []
|
||
for task_id, task_info in self.tasks.items():
|
||
# 计算进度
|
||
progress = 0.0
|
||
if task_info['status'] == 'preparing':
|
||
progress = 0.0
|
||
elif task_info['status'] == 'downloading':
|
||
if task_info['total_segments'] > 0:
|
||
progress = task_info['downloaded_segments'] / task_info['total_segments']
|
||
else:
|
||
progress = 0.0
|
||
elif task_info['status'] == 'merging':
|
||
progress = 1.0
|
||
elif task_info['status'] == 'completed':
|
||
progress = 1.0
|
||
elif task_info['status'] == 'failed':
|
||
progress = 0.0
|
||
|
||
all_tasks.append({
|
||
'task_id': task_id,
|
||
'filename': task_info['output_filename'],
|
||
'status': task_info['status'],
|
||
'progress': round(progress, 4),
|
||
'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task_info.get('start_time', time.time())))
|
||
})
|
||
|
||
# 按开始时间倒序排列,最新的任务在前面
|
||
all_tasks.sort(key=lambda x: x['start_time'], reverse=True)
|
||
return all_tasks
|
||
|
||
def get_tasks_summary(self):
|
||
"""
|
||
获取任务摘要信息
|
||
|
||
Returns:
|
||
dict: 包含任务统计信息的字典
|
||
"""
|
||
all_tasks = self.get_all_tasks()
|
||
|
||
summary = {
|
||
'total': len(all_tasks),
|
||
'preparing': 0,
|
||
'downloading': 0,
|
||
'merging': 0,
|
||
'completed': 0,
|
||
'failed': 0
|
||
}
|
||
|
||
for task in all_tasks:
|
||
status = task['status']
|
||
if status in summary:
|
||
summary[status] += 1
|
||
|
||
return summary
|
||
|
||
def download_ts_segment(self, task_info, ts_url, output_path, segment_index):
|
||
"""下载单个TS片段"""
|
||
try:
|
||
headers = {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||
}
|
||
|
||
response = requests.get(ts_url, headers=headers, stream=True, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
ts_data = response.content
|
||
|
||
# 如果有加密,进行解密
|
||
if task_info['key'] and task_info['iv']:
|
||
cipher = AES.new(task_info['key'], AES.MODE_CBC, task_info['iv'])
|
||
ts_data = cipher.decrypt(ts_data)
|
||
|
||
with open(output_path, 'wb') as f:
|
||
f.write(ts_data)
|
||
|
||
# 更新进度
|
||
with self.lock:
|
||
if task_info['task_id'] in self.tasks:
|
||
self.tasks[task_info['task_id']]['downloaded_segments'] += 1
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"下载片段 {segment_index} 失败: {e}")
|
||
return False
|
||
|
||
def get_decryption_key(self, key_uri, iv=None):
|
||
"""获取解密密钥"""
|
||
try:
|
||
response = requests.get(key_uri)
|
||
response.raise_for_status()
|
||
key = response.content
|
||
|
||
# 如果IV是十六进制字符串,转换为bytes
|
||
if iv and isinstance(iv, str):
|
||
if iv.startswith('0x'):
|
||
iv = bytes.fromhex(iv[2:])
|
||
else:
|
||
iv = bytes.fromhex(iv)
|
||
elif not iv:
|
||
iv = b'\x00' * 16 # 默认IV
|
||
|
||
return key, iv
|
||
except Exception as e:
|
||
print(f"获取解密密钥失败: {e}")
|
||
return None, None
|
||
|
||
def _download_m3u8(self, m3u8_url, output_filename, task_id):
|
||
"""内部下载方法"""
|
||
# 初始化任务信息
|
||
task_info = {
|
||
'task_id': task_id,
|
||
'm3u8_url': m3u8_url,
|
||
'output_filename': output_filename,
|
||
'status': 'preparing',
|
||
'total_segments': 0,
|
||
'downloaded_segments': 0,
|
||
'progress': 0.0,
|
||
'output_file': '',
|
||
'start_time': time.time(),
|
||
'key': None,
|
||
'iv': None
|
||
}
|
||
|
||
with self.lock:
|
||
self.tasks[task_id] = task_info
|
||
|
||
try:
|
||
# 解析M3U8文件
|
||
headers = {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||
}
|
||
|
||
response = requests.get(m3u8_url, headers=headers)
|
||
response.raise_for_status()
|
||
|
||
m3u8_content = response.text
|
||
m3u8_obj = m3u8.loads(m3u8_content)
|
||
|
||
# 处理密钥
|
||
key = None
|
||
iv = None
|
||
if m3u8_obj.keys and m3u8_obj.keys[0]:
|
||
key_uri = m3u8_obj.keys[0].uri
|
||
if not key_uri.startswith('http'):
|
||
key_uri = urljoin(m3u8_url, key_uri)
|
||
|
||
key, iv = self.get_decryption_key(key_uri, m3u8_obj.keys[0].iv)
|
||
task_info['key'] = key
|
||
task_info['iv'] = iv
|
||
|
||
# 获取所有TS片段URL
|
||
ts_segments = []
|
||
for segment in m3u8_obj.segments:
|
||
ts_url = segment.uri
|
||
if not ts_url.startswith('http'):
|
||
ts_url = urljoin(m3u8_url, ts_url)
|
||
ts_segments.append(ts_url)
|
||
|
||
task_info['total_segments'] = len(ts_segments)
|
||
task_info['status'] = 'downloading'
|
||
|
||
# 设置输出文件路径
|
||
output_path = self.output_dir / output_filename
|
||
task_info['output_file'] = str(output_path)
|
||
|
||
# 创建临时目录存储TS片段
|
||
temp_dir = self.output_dir / f"temp_{task_id}"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
print(f"开始下载任务 {task_id}: {len(ts_segments)} 个片段")
|
||
|
||
# 使用线程池下载所有TS片段
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||
futures = []
|
||
for i, ts_url in enumerate(ts_segments):
|
||
ts_path = temp_dir / f"segment_{i:05d}.ts"
|
||
future = executor.submit(
|
||
self.download_ts_segment,
|
||
task_info,
|
||
ts_url,
|
||
ts_path,
|
||
i
|
||
)
|
||
futures.append(future)
|
||
|
||
# 等待所有下载完成
|
||
results = []
|
||
for future in concurrent.futures.as_completed(futures):
|
||
results.append(future.result())
|
||
|
||
# 检查下载结果
|
||
if not all(results):
|
||
task_info['status'] = 'failed'
|
||
task_info['error'] = '部分片段下载失败'
|
||
task_info['progress'] = 0.0
|
||
print(f"任务 {task_id} 下载失败,部分片段下载失败")
|
||
return
|
||
|
||
# 合并TS文件
|
||
print(f"开始合并TS文件...")
|
||
task_info['status'] = 'merging'
|
||
task_info['progress'] = 1.0
|
||
|
||
with open(output_path, 'wb') as outfile:
|
||
for i in range(len(ts_segments)):
|
||
ts_path = temp_dir / f"segment_{i:05d}.ts"
|
||
if ts_path.exists():
|
||
with open(ts_path, 'rb') as infile:
|
||
outfile.write(infile.read())
|
||
ts_path.unlink()
|
||
|
||
# 清理临时目录
|
||
temp_dir.rmdir()
|
||
|
||
task_info['status'] = 'completed'
|
||
task_info['progress'] = 1.0
|
||
task_info['end_time'] = time.time()
|
||
|
||
print(f"任务 {task_id} 完成: {output_path}")
|
||
|
||
except Exception as e:
|
||
task_info['status'] = 'failed'
|
||
task_info['error'] = str(e)
|
||
task_info['progress'] = 0.0
|
||
print(f"任务 {task_id} 失败: {e}")
|
||
|
||
def download(self, output_filename, m3u8_url):
|
||
"""
|
||
下载M3U8视频
|
||
|
||
Args:
|
||
output_filename: 输出文件名(如:video.mp4)
|
||
m3u8_url: M3U8文件URL
|
||
|
||
Returns:
|
||
str: 任务ID
|
||
"""
|
||
with self.lock:
|
||
self.task_counter += 1
|
||
task_id = f"task_{self.task_counter}"
|
||
|
||
thread = threading.Thread(
|
||
target=self._download_m3u8,
|
||
args=(m3u8_url, output_filename, task_id)
|
||
)
|
||
thread.daemon = True
|
||
thread.start()
|
||
|
||
return task_id
|
||
|
||
def get_progress(self, task_id):
|
||
"""
|
||
获取下载进度
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
dict: 包含文件名和进度(0~1浮点数)的字典
|
||
"""
|
||
task_info = self.get_task_info(task_id)
|
||
|
||
if task_info['status'] == 'not_found':
|
||
return {'filename': '', 'progress': 0.0, 'status': 'not_found'}
|
||
|
||
progress = 0.0
|
||
if task_info['status'] == 'preparing':
|
||
progress = 0.0
|
||
elif task_info['status'] == 'downloading':
|
||
if task_info['total_segments'] > 0:
|
||
progress = task_info['downloaded_segments'] / task_info['total_segments']
|
||
else:
|
||
progress = 0.0
|
||
elif task_info['status'] == 'merging':
|
||
progress = 1.0
|
||
elif task_info['status'] == 'completed':
|
||
progress = 1.0
|
||
elif task_info['status'] == 'failed':
|
||
progress = 0.0
|
||
|
||
return {
|
||
'filename': task_info['output_filename'],
|
||
'progress': round(progress, 4),
|
||
'status': task_info['status'],
|
||
'task_id': task_id,
|
||
'output_file': task_info.get('output_file', ''),
|
||
'downloaded_segments': task_info.get('downloaded_segments', 0),
|
||
'total_segments': task_info.get('total_segments', 0)
|
||
}
|
||
|
||
def wait_for_completion(self, task_id, timeout=None):
|
||
"""
|
||
等待任务完成
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
bool: 是否成功完成
|
||
"""
|
||
start_time = time.time()
|
||
while True:
|
||
task_info = self.get_task_info(task_id)
|
||
|
||
if task_info['status'] == 'completed':
|
||
return True
|
||
elif task_info['status'] == 'failed':
|
||
return False
|
||
elif timeout and (time.time() - start_time) > timeout:
|
||
return False
|
||
|
||
time.sleep(1) |