Files
JDClone/pose_viewer.py

436 lines
15 KiB
Python
Raw Permalink Normal View History

import argparse
import json
import os
import sys
import time
import urllib.request
from pathlib import Path
import cv2
import numpy as np
import pygame
from pygame.locals import *
# Define colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)
CYAN = (0, 255, 255)
MAGENTA = (255, 0, 255)
# Define keypoint colors (custom palette)
KEYPOINT_COLORS = [
(255, 0, 0), # nose (red)
(255, 85, 0), # left_eye (orange-red)
(255, 170, 0), # right_eye (orange)
(255, 255, 0), # left_ear (yellow)
(170, 255, 0), # right_ear (yellow-green)
(85, 255, 0), # left_shoulder (green-yellow)
(0, 255, 0), # right_shoulder (green)
(0, 255, 85), # left_elbow (green-cyan)
(0, 255, 170), # right_elbow (cyan-green)
(0, 255, 255), # left_wrist (cyan)
(0, 170, 255), # right_wrist (cyan-blue)
(0, 85, 255), # left_hip (blue-cyan)
(0, 0, 255), # right_hip (blue)
(85, 0, 255), # left_knee (blue-purple)
(170, 0, 255), # right_knee (purple-blue)
(255, 0, 255), # left_ankle (magenta)
(255, 0, 170) # right_ankle (magenta-pink)
]
# Person ID colors
PERSON_COLORS = [
(255, 0, 0), # red
(0, 255, 0), # green
(0, 0, 255), # blue
(255, 255, 0), # yellow
(255, 0, 255), # magenta
(0, 255, 255), # cyan
(255, 128, 0), # orange
(128, 0, 255), # purple
(0, 255, 128), # mint
(255, 255, 255) # white
]
def download_video(url, output_dir="downloaded_videos"):
"""Download a video from a URL and return the local file path"""
os.makedirs(output_dir, exist_ok=True)
video_name = os.path.basename(url).split("?")[0]
if not video_name or "." not in video_name:
video_name = f"video_{int(time.time())}.mp4"
output_path = os.path.join(output_dir, video_name)
if os.path.exists(output_path):
print(f"✅ Video already downloaded: {output_path}")
return output_path
print(f"⬇️ Downloading video from {url} to {output_path}...")
urllib.request.urlretrieve(url, output_path)
print(f"✅ Video downloaded successfully to {output_path}")
return output_path
def load_pose_data(json_file):
"""Load pose data from a JSON file"""
print(f"📂 Loading pose data from {json_file}...")
with open(json_file, 'r') as f:
data = json.load(f)
# Extract metadata
width = data.get('w', 1280)
height = data.get('h', 720)
fps = data.get('fps', 30)
total_frames = data.get('frames', 0)
# Get frame sampling and precision from metadata if available
metadata = data.get('meta', {})
frame_sampling = metadata.get('frame_sampling', 1)
precision = metadata.get('precision', 3)
# Extract connections
connections = []
for conn in data.get('connections', []):
start = conn.get('s', 0)
end = conn.get('e', 0)
connections.append((start, end))
# Extract keypoint names
keypoint_names = data.get('keypoints', [])
# Extract frame data
frames = data.get('data', [])
print(f"✅ Loaded {len(frames)} frames of pose data")
print(f"📊 Video: {width}x{height}@{fps}fps, {total_frames} total frames")
print(f"🔍 Frame sampling: {frame_sampling}, Precision: {precision}")
return {
'width': width,
'height': height,
'fps': fps,
'total_frames': total_frames,
'frame_sampling': frame_sampling,
'precision': precision,
'connections': connections,
'keypoint_names': keypoint_names,
'frames': frames
}
def create_pygame_window(width, height, title="Pose Viewer"):
"""Create a PyGame window"""
pygame.init()
window = pygame.display.set_mode((width, height))
pygame.display.set_caption(title)
return window
def draw_pose(frame, pose_data, frame_idx, original_width, original_height, prev_frame_idx=None):
"""Draw pose data on a given frame"""
# Clone the frame to avoid modifying the original
pose_frame = frame.copy()
# Find the closest pose frame to the current video frame (should be first and only in temp_pose_data)
if pose_data['frames']:
closest_frame = pose_data['frames'][0]
connections = pose_data['connections']
# Draw each person
for person_idx, person in enumerate(closest_frame['p']):
person_id = person['id']
person_color = PERSON_COLORS[person_id % len(PERSON_COLORS)]
# Get keypoints
keypoints = person['k']
# Create a dictionary to store keypoints by index
kp_dict = {}
for kp in keypoints:
kp_dict[kp[0]] = (
int(kp[1] * original_width),
int(kp[2] * original_height),
kp[3]
)
# Draw connections
for conn in connections:
if conn[0] in kp_dict and conn[1] in kp_dict:
start_point = kp_dict[conn[0]][:2]
end_point = kp_dict[conn[1]][:2]
# Use average confidence to determine line thickness
avg_conf = (kp_dict[conn[0]][2] + kp_dict[conn[1]][2]) / 2
thickness = int(avg_conf * 3) + 1
cv2.line(pose_frame, start_point, end_point, person_color, thickness)
# Draw keypoints
for kp_idx, (x, y, conf) in kp_dict.items():
# Circle size based on confidence
radius = int(conf * 5) + 2
cv2.circle(pose_frame, (x, y), radius, KEYPOINT_COLORS[kp_idx % len(KEYPOINT_COLORS)], -1)
# Draw person ID
bbox = person['b']
x, y = int(bbox[0]), int(bbox[1])
cv2.putText(
pose_frame,
f"ID: {person_id}",
(x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
person_color,
2
)
return pose_frame
def draw_ui_controls(surface, width, height, playing, current_frame, total_frames):
"""Draw UI controls on the PyGame surface"""
# Background for controls
control_height = 50
control_surface = pygame.Surface((width, control_height))
control_surface.fill(BLACK)
# Draw play/pause button
button_width = 80
button_height = 30
button_x = 20
button_y = (control_height - button_height) // 2
pygame.draw.rect(control_surface, BLUE, (button_x, button_y, button_width, button_height))
font = pygame.font.SysFont(None, 24)
text = font.render("Pause" if playing else "Play", True, WHITE)
text_rect = text.get_rect(center=(button_x + button_width//2, button_y + button_height//2))
control_surface.blit(text, text_rect)
# Draw stop button
stop_button_x = button_x + button_width + 20
pygame.draw.rect(control_surface, RED, (stop_button_x, button_y, button_width, button_height))
stop_text = font.render("Stop", True, WHITE)
stop_text_rect = stop_text.get_rect(center=(stop_button_x + button_width//2, button_y + button_height//2))
control_surface.blit(stop_text, stop_text_rect)
# Draw seek bar
seekbar_x = stop_button_x + button_width + 40
seekbar_y = button_y + button_height // 2
seekbar_width = width - seekbar_x - 40
seekbar_height = 10
# Background bar
pygame.draw.rect(control_surface, (100, 100, 100),
(seekbar_x, seekbar_y - seekbar_height//2, seekbar_width, seekbar_height))
# Progress bar
progress = current_frame / total_frames if total_frames > 0 else 0
progress_width = int(seekbar_width * progress)
pygame.draw.rect(control_surface, GREEN,
(seekbar_x, seekbar_y - seekbar_height//2, progress_width, seekbar_height))
# Display current time / total time
time_text = font.render(f"Frame: {current_frame} / {total_frames}", True, WHITE)
time_rect = time_text.get_rect(center=(seekbar_x + seekbar_width//2, seekbar_y - 20))
control_surface.blit(time_text, time_rect)
# Blit the control surface to the main surface
surface.blit(control_surface, (0, height - control_height))
# Return button regions for click handling
play_button_rect = pygame.Rect(button_x, height - control_height + button_y, button_width, button_height)
stop_button_rect = pygame.Rect(stop_button_x, height - control_height + button_y, button_width, button_height)
seekbar_rect = pygame.Rect(seekbar_x, height - control_height + seekbar_y - seekbar_height//2,
seekbar_width, seekbar_height)
return play_button_rect, stop_button_rect, seekbar_rect
def run_viewer(video_path, json_path):
"""Main function to run the pose viewer"""
# Load pose data
pose_data = load_pose_data(json_path)
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"❌ Error: Could not open video {video_path}")
return
# Get video properties
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create window - now only showing single visualization
window_width = video_width
window_height = video_height + 50 # Additional space for controls
window = create_pygame_window(window_width, window_height, f"Pose Viewer - {os.path.basename(video_path)}")
# Setup clock
clock = pygame.time.Clock()
# State variables
playing = False
current_frame = 0
prev_frame_idx = None
last_valid_pose_frame = None
# Create a frame lookup for fast access - maps video frame number to pose data frame
frame_lookup = {}
max_pose_frame = 0
for pose_frame in pose_data['frames']:
frame_num = pose_frame['f']
frame_lookup[frame_num] = pose_frame
max_pose_frame = max(max_pose_frame, frame_num)
# Initial render
ret, frame = cap.read()
if not ret:
print("❌ Error: Could not read the first frame")
return
# Main loop
while True:
# Handle events
for event in pygame.event.get():
if event.type == QUIT:
pygame.quit()
sys.exit()
elif event.type == KEYDOWN:
if event.key == K_ESCAPE:
pygame.quit()
sys.exit()
elif event.key == K_SPACE:
playing = not playing
elif event.type == MOUSEBUTTONDOWN:
# Check if any buttons were clicked
mouse_pos = pygame.mouse.get_pos()
play_button_rect, stop_button_rect, seekbar_rect = draw_ui_controls(
window, window_width, window_height, playing, current_frame, total_frames
)
if play_button_rect.collidepoint(mouse_pos):
playing = not playing
elif stop_button_rect.collidepoint(mouse_pos):
playing = False
current_frame = 0
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
ret, frame = cap.read()
prev_frame_idx = None
last_valid_pose_frame = None
elif seekbar_rect.collidepoint(mouse_pos):
# Calculate position ratio
x_offset = mouse_pos[0] - seekbar_rect.x
ratio = x_offset / seekbar_rect.width
# Set frame position
current_frame = int(ratio * total_frames)
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
ret, frame = cap.read()
prev_frame_idx = None # Reset previous frame index after seeking
# Handle playback
if playing:
ret, frame = cap.read()
if not ret:
# End of video, loop back to start
playing = False
current_frame = 0
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
ret, frame = cap.read()
prev_frame_idx = None
last_valid_pose_frame = None
if not ret:
break
current_frame += 1
# Find the appropriate pose frame for the current video frame
frame_sampling = pose_data['frame_sampling']
# Try to find the exact frame in lookup
pose_frame = frame_lookup.get(current_frame)
# If not found, find the closest previous frame based on sampling
if not pose_frame:
# Calculate what the nearest pose frame should be
# This searches for the most recent pose frame
nearest_frame = current_frame
while nearest_frame > 0 and nearest_frame not in frame_lookup:
nearest_frame -= 1
if nearest_frame in frame_lookup:
pose_frame = frame_lookup[nearest_frame]
# Update the last valid pose frame if we found one
if pose_frame:
last_valid_pose_frame = pose_frame
# Draw pose on frame - use the most recent valid pose frame
if last_valid_pose_frame:
# Create a special frame dict with only the current pose for draw_pose
temp_pose_data = pose_data.copy()
temp_pose_data['frames'] = [last_valid_pose_frame]
pose_frame = draw_pose(frame, temp_pose_data, current_frame, video_width, video_height)
else:
# If no pose data found yet, just show the original frame
pose_frame = frame.copy()
prev_frame_idx = current_frame
# Convert frame from BGR to RGB for PyGame
rgb_pose_frame = cv2.cvtColor(pose_frame, cv2.COLOR_BGR2RGB)
pygame_pose_frame = pygame.surfarray.make_surface(rgb_pose_frame.swapaxes(0, 1))
# Draw frame
window.blit(pygame_pose_frame, (0, 0))
# Draw UI controls
play_button_rect, stop_button_rect, seekbar_rect = draw_ui_controls(
window, window_width, window_height, playing, current_frame, total_frames
)
# Draw metadata
font = pygame.font.SysFont(None, 20)
metadata_text = f"Frame Sampling: {pose_data['frame_sampling']}, Precision: {pose_data['precision']}"
metadata_surface = font.render(metadata_text, True, WHITE)
window.blit(metadata_surface, (10, 10))
# Update display
pygame.display.flip()
# Cap framerate
clock.tick(fps)
# Clean up
cap.release()
pygame.quit()
def main():
parser = argparse.ArgumentParser(description='Pose Viewer for JSON pose data with video')
parser.add_argument('--video', '-v', required=True, help='Video file path or URL')
parser.add_argument('--json', '-j', required=True, help='JSON pose data file path')
args = parser.parse_args()
# Handle URL input for video
video_path = args.video
if video_path.startswith('http://') or video_path.startswith('https://'):
video_path = download_video(video_path)
if not os.path.exists(video_path):
print(f"❌ Error: Video file not found: {video_path}")
return
if not os.path.exists(args.json):
print(f"❌ Error: JSON file not found: {args.json}")
return
# Run the viewer
run_viewer(video_path, args.json)
if __name__ == "__main__":
main()