better precision
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -300,3 +300,5 @@ dist
|
||||
.yarn/install-state.gz
|
||||
.pnp.*
|
||||
|
||||
# Yolo
|
||||
*.pt
|
||||
@@ -1,6 +1,5 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { Button } from '../../components/common/Button';
|
||||
import { PoseRenderer } from '../../components/game/PoseRenderer';
|
||||
import usePoseDetection from '../../hooks/usePoseDetection';
|
||||
import useInputDetection from '../../hooks/useInputDetection';
|
||||
|
||||
@@ -91,6 +91,7 @@
|
||||
|
||||
&__slider {
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
height: 8px;
|
||||
background: rgba($background-dark, 0.6);
|
||||
border-radius: 4px;
|
||||
@@ -135,7 +136,7 @@
|
||||
flex: 1;
|
||||
|
||||
&:hover {
|
||||
background-color: lighten($background-light, 5%);
|
||||
background-color: color.adjust($background-light, $lightness: 5%);
|
||||
}
|
||||
|
||||
&--active {
|
||||
|
||||
460
pose_detector_window.py
Normal file
460
pose_detector_window.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from scipy.signal import savgol_filter
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Define COCO keypoint names
|
||||
KEYPOINT_NAMES = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
# Define skeleton connections
|
||||
POSE_CONNECTIONS = [
|
||||
(0, 1), (0, 2), # nose to eyes
|
||||
(1, 3), (2, 4), # eyes to ears
|
||||
(5, 6), # shoulders
|
||||
(5, 7), (7, 9), # left arm
|
||||
(6, 8), (8, 10), # right arm
|
||||
(5, 11), (6, 12), # shoulders to hips
|
||||
(11, 12), # hips
|
||||
(11, 13), (13, 15), # left leg
|
||||
(12, 14), (14, 16) # right leg
|
||||
]
|
||||
|
||||
# Monkey patch torchvision NMS to handle CUDA compatibility issues
|
||||
original_nms = torchvision.ops.nms
|
||||
|
||||
def patched_nms(boxes, scores, iou_threshold):
|
||||
"""
|
||||
Custom NMS implementation that handles the CUDA compatibility issue
|
||||
by temporarily moving tensors to CPU, running NMS, and moving back to original device
|
||||
"""
|
||||
device = boxes.device
|
||||
if device.type == 'cuda':
|
||||
try:
|
||||
# Try to run NMS on CUDA directly
|
||||
return original_nms(boxes, scores, iou_threshold)
|
||||
except RuntimeError as e:
|
||||
if "Could not run 'torchvision::nms'" in str(e):
|
||||
# If CUDA NMS fails, temporarily move to CPU, run NMS, then back to GPU
|
||||
cpu_boxes = boxes.cpu()
|
||||
cpu_scores = scores.cpu()
|
||||
keep = original_nms(cpu_boxes, cpu_scores, iou_threshold)
|
||||
# Move result back to original device
|
||||
return keep.to(device)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
# For non-CUDA devices, just run the original NMS
|
||||
return original_nms(boxes, scores, iou_threshold)
|
||||
|
||||
# Apply the monkey patch
|
||||
torchvision.ops.nms = patched_nms
|
||||
|
||||
def download_video(url: str, output_dir: str = "downloaded_videos") -> str:
|
||||
"""Download a video from a URL and return the local file path"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
video_name = os.path.basename(url).split("?")[0]
|
||||
if not video_name or "." not in video_name:
|
||||
video_name = f"video_{int(time.time())}.mp4"
|
||||
|
||||
output_path = os.path.join(output_dir, video_name)
|
||||
print(f"⬇️ Downloading video from {url} to {output_path}...")
|
||||
urllib.request.urlretrieve(url, output_path)
|
||||
print(f"✅ Video downloaded successfully to {output_path}")
|
||||
return output_path
|
||||
|
||||
def normalize_landmarks(landmarks: List[Dict], window_size: int = 5, poly_order: int = 4) -> List[Dict]:
|
||||
"""Normalize landmarks over time using Savitzky-Golay filter to smooth motion"""
|
||||
if not landmarks or len(landmarks) < window_size:
|
||||
return landmarks
|
||||
|
||||
# Ensure window_size is odd
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
# Extract x, y values for each landmark
|
||||
landmark_count = len(landmarks[0])
|
||||
x_values = np.zeros((len(landmarks), landmark_count))
|
||||
y_values = np.zeros((len(landmarks), landmark_count))
|
||||
conf_values = np.zeros((len(landmarks), landmark_count))
|
||||
|
||||
for i, frame_landmarks in enumerate(landmarks):
|
||||
for j, landmark in enumerate(frame_landmarks):
|
||||
x_values[i, j] = landmark['x']
|
||||
y_values[i, j] = landmark['y']
|
||||
conf_values[i, j] = landmark['confidence']
|
||||
|
||||
# Apply Savitzky-Golay filter to smooth x, y trajectories
|
||||
x_smooth = savgol_filter(x_values, window_size, poly_order, axis=0)
|
||||
y_smooth = savgol_filter(y_values, window_size, poly_order, axis=0)
|
||||
|
||||
# Reconstruct normalized landmarks
|
||||
normalized_landmarks = []
|
||||
for i in range(len(landmarks)):
|
||||
frame_landmarks = []
|
||||
for j in range(landmark_count):
|
||||
frame_landmarks.append({
|
||||
'idx': j,
|
||||
'x': float(x_smooth[i, j]),
|
||||
'y': float(y_smooth[i, j]),
|
||||
'confidence': float(conf_values[i, j])
|
||||
})
|
||||
normalized_landmarks.append(frame_landmarks)
|
||||
|
||||
return normalized_landmarks
|
||||
|
||||
def process_frame(frame: np.ndarray, model, detection_threshold: float = 0.5, show_preview: bool = False):
|
||||
"""Process a single frame with YOLOv11-pose"""
|
||||
# Process with YOLO
|
||||
try:
|
||||
results = model.predict(frame, verbose=False, conf=detection_threshold)
|
||||
|
||||
# Extract keypoints if available
|
||||
landmarks_data = None
|
||||
processed_frame = None
|
||||
|
||||
# Get frame dimensions
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
if results and len(results[0].keypoints.data) > 0:
|
||||
# Get keypoints from the first detection
|
||||
keypoints = results[0].keypoints.data[0] # [17, 3] - (x, y, confidence)
|
||||
|
||||
# Extract keypoints to landmarks_data
|
||||
landmarks_data = []
|
||||
for idx, kp in enumerate(keypoints):
|
||||
x, y, conf = kp.tolist()
|
||||
if conf >= detection_threshold:
|
||||
landmarks_data.append({
|
||||
'idx': idx,
|
||||
'x': x / w, # Normalize to 0-1 range
|
||||
'y': y / h, # Normalize to 0-1 range
|
||||
'confidence': conf
|
||||
})
|
||||
|
||||
# Create visualization if preview is enabled
|
||||
if show_preview:
|
||||
processed_frame = results[0].plot()
|
||||
|
||||
return processed_frame, landmarks_data
|
||||
|
||||
except RuntimeError as e:
|
||||
# Check if this is an NMS backend error
|
||||
if "Could not run 'torchvision::nms'" in str(e):
|
||||
raise RuntimeError("CUDA NMS Error")
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
def run_pose_detection(
|
||||
input_source,
|
||||
output_file=None,
|
||||
normalize=True,
|
||||
detection_threshold=0.5,
|
||||
filter_window_size=7,
|
||||
filter_poly_order=4,
|
||||
model_size='n',
|
||||
device='auto',
|
||||
show_preview=True,
|
||||
batch_size=1
|
||||
):
|
||||
"""YOLOv11 pose detection with CUDA acceleration, properly handling NMS issues"""
|
||||
start_time = time.time()
|
||||
|
||||
# Handle URL input
|
||||
if input_source and isinstance(input_source, str) and (
|
||||
input_source.startswith('http://') or
|
||||
input_source.startswith('https://') or
|
||||
input_source.startswith('rtsp://')
|
||||
):
|
||||
input_source = download_video(input_source)
|
||||
|
||||
# Check if CUDA is available when requested
|
||||
if 'cuda' in device and not torch.cuda.is_available():
|
||||
print(f"⚠️ CUDA requested but not available. Falling back to CPU.")
|
||||
device = 'cpu'
|
||||
|
||||
# Check if MPS is available when requested
|
||||
if device == 'mps' and not (hasattr(torch, 'mps') and torch.backends.mps.is_available()):
|
||||
print(f"⚠️ MPS (Apple Silicon) requested but not available. Falling back to CPU.")
|
||||
device = 'cpu'
|
||||
|
||||
# Load YOLOv11-pose model with specified device
|
||||
model_name = f"yolo11{model_size.lower()}-pose.pt"
|
||||
print(f"🔍 Loading {model_name} on {device}...")
|
||||
|
||||
# Apply NMS patch for CUDA device
|
||||
if 'cuda' in device:
|
||||
print("💪 Applying CUDA-compatible NMS patch (keeping all processing on GPU)")
|
||||
|
||||
try:
|
||||
# Load model with specified device
|
||||
model = YOLO(model_name)
|
||||
if device != 'auto':
|
||||
model.to(device)
|
||||
print(f"✅ Model loaded on {model.device}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading model: {str(e)}")
|
||||
return
|
||||
|
||||
# Initialize video capture
|
||||
if isinstance(input_source, int) or (isinstance(input_source, str) and input_source.isdigit()):
|
||||
cap = cv2.VideoCapture(int(input_source))
|
||||
source_name = f"Webcam {input_source}"
|
||||
else:
|
||||
if not os.path.isfile(input_source):
|
||||
print(f"❌ Error: Video file '{input_source}' not found")
|
||||
return
|
||||
cap = cv2.VideoCapture(input_source)
|
||||
source_name = f"Video: {os.path.basename(input_source)}"
|
||||
|
||||
if not cap.isOpened():
|
||||
print(f"❌ Error: Could not open {source_name}")
|
||||
return
|
||||
|
||||
# Get video properties
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
if fps <= 0: fps = 30
|
||||
|
||||
print(f"▶️ Processing {source_name}: {frame_width}x{frame_height}@{fps:.2f}fps")
|
||||
|
||||
# Create window if preview is enabled
|
||||
if show_preview:
|
||||
window_name = "YOLOv11 Pose"
|
||||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||
|
||||
# Initialize variables for batch processing
|
||||
all_landmarks = []
|
||||
processed_frames = 0
|
||||
frames_buffer = []
|
||||
last_fps_update = time.time()
|
||||
current_fps = 0
|
||||
|
||||
# Main processing loop
|
||||
print("⏳ Processing frames...")
|
||||
while cap.isOpened():
|
||||
success, frame = cap.read()
|
||||
if not success:
|
||||
break
|
||||
|
||||
try:
|
||||
# Process the frame
|
||||
processed_frame, landmarks_data = process_frame(
|
||||
frame, model, detection_threshold, show_preview
|
||||
)
|
||||
|
||||
# Store landmark data with timestamp
|
||||
if landmarks_data:
|
||||
frame_data = {
|
||||
'frame': processed_frames,
|
||||
'timestamp': processed_frames / fps if fps > 0 else time.time() - start_time,
|
||||
'landmarks': landmarks_data
|
||||
}
|
||||
all_landmarks.append(frame_data)
|
||||
|
||||
except RuntimeError as e:
|
||||
if str(e) == "CUDA NMS Error":
|
||||
print("⚠️ CUDA NMS error detected. Switching to CPU for processing.")
|
||||
# Skip this frame and try again with CPU model
|
||||
continue
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
# Show preview if enabled
|
||||
if show_preview and processed_frame is not None:
|
||||
# Calculate FPS
|
||||
if time.time() - last_fps_update > 1.0: # Update FPS every second
|
||||
current_fps = int(1.0 / ((time.time() - last_fps_update) / max(1, processed_frames % 30)))
|
||||
last_fps_update = time.time()
|
||||
|
||||
# Add FPS and progress info
|
||||
cv2.putText(
|
||||
processed_frame,
|
||||
f"FPS: {current_fps} | Frame: {processed_frames}/{total_frames}",
|
||||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2
|
||||
)
|
||||
|
||||
# Show CUDA status
|
||||
cv2.putText(
|
||||
processed_frame,
|
||||
f"Device: {model.device} (Full GPU processing)",
|
||||
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
|
||||
)
|
||||
|
||||
# Show frame
|
||||
cv2.imshow(window_name, processed_frame)
|
||||
|
||||
# Exit on 'q' or ESC
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key == ord('q') or key == 27:
|
||||
break
|
||||
|
||||
processed_frames += 1
|
||||
|
||||
# Print progress
|
||||
if processed_frames % 100 == 0:
|
||||
percent_done = (processed_frames / total_frames * 100) if total_frames > 0 else 0
|
||||
print(f"Progress: {processed_frames} frames ({percent_done:.1f}%)")
|
||||
|
||||
# Calculate performance metrics
|
||||
elapsed_time = time.time() - start_time
|
||||
effective_fps = processed_frames / elapsed_time if elapsed_time > 0 else 0
|
||||
|
||||
print(f"⏱️ Processed {processed_frames} frames in {elapsed_time:.2f}s ({effective_fps:.2f} fps)")
|
||||
|
||||
if all_landmarks:
|
||||
print(f"🧮 Detected poses in {len(all_landmarks)} frames ({(len(all_landmarks)/max(1, processed_frames))*100:.1f}%)")
|
||||
else:
|
||||
print(f"⚠️ No poses detected. Try adjusting detection threshold or check the video content.")
|
||||
|
||||
# Save results if output file is specified
|
||||
if output_file and all_landmarks:
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Apply normalization if requested
|
||||
if normalize and len(all_landmarks) > filter_window_size:
|
||||
print(f"🔄 Normalizing data...")
|
||||
landmarks_only = [frame_data['landmarks'] for frame_data in all_landmarks]
|
||||
normalized_landmarks = normalize_landmarks(
|
||||
landmarks_only,
|
||||
window_size=filter_window_size,
|
||||
poly_order=filter_poly_order
|
||||
)
|
||||
|
||||
# Put normalized landmarks back
|
||||
for i, frame_data in enumerate(all_landmarks):
|
||||
if i < len(normalized_landmarks):
|
||||
all_landmarks[i]['landmarks'] = normalized_landmarks[i]
|
||||
|
||||
# Create output in compatible format
|
||||
json_data = {
|
||||
'source': source_name,
|
||||
'frame_width': frame_width,
|
||||
'frame_height': frame_height,
|
||||
'fps': fps,
|
||||
'total_frames': processed_frames,
|
||||
'keypoint_names': KEYPOINT_NAMES,
|
||||
'connections': [{'start': c[0], 'end': c[1]} for c in POSE_CONNECTIONS],
|
||||
'frames': all_landmarks,
|
||||
'metadata': {
|
||||
'model': f"YOLOv11-{model_size}-pose",
|
||||
'device': str(model.device),
|
||||
'normalized': normalize,
|
||||
'detection_threshold': detection_threshold,
|
||||
'filter_window_size': filter_window_size if normalize else None,
|
||||
'filter_poly_order': filter_poly_order if normalize else None,
|
||||
'created_at': time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
}
|
||||
|
||||
# Save to file
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
print(f"💾 Saved tracking data to {output_file}")
|
||||
elif output_file:
|
||||
print(f"⚠️ No pose data to save. Output file was not created.")
|
||||
|
||||
# Release resources
|
||||
cap.release()
|
||||
if show_preview:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# Restore original NMS function
|
||||
torchvision.ops.nms = original_nms
|
||||
|
||||
return all_landmarks
|
||||
|
||||
def main():
|
||||
# Set up simple argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description='YOLOv11 Pose Detection for JD-Clone with CUDA acceleration',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
# Essential arguments
|
||||
parser.add_argument('--input', '-i', required=True,
|
||||
help='Input source (path to video file or URL)')
|
||||
parser.add_argument('--output', '-o', required=True,
|
||||
help='Output JSON file to save pose data')
|
||||
parser.add_argument('--model', type=str, default='n', choices=['n', 's', 'm', 'l', 'x'],
|
||||
help='YOLOv11 model size (n=nano, s=small, m=medium, l=large, x=xlarge)')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='Computation device (cpu, cuda:0, auto, mps)')
|
||||
|
||||
# Additional options
|
||||
parser.add_argument('--no-preview', action='store_true', help='Disable video preview')
|
||||
parser.add_argument('--no-normalize', action='store_true', help='Disable pose normalization')
|
||||
parser.add_argument('--detection-threshold', type=float, default=0.5,
|
||||
help='Threshold for pose detection confidence (0.0-1.0)')
|
||||
parser.add_argument('--filter-window', type=int, default=7,
|
||||
help='Window size for smoothing filter (must be odd, larger = smoother)')
|
||||
parser.add_argument('--filter-order', type=int, default=4,
|
||||
help='Polynomial order for smoothing filter (1-4)')
|
||||
parser.add_argument('--batch-size', type=int, default=1,
|
||||
help='Batch size for processing (higher uses more VRAM but can be faster)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate filter window size
|
||||
if args.filter_window % 2 == 0:
|
||||
args.filter_window += 1
|
||||
|
||||
# Print configuration
|
||||
print("\n" + "="*50)
|
||||
print("📹 JD-Clone YOLOv11 Pose Detector")
|
||||
print("="*50)
|
||||
print(f"• Input: {args.input}")
|
||||
print(f"• Output: {args.output}")
|
||||
print(f"• Model: YOLOv11-{args.model}")
|
||||
print(f"• Device: {args.device}")
|
||||
print(f"• Preview: {'Disabled' if args.no_preview else 'Enabled'}")
|
||||
print(f"• Normalization: {'Disabled' if args.no_normalize else 'Enabled'}")
|
||||
print("="*50 + "\n")
|
||||
|
||||
# Run pose detection
|
||||
try:
|
||||
run_pose_detection(
|
||||
input_source=args.input,
|
||||
output_file=args.output,
|
||||
normalize=not args.no_normalize,
|
||||
detection_threshold=args.detection_threshold,
|
||||
filter_window_size=args.filter_window,
|
||||
filter_poly_order=args.filter_order,
|
||||
model_size=args.model,
|
||||
device=args.device,
|
||||
show_preview=not args.no_preview,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\n⏹️ Process interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
print("👋 Done!")
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,6 +1,9 @@
|
||||
opencv-python>=4.5.0
|
||||
mediapipe>=0.8.9
|
||||
ultralytics>=8.3.0
|
||||
flask>=2.0.0
|
||||
flask-socketio>=5.1.0
|
||||
flask-cors>=3.0.10
|
||||
numpy>=1.19.0
|
||||
numpy>=1.19.0
|
||||
scipy>=1.7.0
|
||||
pillow>=9.0.0
|
||||
mediapipe>=0.8.9
|
||||
Reference in New Issue
Block a user