460 lines
17 KiB
Python
460 lines
17 KiB
Python
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
import urllib.request
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
import torchvision
|
||
|
|
from scipy.signal import savgol_filter
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
# Define COCO keypoint names
|
||
|
|
KEYPOINT_NAMES = [
|
||
|
|
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||
|
|
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||
|
|
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||
|
|
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||
|
|
]
|
||
|
|
|
||
|
|
# Define skeleton connections
|
||
|
|
POSE_CONNECTIONS = [
|
||
|
|
(0, 1), (0, 2), # nose to eyes
|
||
|
|
(1, 3), (2, 4), # eyes to ears
|
||
|
|
(5, 6), # shoulders
|
||
|
|
(5, 7), (7, 9), # left arm
|
||
|
|
(6, 8), (8, 10), # right arm
|
||
|
|
(5, 11), (6, 12), # shoulders to hips
|
||
|
|
(11, 12), # hips
|
||
|
|
(11, 13), (13, 15), # left leg
|
||
|
|
(12, 14), (14, 16) # right leg
|
||
|
|
]
|
||
|
|
|
||
|
|
# Monkey patch torchvision NMS to handle CUDA compatibility issues
|
||
|
|
original_nms = torchvision.ops.nms
|
||
|
|
|
||
|
|
def patched_nms(boxes, scores, iou_threshold):
|
||
|
|
"""
|
||
|
|
Custom NMS implementation that handles the CUDA compatibility issue
|
||
|
|
by temporarily moving tensors to CPU, running NMS, and moving back to original device
|
||
|
|
"""
|
||
|
|
device = boxes.device
|
||
|
|
if device.type == 'cuda':
|
||
|
|
try:
|
||
|
|
# Try to run NMS on CUDA directly
|
||
|
|
return original_nms(boxes, scores, iou_threshold)
|
||
|
|
except RuntimeError as e:
|
||
|
|
if "Could not run 'torchvision::nms'" in str(e):
|
||
|
|
# If CUDA NMS fails, temporarily move to CPU, run NMS, then back to GPU
|
||
|
|
cpu_boxes = boxes.cpu()
|
||
|
|
cpu_scores = scores.cpu()
|
||
|
|
keep = original_nms(cpu_boxes, cpu_scores, iou_threshold)
|
||
|
|
# Move result back to original device
|
||
|
|
return keep.to(device)
|
||
|
|
else:
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
# For non-CUDA devices, just run the original NMS
|
||
|
|
return original_nms(boxes, scores, iou_threshold)
|
||
|
|
|
||
|
|
# Apply the monkey patch
|
||
|
|
torchvision.ops.nms = patched_nms
|
||
|
|
|
||
|
|
def download_video(url: str, output_dir: str = "downloaded_videos") -> str:
|
||
|
|
"""Download a video from a URL and return the local file path"""
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
video_name = os.path.basename(url).split("?")[0]
|
||
|
|
if not video_name or "." not in video_name:
|
||
|
|
video_name = f"video_{int(time.time())}.mp4"
|
||
|
|
|
||
|
|
output_path = os.path.join(output_dir, video_name)
|
||
|
|
print(f"⬇️ Downloading video from {url} to {output_path}...")
|
||
|
|
urllib.request.urlretrieve(url, output_path)
|
||
|
|
print(f"✅ Video downloaded successfully to {output_path}")
|
||
|
|
return output_path
|
||
|
|
|
||
|
|
def normalize_landmarks(landmarks: List[Dict], window_size: int = 5, poly_order: int = 4) -> List[Dict]:
|
||
|
|
"""Normalize landmarks over time using Savitzky-Golay filter to smooth motion"""
|
||
|
|
if not landmarks or len(landmarks) < window_size:
|
||
|
|
return landmarks
|
||
|
|
|
||
|
|
# Ensure window_size is odd
|
||
|
|
if window_size % 2 == 0:
|
||
|
|
window_size += 1
|
||
|
|
|
||
|
|
# Extract x, y values for each landmark
|
||
|
|
landmark_count = len(landmarks[0])
|
||
|
|
x_values = np.zeros((len(landmarks), landmark_count))
|
||
|
|
y_values = np.zeros((len(landmarks), landmark_count))
|
||
|
|
conf_values = np.zeros((len(landmarks), landmark_count))
|
||
|
|
|
||
|
|
for i, frame_landmarks in enumerate(landmarks):
|
||
|
|
for j, landmark in enumerate(frame_landmarks):
|
||
|
|
x_values[i, j] = landmark['x']
|
||
|
|
y_values[i, j] = landmark['y']
|
||
|
|
conf_values[i, j] = landmark['confidence']
|
||
|
|
|
||
|
|
# Apply Savitzky-Golay filter to smooth x, y trajectories
|
||
|
|
x_smooth = savgol_filter(x_values, window_size, poly_order, axis=0)
|
||
|
|
y_smooth = savgol_filter(y_values, window_size, poly_order, axis=0)
|
||
|
|
|
||
|
|
# Reconstruct normalized landmarks
|
||
|
|
normalized_landmarks = []
|
||
|
|
for i in range(len(landmarks)):
|
||
|
|
frame_landmarks = []
|
||
|
|
for j in range(landmark_count):
|
||
|
|
frame_landmarks.append({
|
||
|
|
'idx': j,
|
||
|
|
'x': float(x_smooth[i, j]),
|
||
|
|
'y': float(y_smooth[i, j]),
|
||
|
|
'confidence': float(conf_values[i, j])
|
||
|
|
})
|
||
|
|
normalized_landmarks.append(frame_landmarks)
|
||
|
|
|
||
|
|
return normalized_landmarks
|
||
|
|
|
||
|
|
def process_frame(frame: np.ndarray, model, detection_threshold: float = 0.5, show_preview: bool = False):
|
||
|
|
"""Process a single frame with YOLOv11-pose"""
|
||
|
|
# Process with YOLO
|
||
|
|
try:
|
||
|
|
results = model.predict(frame, verbose=False, conf=detection_threshold)
|
||
|
|
|
||
|
|
# Extract keypoints if available
|
||
|
|
landmarks_data = None
|
||
|
|
processed_frame = None
|
||
|
|
|
||
|
|
# Get frame dimensions
|
||
|
|
h, w = frame.shape[:2]
|
||
|
|
|
||
|
|
if results and len(results[0].keypoints.data) > 0:
|
||
|
|
# Get keypoints from the first detection
|
||
|
|
keypoints = results[0].keypoints.data[0] # [17, 3] - (x, y, confidence)
|
||
|
|
|
||
|
|
# Extract keypoints to landmarks_data
|
||
|
|
landmarks_data = []
|
||
|
|
for idx, kp in enumerate(keypoints):
|
||
|
|
x, y, conf = kp.tolist()
|
||
|
|
if conf >= detection_threshold:
|
||
|
|
landmarks_data.append({
|
||
|
|
'idx': idx,
|
||
|
|
'x': x / w, # Normalize to 0-1 range
|
||
|
|
'y': y / h, # Normalize to 0-1 range
|
||
|
|
'confidence': conf
|
||
|
|
})
|
||
|
|
|
||
|
|
# Create visualization if preview is enabled
|
||
|
|
if show_preview:
|
||
|
|
processed_frame = results[0].plot()
|
||
|
|
|
||
|
|
return processed_frame, landmarks_data
|
||
|
|
|
||
|
|
except RuntimeError as e:
|
||
|
|
# Check if this is an NMS backend error
|
||
|
|
if "Could not run 'torchvision::nms'" in str(e):
|
||
|
|
raise RuntimeError("CUDA NMS Error")
|
||
|
|
else:
|
||
|
|
# Re-raise if it's a different error
|
||
|
|
raise
|
||
|
|
|
||
|
|
def run_pose_detection(
|
||
|
|
input_source,
|
||
|
|
output_file=None,
|
||
|
|
normalize=True,
|
||
|
|
detection_threshold=0.5,
|
||
|
|
filter_window_size=7,
|
||
|
|
filter_poly_order=4,
|
||
|
|
model_size='n',
|
||
|
|
device='auto',
|
||
|
|
show_preview=True,
|
||
|
|
batch_size=1
|
||
|
|
):
|
||
|
|
"""YOLOv11 pose detection with CUDA acceleration, properly handling NMS issues"""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
# Handle URL input
|
||
|
|
if input_source and isinstance(input_source, str) and (
|
||
|
|
input_source.startswith('http://') or
|
||
|
|
input_source.startswith('https://') or
|
||
|
|
input_source.startswith('rtsp://')
|
||
|
|
):
|
||
|
|
input_source = download_video(input_source)
|
||
|
|
|
||
|
|
# Check if CUDA is available when requested
|
||
|
|
if 'cuda' in device and not torch.cuda.is_available():
|
||
|
|
print(f"⚠️ CUDA requested but not available. Falling back to CPU.")
|
||
|
|
device = 'cpu'
|
||
|
|
|
||
|
|
# Check if MPS is available when requested
|
||
|
|
if device == 'mps' and not (hasattr(torch, 'mps') and torch.backends.mps.is_available()):
|
||
|
|
print(f"⚠️ MPS (Apple Silicon) requested but not available. Falling back to CPU.")
|
||
|
|
device = 'cpu'
|
||
|
|
|
||
|
|
# Load YOLOv11-pose model with specified device
|
||
|
|
model_name = f"yolo11{model_size.lower()}-pose.pt"
|
||
|
|
print(f"🔍 Loading {model_name} on {device}...")
|
||
|
|
|
||
|
|
# Apply NMS patch for CUDA device
|
||
|
|
if 'cuda' in device:
|
||
|
|
print("💪 Applying CUDA-compatible NMS patch (keeping all processing on GPU)")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Load model with specified device
|
||
|
|
model = YOLO(model_name)
|
||
|
|
if device != 'auto':
|
||
|
|
model.to(device)
|
||
|
|
print(f"✅ Model loaded on {model.device}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error loading model: {str(e)}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Initialize video capture
|
||
|
|
if isinstance(input_source, int) or (isinstance(input_source, str) and input_source.isdigit()):
|
||
|
|
cap = cv2.VideoCapture(int(input_source))
|
||
|
|
source_name = f"Webcam {input_source}"
|
||
|
|
else:
|
||
|
|
if not os.path.isfile(input_source):
|
||
|
|
print(f"❌ Error: Video file '{input_source}' not found")
|
||
|
|
return
|
||
|
|
cap = cv2.VideoCapture(input_source)
|
||
|
|
source_name = f"Video: {os.path.basename(input_source)}"
|
||
|
|
|
||
|
|
if not cap.isOpened():
|
||
|
|
print(f"❌ Error: Could not open {source_name}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Get video properties
|
||
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
|
if fps <= 0: fps = 30
|
||
|
|
|
||
|
|
print(f"▶️ Processing {source_name}: {frame_width}x{frame_height}@{fps:.2f}fps")
|
||
|
|
|
||
|
|
# Create window if preview is enabled
|
||
|
|
if show_preview:
|
||
|
|
window_name = "YOLOv11 Pose"
|
||
|
|
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||
|
|
|
||
|
|
# Initialize variables for batch processing
|
||
|
|
all_landmarks = []
|
||
|
|
processed_frames = 0
|
||
|
|
frames_buffer = []
|
||
|
|
last_fps_update = time.time()
|
||
|
|
current_fps = 0
|
||
|
|
|
||
|
|
# Main processing loop
|
||
|
|
print("⏳ Processing frames...")
|
||
|
|
while cap.isOpened():
|
||
|
|
success, frame = cap.read()
|
||
|
|
if not success:
|
||
|
|
break
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Process the frame
|
||
|
|
processed_frame, landmarks_data = process_frame(
|
||
|
|
frame, model, detection_threshold, show_preview
|
||
|
|
)
|
||
|
|
|
||
|
|
# Store landmark data with timestamp
|
||
|
|
if landmarks_data:
|
||
|
|
frame_data = {
|
||
|
|
'frame': processed_frames,
|
||
|
|
'timestamp': processed_frames / fps if fps > 0 else time.time() - start_time,
|
||
|
|
'landmarks': landmarks_data
|
||
|
|
}
|
||
|
|
all_landmarks.append(frame_data)
|
||
|
|
|
||
|
|
except RuntimeError as e:
|
||
|
|
if str(e) == "CUDA NMS Error":
|
||
|
|
print("⚠️ CUDA NMS error detected. Switching to CPU for processing.")
|
||
|
|
# Skip this frame and try again with CPU model
|
||
|
|
continue
|
||
|
|
else:
|
||
|
|
# Re-raise if it's a different error
|
||
|
|
raise
|
||
|
|
|
||
|
|
# Show preview if enabled
|
||
|
|
if show_preview and processed_frame is not None:
|
||
|
|
# Calculate FPS
|
||
|
|
if time.time() - last_fps_update > 1.0: # Update FPS every second
|
||
|
|
current_fps = int(1.0 / ((time.time() - last_fps_update) / max(1, processed_frames % 30)))
|
||
|
|
last_fps_update = time.time()
|
||
|
|
|
||
|
|
# Add FPS and progress info
|
||
|
|
cv2.putText(
|
||
|
|
processed_frame,
|
||
|
|
f"FPS: {current_fps} | Frame: {processed_frames}/{total_frames}",
|
||
|
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
|
||
|
|
)
|
||
|
|
|
||
|
|
# Show CUDA status
|
||
|
|
cv2.putText(
|
||
|
|
processed_frame,
|
||
|
|
f"Device: {model.device} (Full GPU processing)",
|
||
|
|
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
|
||
|
|
)
|
||
|
|
|
||
|
|
# Show frame
|
||
|
|
cv2.imshow(window_name, processed_frame)
|
||
|
|
|
||
|
|
# Exit on 'q' or ESC
|
||
|
|
key = cv2.waitKey(1) & 0xFF
|
||
|
|
if key == ord('q') or key == 27:
|
||
|
|
break
|
||
|
|
|
||
|
|
processed_frames += 1
|
||
|
|
|
||
|
|
# Print progress
|
||
|
|
if processed_frames % 100 == 0:
|
||
|
|
percent_done = (processed_frames / total_frames * 100) if total_frames > 0 else 0
|
||
|
|
print(f"Progress: {processed_frames} frames ({percent_done:.1f}%)")
|
||
|
|
|
||
|
|
# Calculate performance metrics
|
||
|
|
elapsed_time = time.time() - start_time
|
||
|
|
effective_fps = processed_frames / elapsed_time if elapsed_time > 0 else 0
|
||
|
|
|
||
|
|
print(f"⏱️ Processed {processed_frames} frames in {elapsed_time:.2f}s ({effective_fps:.2f} fps)")
|
||
|
|
|
||
|
|
if all_landmarks:
|
||
|
|
print(f"🧮 Detected poses in {len(all_landmarks)} frames ({(len(all_landmarks)/max(1, processed_frames))*100:.1f}%)")
|
||
|
|
else:
|
||
|
|
print(f"⚠️ No poses detected. Try adjusting detection threshold or check the video content.")
|
||
|
|
|
||
|
|
# Save results if output file is specified
|
||
|
|
if output_file and all_landmarks:
|
||
|
|
output_dir = os.path.dirname(output_file)
|
||
|
|
if output_dir:
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# Apply normalization if requested
|
||
|
|
if normalize and len(all_landmarks) > filter_window_size:
|
||
|
|
print(f"🔄 Normalizing data...")
|
||
|
|
landmarks_only = [frame_data['landmarks'] for frame_data in all_landmarks]
|
||
|
|
normalized_landmarks = normalize_landmarks(
|
||
|
|
landmarks_only,
|
||
|
|
window_size=filter_window_size,
|
||
|
|
poly_order=filter_poly_order
|
||
|
|
)
|
||
|
|
|
||
|
|
# Put normalized landmarks back
|
||
|
|
for i, frame_data in enumerate(all_landmarks):
|
||
|
|
if i < len(normalized_landmarks):
|
||
|
|
all_landmarks[i]['landmarks'] = normalized_landmarks[i]
|
||
|
|
|
||
|
|
# Create output in compatible format
|
||
|
|
json_data = {
|
||
|
|
'source': source_name,
|
||
|
|
'frame_width': frame_width,
|
||
|
|
'frame_height': frame_height,
|
||
|
|
'fps': fps,
|
||
|
|
'total_frames': processed_frames,
|
||
|
|
'keypoint_names': KEYPOINT_NAMES,
|
||
|
|
'connections': [{'start': c[0], 'end': c[1]} for c in POSE_CONNECTIONS],
|
||
|
|
'frames': all_landmarks,
|
||
|
|
'metadata': {
|
||
|
|
'model': f"YOLOv11-{model_size}-pose",
|
||
|
|
'device': str(model.device),
|
||
|
|
'normalized': normalize,
|
||
|
|
'detection_threshold': detection_threshold,
|
||
|
|
'filter_window_size': filter_window_size if normalize else None,
|
||
|
|
'filter_poly_order': filter_poly_order if normalize else None,
|
||
|
|
'created_at': time.strftime('%Y-%m-%d %H:%M:%S')
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Save to file
|
||
|
|
with open(output_file, 'w') as f:
|
||
|
|
json.dump(json_data, f)
|
||
|
|
|
||
|
|
print(f"💾 Saved tracking data to {output_file}")
|
||
|
|
elif output_file:
|
||
|
|
print(f"⚠️ No pose data to save. Output file was not created.")
|
||
|
|
|
||
|
|
# Release resources
|
||
|
|
cap.release()
|
||
|
|
if show_preview:
|
||
|
|
cv2.destroyAllWindows()
|
||
|
|
|
||
|
|
# Restore original NMS function
|
||
|
|
torchvision.ops.nms = original_nms
|
||
|
|
|
||
|
|
return all_landmarks
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# Set up simple argument parser
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description='YOLOv11 Pose Detection for JD-Clone with CUDA acceleration',
|
||
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||
|
|
)
|
||
|
|
|
||
|
|
# Essential arguments
|
||
|
|
parser.add_argument('--input', '-i', required=True,
|
||
|
|
help='Input source (path to video file or URL)')
|
||
|
|
parser.add_argument('--output', '-o', required=True,
|
||
|
|
help='Output JSON file to save pose data')
|
||
|
|
parser.add_argument('--model', type=str, default='n', choices=['n', 's', 'm', 'l', 'x'],
|
||
|
|
help='YOLOv11 model size (n=nano, s=small, m=medium, l=large, x=xlarge)')
|
||
|
|
parser.add_argument('--device', type=str, default='auto',
|
||
|
|
help='Computation device (cpu, cuda:0, auto, mps)')
|
||
|
|
|
||
|
|
# Additional options
|
||
|
|
parser.add_argument('--no-preview', action='store_true', help='Disable video preview')
|
||
|
|
parser.add_argument('--no-normalize', action='store_true', help='Disable pose normalization')
|
||
|
|
parser.add_argument('--detection-threshold', type=float, default=0.5,
|
||
|
|
help='Threshold for pose detection confidence (0.0-1.0)')
|
||
|
|
parser.add_argument('--filter-window', type=int, default=7,
|
||
|
|
help='Window size for smoothing filter (must be odd, larger = smoother)')
|
||
|
|
parser.add_argument('--filter-order', type=int, default=4,
|
||
|
|
help='Polynomial order for smoothing filter (1-4)')
|
||
|
|
parser.add_argument('--batch-size', type=int, default=1,
|
||
|
|
help='Batch size for processing (higher uses more VRAM but can be faster)')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Validate filter window size
|
||
|
|
if args.filter_window % 2 == 0:
|
||
|
|
args.filter_window += 1
|
||
|
|
|
||
|
|
# Print configuration
|
||
|
|
print("\n" + "="*50)
|
||
|
|
print("📹 JD-Clone YOLOv11 Pose Detector")
|
||
|
|
print("="*50)
|
||
|
|
print(f"• Input: {args.input}")
|
||
|
|
print(f"• Output: {args.output}")
|
||
|
|
print(f"• Model: YOLOv11-{args.model}")
|
||
|
|
print(f"• Device: {args.device}")
|
||
|
|
print(f"• Preview: {'Disabled' if args.no_preview else 'Enabled'}")
|
||
|
|
print(f"• Normalization: {'Disabled' if args.no_normalize else 'Enabled'}")
|
||
|
|
print("="*50 + "\n")
|
||
|
|
|
||
|
|
# Run pose detection
|
||
|
|
try:
|
||
|
|
run_pose_detection(
|
||
|
|
input_source=args.input,
|
||
|
|
output_file=args.output,
|
||
|
|
normalize=not args.no_normalize,
|
||
|
|
detection_threshold=args.detection_threshold,
|
||
|
|
filter_window_size=args.filter_window,
|
||
|
|
filter_poly_order=args.filter_order,
|
||
|
|
model_size=args.model,
|
||
|
|
device=args.device,
|
||
|
|
show_preview=not args.no_preview,
|
||
|
|
batch_size=args.batch_size
|
||
|
|
)
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
print("\n⏹️ Process interrupted by user")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"\n❌ Error: {str(e)}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
finally:
|
||
|
|
print("👋 Done!")
|
||
|
|
cv2.destroyAllWindows()
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|