video post-processing, moves saved, used by ts

This commit is contained in:
2025-05-14 21:17:41 -03:00
parent ffd9775445
commit c4291c8759
15 changed files with 1262 additions and 203 deletions

5
.gitignore vendored
View File

@@ -301,4 +301,7 @@ dist
.pnp.*
# Yolo
*.pt
*.pt
# Downloads
downloaded_videos/

18
jd-clone/index.json Normal file
View File

@@ -0,0 +1,18 @@
{
"songs": {
"Butter": {
"Name": "Butter",
"Artist": "BTS",
"Album": "Butter",
"Year": 2021,
"Genre": "K-Pop",
"GameData": {
"moves": "https://cdn.ovosimpatico.com/jdlo/maps/Butter/poses.json",
"video": "https://cdn.ovosimpatico.com/jdlo/maps/Butter/Butter_ULTRA.webm",
"audio": "https://cdn.ovosimpatico.com/jdlo/maps/Butter/Butter.ogg",
"cover": "https://cdn.ovosimpatico.com/jdlo/maps/Butter/butter_cover_generic.png",
"audio_preview": "https://cdn.ovosimpatico.com/jdlo/maps/Butter/butter_audiopreview.ogg"
}
}
}
}

View File

@@ -5,7 +5,7 @@
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"build": "tsc -b && vite build && cp ../index.json dist/",
"lint": "eslint .",
"preview": "vite preview"
},

View File

@@ -0,0 +1,61 @@
import { useEffect, useRef, useState } from 'react';
import useAppStore from '../../store/app-store';
interface AudioPreviewProps {
src: string;
autoPlay?: boolean;
}
const AudioPreview = ({ src, autoPlay = false }: AudioPreviewProps) => {
const audioRef = useRef<HTMLAudioElement | null>(null);
const [isPlaying, setIsPlaying] = useState(false);
const masterVolume = useAppStore(state => state.settings.volume.master);
const musicVolume = useAppStore(state => state.settings.volume.music);
useEffect(() => {
if (!audioRef.current) return;
// Set volume based on app settings
const volume = masterVolume * musicVolume;
audioRef.current.volume = volume;
// Autoplay if needed
if (autoPlay && audioRef.current) {
audioRef.current.play().catch(error => {
console.warn('Autoplay prevented:', error);
});
}
return () => {
if (audioRef.current) {
audioRef.current.pause();
audioRef.current.currentTime = 0;
}
};
}, [src, autoPlay, masterVolume, musicVolume]);
useEffect(() => {
const audio = audioRef.current;
if (!audio) return;
const handlePlay = () => setIsPlaying(true);
const handlePause = () => setIsPlaying(false);
const handleEnded = () => setIsPlaying(false);
audio.addEventListener('play', handlePlay);
audio.addEventListener('pause', handlePause);
audio.addEventListener('ended', handleEnded);
return () => {
audio.removeEventListener('play', handlePlay);
audio.removeEventListener('pause', handlePause);
audio.removeEventListener('ended', handleEnded);
};
}, []);
return (
<audio ref={audioRef} src={src} loop={false} preload="auto" />
);
};
export default AudioPreview;

View File

@@ -0,0 +1,14 @@
.video-player {
position: relative;
overflow: hidden;
background-color: black;
border-radius: 8px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
video {
display: block;
object-fit: contain;
width: 100%;
height: 100%;
}
}

View File

@@ -0,0 +1,168 @@
import { useEffect, useRef } from 'react';
import useAppStore from '../../store/app-store';
import './VideoPlayer.scss';
interface VideoPlayerProps {
src: string;
audioSrc?: string;
playing: boolean;
muted?: boolean;
width?: number | string;
height?: number | string;
onEnded?: () => void;
}
const VideoPlayer = ({
src,
audioSrc,
playing,
muted = false,
width = '100%',
height = '100%',
onEnded
}: VideoPlayerProps) => {
const videoRef = useRef<HTMLVideoElement | null>(null);
const audioRef = useRef<HTMLAudioElement | null>(null);
const masterVolume = useAppStore(state => state.settings.volume.master);
const musicVolume = useAppStore(state => state.settings.volume.music);
// Manage play/pause state for video
useEffect(() => {
const video = videoRef.current;
if (!video) return;
if (playing) {
video.play().catch(error => {
console.warn('Video playback prevented:', error);
});
} else {
video.pause();
}
}, [playing]);
// Manage play/pause state for audio
useEffect(() => {
const audio = audioRef.current;
if (!audio || !audioSrc) return;
if (playing) {
// Sync with video if needed
if (videoRef.current) {
audio.currentTime = videoRef.current.currentTime;
}
audio.play().catch(error => {
console.warn('Audio playback prevented:', error);
});
} else {
audio.pause();
}
}, [playing, audioSrc]);
// Sync audio with video when video seeks or loads
useEffect(() => {
const video = videoRef.current;
const audio = audioRef.current;
if (!video || !audio || !audioSrc) return;
const handleTimeUpdate = () => {
// Only sync if the difference is significant (more than 0.1 seconds)
if (Math.abs(video.currentTime - audio.currentTime) > 0.1) {
audio.currentTime = video.currentTime;
}
};
const handlePlay = () => {
if (playing) {
audio.play().catch(error => {
console.warn('Audio playback prevented:', error);
});
}
};
const handlePause = () => {
audio.pause();
};
video.addEventListener('seeked', handleTimeUpdate);
video.addEventListener('play', handlePlay);
video.addEventListener('pause', handlePause);
return () => {
video.removeEventListener('seeked', handleTimeUpdate);
video.removeEventListener('play', handlePlay);
video.removeEventListener('pause', handlePause);
};
}, [playing, audioSrc]);
// Manage volume for video
useEffect(() => {
const video = videoRef.current;
if (!video) return;
// If we have a separate audio source, mute the video
if (audioSrc) {
video.muted = true;
} else {
// Otherwise use the video's audio
if (muted) {
video.muted = true;
} else {
video.muted = false;
video.volume = masterVolume * musicVolume;
}
}
}, [masterVolume, musicVolume, muted, audioSrc]);
// Manage volume for audio
useEffect(() => {
const audio = audioRef.current;
if (!audio || !audioSrc) return;
if (muted) {
audio.muted = true;
} else {
audio.muted = false;
audio.volume = masterVolume * musicVolume;
}
}, [masterVolume, musicVolume, muted, audioSrc]);
// Setup event handlers for video
useEffect(() => {
const video = videoRef.current;
if (!video) return;
const handleEnded = () => {
if (onEnded) onEnded();
};
video.addEventListener('ended', handleEnded);
return () => {
video.removeEventListener('ended', handleEnded);
};
}, [onEnded]);
return (
<div className="video-player" style={{ width, height }}>
<video
ref={videoRef}
src={src}
preload="auto"
playsInline
width="100%"
height="100%"
muted={!!audioSrc || muted} // Always mute video if using separate audio
/>
{audioSrc && (
<audio
ref={audioRef}
src={audioSrc}
preload="auto"
/>
)}
</div>
);
};
export default VideoPlayer;

View File

@@ -249,6 +249,47 @@
align-items: center;
gap: 1rem;
}
// Completion overlay
&__completion-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: rgba(0, 0, 0, 0.7);
display: flex;
align-items: center;
justify-content: center;
z-index: 100;
animation: fade-in 0.5s ease-out;
}
&__completion-content {
background-color: rgba(26, 26, 26, 0.9);
border-radius: 12px;
padding: 3rem;
text-align: center;
box-shadow: 0 0 30px rgba(255, 215, 0, 0.3);
animation: scale-in 0.5s ease-out;
h2 {
font-size: 3.5rem;
margin-top: 0;
margin-bottom: 1.5rem;
color: #ffca3a;
}
p {
font-size: 1.8rem;
margin-bottom: 1rem;
&:last-child {
color: #8ac926;
font-weight: bold;
}
}
}
}
@keyframes feedback-pulse {
@@ -263,4 +304,24 @@
transform: scale(1);
opacity: 1;
}
}
@keyframes fade-in {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
@keyframes scale-in {
from {
transform: scale(0.8);
opacity: 0;
}
to {
transform: scale(1);
opacity: 1;
}
}

View File

@@ -1,6 +1,7 @@
import { useEffect, useState } from 'react';
import { useNavigate } from 'react-router-dom';
import { PoseRenderer } from '../../components/game/PoseRenderer';
import VideoPlayer from '../../components/game/VideoPlayer';
import usePoseDetection from '../../hooks/usePoseDetection';
import useInputDetection from '../../hooks/useInputDetection';
import useControllerDetection from '../../hooks/useControllerDetection';
@@ -19,6 +20,7 @@ function GameplayPage() {
const [gameStarted, setGameStarted] = useState(false);
const [countdown, setCountdown] = useState(3);
const [error, setError] = useState<string | null>(null);
const [gameCompleted, setGameCompleted] = useState(false);
// Get app state
const selectedSongId = useAppStore(state => state.selectedSongId);
@@ -101,6 +103,7 @@ function GameplayPage() {
// Reset game state
resetGameState();
setGameCompleted(false);
loadGameData();
@@ -142,6 +145,16 @@ function GameplayPage() {
navigate('/results');
};
// Handle video ended
const handleVideoEnded = () => {
setGameCompleted(true);
endGame();
// Wait a short moment before navigating to results
setTimeout(() => {
navigate('/results');
}, 1500);
};
// If we don't have a selected song, redirect to setup
if (!selectedSongId) {
useEffect(() => {
@@ -192,6 +205,21 @@ function GameplayPage() {
);
};
// Helper to render game completion overlay
const renderCompletionOverlay = () => {
if (!gameCompleted) return null;
return (
<div className="gameplay-page__completion-overlay">
<div className="gameplay-page__completion-content">
<h2>Song Complete!</h2>
<p>Great job dancing!</p>
<p>Redirecting to results...</p>
</div>
</div>
);
};
return (
<div className="gameplay-page gameplay-page--tv-mode">
{loading ? (
@@ -271,11 +299,19 @@ function GameplayPage() {
<div className="gameplay-page__content">
<div className="gameplay-page__video-container">
{/* Video will go here */}
<div className="gameplay-page__video-placeholder">
<p>Dance video will play here</p>
<p>Song: {song?.title} by {song?.artist}</p>
</div>
{song && song.GameData ? (
<VideoPlayer
src={song.GameData.video}
audioSrc={song.GameData.audio}
playing={isPlaying && !isPaused}
onEnded={handleVideoEnded}
/>
) : (
<div className="gameplay-page__video-placeholder">
<p>Dance video will play here</p>
<p>Song information not available</p>
</div>
)}
</div>
<div className="gameplay-page__camera-container">
@@ -302,6 +338,7 @@ function GameplayPage() {
</div>
{renderPauseOverlay()}
{renderCompletionOverlay()}
</>
)}
</div>

View File

@@ -46,8 +46,8 @@ function ResultsPage() {
<h1 className="results-page__title">Results</h1>
<div className="results-page__song-info">
<h2>{selectedSong?.title || 'Unknown Song'}</h2>
<p>{selectedSong?.artist || 'Unknown Artist'}</p>
<h2>{selectedSong?.Name || 'Unknown Song'}</h2>
<p>{selectedSong?.Artist || 'Unknown Artist'}</p>
</div>
<div className="results-page__score-container">

View File

@@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react';
import { useNavigate } from 'react-router-dom';
import { Button } from '../../components/common/Button';
import { PoseRenderer } from '../../components/game/PoseRenderer';
import AudioPreview from '../../components/common/AudioPreview';
import usePoseDetection from '../../hooks/usePoseDetection';
import useAppStore from '../../store/app-store';
import songService from '../../services/song-service';
@@ -13,6 +14,7 @@ function GameSetupPage() {
const [songs, setSongs] = useState<Song[]>([]);
const [loading, setLoading] = useState(true);
const [selectedSongId, setSelectedSongId] = useState<string | null>(null);
const [currentPreviewUrl, setCurrentPreviewUrl] = useState<string | null>(null);
const difficulty = useAppStore(state => state.settings.difficulty);
const setDifficulty = useAppStore(state => state.setDifficulty);
@@ -31,6 +33,8 @@ function GameSetupPage() {
// Select the first song by default
if (songList.length > 0 && !selectedSongId) {
setSelectedSongId(songList[0].id);
// Set the preview URL for the first song
setCurrentPreviewUrl(songList[0].GameData.audio_preview);
}
} catch (error) {
console.error('Failed to load songs:', error);
@@ -48,6 +52,12 @@ function GameSetupPage() {
const handleSongSelect = (songId: string) => {
setSelectedSongId(songId);
// Update the audio preview URL when selecting a new song
const selectedSong = songs.find(song => song.id === songId);
if (selectedSong) {
setCurrentPreviewUrl(selectedSong.GameData.audio_preview);
}
};
const handleStartGame = () => {
@@ -103,27 +113,32 @@ function GameSetupPage() {
className={`setup-page__song-item ${selectedSongId === song.id ? 'setup-page__song-item--selected' : ''}`}
onClick={() => handleSongSelect(song.id)}
>
<div className="setup-page__song-cover" style={{ backgroundImage: `url(${song.coverUrl})` }} />
<div className="setup-page__song-cover" style={{ backgroundImage: `url(${song.GameData.cover})` }} />
<div className="setup-page__song-info">
<h3 className="setup-page__song-title">{song.title}</h3>
<p className="setup-page__song-artist">{song.artist}</p>
<p className="setup-page__song-duration">{Math.floor(song.duration / 60)}:{(song.duration % 60).toString().padStart(2, '0')}</p>
<h3 className="setup-page__song-title">{song.Name}</h3>
<p className="setup-page__song-artist">{song.Artist}</p>
<p className="setup-page__song-genre">{song.Genre || 'Unknown genre'}</p>
</div>
</div>
))}
</div>
)}
{/* Audio Preview Player */}
{currentPreviewUrl && (
<AudioPreview src={currentPreviewUrl} autoPlay={true} />
)}
</div>
<div className="setup-page__difficulty-selection">
<h2>Select Difficulty</h2>
<div className="setup-page__difficulty-buttons">
{selectedSong?.difficulty.map((diff) => (
{(['easy', 'medium', 'hard', 'extreme'] as DifficultyLevel[]).map((diff) => (
<button
key={diff}
className={`setup-page__difficulty-button setup-page__difficulty-button--${diff} ${difficulty === diff ? 'setup-page__difficulty-button--selected' : ''}`}
onClick={() => handleDifficultyChange(diff)}
disabled={!selectedSong.difficulty.includes(diff)}
disabled={selectedSong?.difficulty ? !selectedSong.difficulty.includes(diff) : false}
>
{diff.toUpperCase()}
</button>

View File

@@ -1,140 +1,101 @@
import { Song, Choreography, DifficultyLevel, Move } from '../types';
// Mock song data for development purposes
const MOCK_SONGS: Song[] = [
{
id: 'song1',
title: 'Dance The Night',
artist: 'Dua Lipa',
bpm: 120,
duration: 176,
coverUrl: 'https://example.com/covers/dance-the-night.jpg',
audioUrl: 'https://example.com/songs/dance-the-night.mp3',
videoUrl: 'https://example.com/videos/dance-the-night.mp4',
difficulty: ['easy', 'medium', 'hard'],
tags: ['pop', 'upbeat', 'disco']
},
{
id: 'song2',
title: 'Levitating',
artist: 'Dua Lipa ft. DaBaby',
bpm: 103,
duration: 203,
coverUrl: 'https://example.com/covers/levitating.jpg',
audioUrl: 'https://example.com/songs/levitating.mp3',
videoUrl: 'https://example.com/videos/levitating.mp4',
difficulty: ['easy', 'medium', 'hard', 'extreme'],
tags: ['pop', 'upbeat', 'disco']
},
{
id: 'song3',
title: 'Physical',
artist: 'Dua Lipa',
bpm: 124,
duration: 183,
coverUrl: 'https://example.com/covers/physical.jpg',
audioUrl: 'https://example.com/songs/physical.mp3',
videoUrl: 'https://example.com/videos/physical.mp4',
difficulty: ['medium', 'hard', 'extreme'],
tags: ['pop', 'dance', 'workout']
}
];
// Mock choreography data with placeholder moves
const MOCK_CHOREOGRAPHIES: Record<string, Record<DifficultyLevel, Choreography>> = {
song1: {
easy: {
songId: 'song1',
difficulty: 'easy',
moves: Array(20).fill(null).map((_, index) => ({
id: `song1-easy-move-${index}`,
startTime: index * 8000,
duration: 4000,
keyPosePoints: [], // This would contain actual pose landmarks
difficulty: 'easy',
score: 100
}))
},
medium: {
songId: 'song1',
difficulty: 'medium',
moves: Array(30).fill(null).map((_, index) => ({
id: `song1-medium-move-${index}`,
startTime: index * 6000,
duration: 3000,
keyPosePoints: [],
difficulty: 'medium',
score: 150
}))
},
hard: {
songId: 'song1',
difficulty: 'hard',
moves: Array(40).fill(null).map((_, index) => ({
id: `song1-hard-move-${index}`,
startTime: index * 4000,
duration: 2000,
keyPosePoints: [],
difficulty: 'hard',
score: 200
}))
},
extreme: {
songId: 'song1',
difficulty: 'extreme',
moves: Array(50).fill(null).map((_, index) => ({
id: `song1-extreme-move-${index}`,
startTime: index * 3000,
duration: 1500,
keyPosePoints: [],
difficulty: 'extreme',
score: 300
}))
}
}
};
class SongService {
/**
* Get all available songs
*/
async getSongs(): Promise<Song[]> {
// In a real app, this would fetch from an API
return new Promise((resolve) => {
setTimeout(() => {
resolve(MOCK_SONGS);
}, 500);
});
try {
const response = await fetch('/index.json');
if (!response.ok) {
throw new Error(`Failed to fetch songs: ${response.statusText}`);
}
const data = await response.json();
// Transform the data to match our Song interface
return Object.entries(data.songs).map(([id, songData]: [string, any]) => ({
id,
...songData,
// Add default difficulty levels since they're not in the JSON
difficulty: ['medium', 'hard']
}));
} catch (error) {
console.error('Error fetching songs:', error);
return [];
}
}
/**
* Get a specific song by ID
*/
async getSongById(id: string): Promise<Song | null> {
return new Promise((resolve) => {
setTimeout(() => {
const song = MOCK_SONGS.find(s => s.id === id) || null;
resolve(song);
}, 300);
});
try {
const songs = await this.getSongs();
return songs.find(song => song.id === id) || null;
} catch (error) {
console.error(`Error fetching song ${id}:`, error);
return null;
}
}
/**
* Get choreography for a song at a specific difficulty
*/
async getChoreography(songId: string, difficulty: DifficultyLevel): Promise<Choreography | null> {
return new Promise((resolve) => {
setTimeout(() => {
const songChoreographies = MOCK_CHOREOGRAPHIES[songId];
if (!songChoreographies) {
resolve(null);
return;
}
try {
const song = await this.getSongById(songId);
if (!song) {
return null;
}
const choreography = songChoreographies[difficulty];
resolve(choreography || null);
}, 500);
});
// Fetch the moves data from the URL in song.GameData.moves
const response = await fetch(song.GameData.moves);
if (!response.ok) {
throw new Error(`Failed to fetch choreography: ${response.statusText}`);
}
const choreographyData = await response.json();
// Process the choreography data based on difficulty
// This would need to be adjusted based on the actual format of the moves data
return {
songId,
difficulty,
moves: this.processChoreographyData(choreographyData, difficulty)
};
} catch (error) {
console.error(`Error fetching choreography for song ${songId}:`, error);
return null;
}
}
/**
* Process choreography data from the JSON file
* Note: This would need to be adjusted based on the actual data structure
*/
private processChoreographyData(data: any, difficulty: DifficultyLevel): Move[] {
// This is a placeholder implementation
// You'll need to adapt this based on the actual format of your poses.json files
const moves: Move[] = [];
// Example implementation assuming data has a moves array
if (Array.isArray(data.moves)) {
data.moves.forEach((moveData: any, index: number) => {
moves.push({
id: `${difficulty}-move-${index}`,
startTime: moveData.startTime || index * 3000,
duration: moveData.duration || 2000,
keyPosePoints: moveData.keyPoints || [],
difficulty,
score: difficulty === 'easy' ? 100 :
difficulty === 'medium' ? 150 :
difficulty === 'hard' ? 200 : 300
});
});
}
return moves;
}
/**
@@ -144,12 +105,9 @@ class SongService {
// This would use the pose detection API to analyze a video and generate choreography data
console.log(`Generating choreography for ${videoUrl}, song ${songId}, difficulty ${difficulty}`);
// For now, just return a mock choreography
return new Promise((resolve) => {
setTimeout(() => {
resolve(MOCK_CHOREOGRAPHIES.song1[difficulty]);
}, 2000);
});
// This is a placeholder - in a real implementation,
// this would call an API to process the video and generate choreography
return null;
}
}

View File

@@ -21,17 +21,23 @@ export interface PoseData {
}
// Game content types
export interface SongGameData {
moves: string;
video: string;
audio: string;
cover: string;
audio_preview: string;
}
export interface Song {
id: string;
title: string;
artist: string;
bpm: number;
duration: number;
coverUrl: string;
audioUrl: string;
videoUrl: string;
difficulty: DifficultyLevel[];
tags: string[];
Name: string;
Artist: string;
Album?: string;
Year?: number;
Genre?: string;
GameData: SongGameData;
difficulty?: DifficultyLevel[];
}
export type DifficultyLevel = 'easy' | 'medium' | 'hard' | 'extreme';

View File

@@ -1,5 +1,6 @@
import argparse
import json
import math
import os
import time
import urllib.request
@@ -77,7 +78,66 @@ def download_video(url: str, output_dir: str = "downloaded_videos") -> str:
print(f"✅ Video downloaded successfully to {output_path}")
return output_path
def normalize_landmarks(landmarks: List[Dict], window_size: int = 5, poly_order: int = 4) -> List[Dict]:
def normalize_landmarks_per_person(people_landmarks: List[Dict], window_size: int = 5, poly_order: int = 4) -> List[Dict]:
"""Normalize landmarks over time for each person using Savitzky-Golay filter"""
if not people_landmarks:
return people_landmarks
# Reorganize by person ID
person_data = {}
for frame_data in people_landmarks:
frame_num = frame_data['frame']
timestamp = frame_data['timestamp']
for person in frame_data['people']:
person_id = person['person_id']
if person_id not in person_data:
person_data[person_id] = {
'frames': [],
'timestamps': [],
'landmarks': []
}
person_data[person_id]['frames'].append(frame_num)
person_data[person_id]['timestamps'].append(timestamp)
person_data[person_id]['landmarks'].append(person['landmarks'])
# Normalize each person's landmarks
for person_id, data in person_data.items():
if len(data['landmarks']) >= window_size:
data['landmarks'] = normalize_landmarks(
data['landmarks'],
window_size=window_size,
poly_order=poly_order
)
# Reconstruct the frame data structure
normalized_data = []
for frame_data in people_landmarks:
frame_num = frame_data['frame']
timestamp = frame_data['timestamp']
new_people = []
for person in frame_data['people']:
person_id = person['person_id']
idx = person_data[person_id]['frames'].index(frame_num)
new_people.append({
'person_id': person_id,
'bbox': person['bbox'],
'landmarks': person_data[person_id]['landmarks'][idx]
})
normalized_data.append({
'frame': frame_num,
'timestamp': timestamp,
'people': new_people
})
return normalized_data
def normalize_landmarks(landmarks: List[List[Dict]], window_size: int = 5, poly_order: int = 4) -> List[List[Dict]]:
"""Normalize landmarks over time using Savitzky-Golay filter to smooth motion"""
if not landmarks or len(landmarks) < window_size:
return landmarks
@@ -86,6 +146,12 @@ def normalize_landmarks(landmarks: List[Dict], window_size: int = 5, poly_order:
if window_size % 2 == 0:
window_size += 1
# Check if all frames have the same number of landmarks
if not all(len(frame) == len(landmarks[0]) for frame in landmarks):
# If inconsistent landmark counts, use a simpler approach (frame by frame smoothing)
print("⚠️ Warning: Inconsistent landmark counts across frames. Using simplified smoothing.")
return landmarks
# Extract x, y values for each landmark
landmark_count = len(landmarks[0])
x_values = np.zeros((len(landmarks), landmark_count))
@@ -117,40 +183,226 @@ def normalize_landmarks(landmarks: List[Dict], window_size: int = 5, poly_order:
return normalized_landmarks
def calculate_iou(box1, box2):
"""Calculate IoU (Intersection over Union) between two bounding boxes"""
# Extract coordinates
x1_1, y1_1, x2_1, y2_1 = box1
x1_2, y1_2, x2_2, y2_2 = box2
# Calculate intersection area
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# Calculate union area
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - intersection_area
return intersection_area / union_area if union_area > 0 else 0
def calculate_keypoint_distance(landmarks1, landmarks2):
"""Calculate average distance between corresponding keypoints"""
if not landmarks1 or not landmarks2:
return float('inf')
# Create dictionary for fast lookup
kps1 = {lm['idx']: (lm['x'], lm['y']) for lm in landmarks1}
kps2 = {lm['idx']: (lm['x'], lm['y']) for lm in landmarks2}
# Find common keypoints
common_idx = set(kps1.keys()) & set(kps2.keys())
if not common_idx:
return float('inf')
# Calculate distance between corresponding keypoints
total_dist = 0
for idx in common_idx:
x1, y1 = kps1[idx]
x2, y2 = kps2[idx]
dist = math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
total_dist += dist
return total_dist / len(common_idx)
def assign_person_ids(current_people, previous_people, iou_threshold=0.3, distance_threshold=0.2):
"""Assign stable IDs to people across frames based on IOU and keypoint distance"""
if not previous_people:
# First frame, assign new IDs to everyone
next_id = 0
for person in current_people:
person['person_id'] = next_id
next_id += 1
return current_people
# Create copy of current people to modify
assigned_people = []
unassigned_current = current_people.copy()
# Try to match current detections with previous ones
matched_prev_ids = set()
# Sort previous people by ID to maintain consistency in matching
sorted_prev = sorted(previous_people, key=lambda x: x['person_id'])
for prev_person in sorted_prev:
prev_id = prev_person['person_id']
prev_box = prev_person['bbox']
prev_landmarks = prev_person['landmarks']
best_match = None
best_score = float('inf') # Lower is better for distance
for curr_person in unassigned_current:
curr_box = curr_person['bbox']
curr_landmarks = curr_person['landmarks']
# Calculate IoU between bounding boxes
iou = calculate_iou(prev_box, curr_box)
# Calculate keypoint distance
kp_dist = calculate_keypoint_distance(prev_landmarks, curr_landmarks)
# Combined score (lower is better)
score = kp_dist * (1.5 - iou) # Favor high IoU and low distance
if (iou >= iou_threshold or kp_dist <= distance_threshold) and score < best_score:
best_match = curr_person
best_score = score
if best_match:
# Assign the previous ID to this person
best_match['person_id'] = prev_id
matched_prev_ids.add(prev_id)
assigned_people.append(best_match)
unassigned_current.remove(best_match)
# Find the next available ID
next_id = 0
existing_ids = {p['person_id'] for p in previous_people}
while next_id in existing_ids:
next_id += 1
# Assign new IDs to unmatched current detections
for person in unassigned_current:
person['person_id'] = next_id
assigned_people.append(person)
next_id += 1
return assigned_people
def compress_pose_data(all_frame_data, frame_sampling=1, precision=3):
"""Compress pose data to reduce JSON file size by reducing precision and sampling frames"""
compressed_data = []
# Process only every nth frame based on sampling rate
for i, frame_data in enumerate(all_frame_data):
if i % frame_sampling != 0:
continue
# Compress frame data
compressed_frame = {
'f': frame_data['frame'], # Short key name
't': round(frame_data['timestamp'], 2), # Reduce timestamp precision
'p': [] # Short key for people
}
# Process each person
for person in frame_data['people']:
# Only keep essential bbox info (we only need width/height for visualization)
x1, y1, x2, y2 = person['bbox']
width = x2 - x1
height = y2 - y1
compressed_person = {
'id': person['person_id'], # Keep ID as is
'b': [round(x1, 1), round(y1, 1), round(width, 1), round(height, 1)], # Simplified bbox with less precision
'k': [] # Short key for keypoints/landmarks
}
# Process each landmark with reduced precision
for lm in person['landmarks']:
compressed_person['k'].append([
lm['idx'], # Keep index as is (small integer)
round(lm['x'], precision), # Reduce coordinate precision
round(lm['y'], precision), # Reduce coordinate precision
round(lm['confidence'], 2) # Reduce confidence precision
])
compressed_frame['p'].append(compressed_person)
compressed_data.append(compressed_frame)
return compressed_data
def process_frame(frame: np.ndarray, model, detection_threshold: float = 0.5, show_preview: bool = False):
"""Process a single frame with YOLOv11-pose"""
"""Process a single frame with YOLOv11-pose, handling multiple people"""
# Process with YOLO
try:
results = model.predict(frame, verbose=False, conf=detection_threshold)
# Extract keypoints if available
landmarks_data = None
processed_frame = None
people_data = []
# Get frame dimensions
h, w = frame.shape[:2]
if results and len(results[0].keypoints.data) > 0:
# Get keypoints from the first detection
keypoints = results[0].keypoints.data[0] # [17, 3] - (x, y, confidence)
# Get all keypoints and bounding boxes
keypoints = results[0].keypoints.data # [num_people, 17, 3] - (x, y, confidence)
boxes = results[0].boxes.xyxy.cpu() # [num_people, 4] - (x1, y1, x2, y2)
# Extract keypoints to landmarks_data
landmarks_data = []
for idx, kp in enumerate(keypoints):
x, y, conf = kp.tolist()
if conf >= detection_threshold:
landmarks_data.append({
'idx': idx,
'x': x / w, # Normalize to 0-1 range
'y': y / h, # Normalize to 0-1 range
'confidence': conf
for i, (kps, box) in enumerate(zip(keypoints, boxes)):
# Extract keypoints to landmarks_data
landmarks_data = []
for idx, kp in enumerate(kps):
x, y, conf = kp.tolist()
if conf >= detection_threshold:
landmarks_data.append({
'idx': idx,
'x': round(x / w, 4), # Normalize to 0-1 range with 4 decimal precision
'y': round(y / h, 4), # Normalize to 0-1 range with 4 decimal precision
'confidence': round(conf, 2) # Reduce confidence to 2 decimal places
})
if landmarks_data: # Only add if we have valid landmarks
# Add bounding box and landmarks for this person
people_data.append({
'bbox': box.tolist(), # Store unnormalized for IoU calculation
'landmarks': landmarks_data # Store normalized for consistency
})
# Create visualization if preview is enabled
if show_preview:
processed_frame = results[0].plot()
return processed_frame, landmarks_data
# Add person IDs to the visualization if they're already assigned
for person in people_data:
if 'person_id' in person:
# Get center of bounding box
x1, y1, x2, y2 = person['bbox']
center_x = int((x1 + x2) / 2)
center_y = int(y1) # Top of the bbox
# Draw ID text
cv2.putText(
processed_frame,
f"ID: {person['person_id']}",
(center_x, center_y - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(0, 255, 255),
2
)
return processed_frame, people_data
except RuntimeError as e:
# Check if this is an NMS backend error
@@ -170,7 +422,9 @@ def run_pose_detection(
model_size='n',
device='auto',
show_preview=True,
batch_size=1
batch_size=1,
frame_sampling=1, # New parameter to control frame sampling rate
precision=3 # New parameter to control coordinate precision
):
"""YOLOv11 pose detection with CUDA acceleration, properly handling NMS issues"""
start_time = time.time()
@@ -240,12 +494,13 @@ def run_pose_detection(
window_name = "YOLOv11 Pose"
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
# Initialize variables for batch processing
all_landmarks = []
# Initialize variables for tracking
all_frame_data = []
processed_frames = 0
frames_buffer = []
last_people_data = []
last_fps_update = time.time()
current_fps = 0
total_people_detected = 0
# Main processing loop
print("⏳ Processing frames...")
@@ -256,18 +511,23 @@ def run_pose_detection(
try:
# Process the frame
processed_frame, landmarks_data = process_frame(
processed_frame, people_data = process_frame(
frame, model, detection_threshold, show_preview
)
# Store landmark data with timestamp
if landmarks_data:
# Assign stable person IDs
if people_data:
people_data = assign_person_ids(people_data, last_people_data)
last_people_data = people_data.copy()
# Store frame data with people
frame_data = {
'frame': processed_frames,
'timestamp': processed_frames / fps if fps > 0 else time.time() - start_time,
'landmarks': landmarks_data
'people': people_data
}
all_landmarks.append(frame_data)
all_frame_data.append(frame_data)
total_people_detected += len(people_data)
except RuntimeError as e:
if str(e) == "CUDA NMS Error":
@@ -295,7 +555,7 @@ def run_pose_detection(
# Show CUDA status
cv2.putText(
processed_frame,
f"Device: {model.device} (Full GPU processing)",
f"Device: {model.device} | People: {len(people_data) if people_data else 0}",
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
)
@@ -320,50 +580,62 @@ def run_pose_detection(
print(f"⏱️ Processed {processed_frames} frames in {elapsed_time:.2f}s ({effective_fps:.2f} fps)")
if all_landmarks:
print(f"🧮 Detected poses in {len(all_landmarks)} frames ({(len(all_landmarks)/max(1, processed_frames))*100:.1f}%)")
if all_frame_data:
unique_people = set()
for frame in all_frame_data:
for person in frame['people']:
unique_people.add(person['person_id'])
print(f"🧮 Detected {len(all_frame_data)} frames with poses ({len(all_frame_data)/max(1, processed_frames)*100:.1f}%)")
print(f"👥 Detected {len(unique_people)} unique people with {total_people_detected} total detections")
else:
print(f"⚠️ No poses detected. Try adjusting detection threshold or check the video content.")
# Save results if output file is specified
if output_file and all_landmarks:
if output_file and all_frame_data:
output_dir = os.path.dirname(output_file)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Apply normalization if requested
if normalize and len(all_landmarks) > filter_window_size:
print(f"🔄 Normalizing data...")
landmarks_only = [frame_data['landmarks'] for frame_data in all_landmarks]
normalized_landmarks = normalize_landmarks(
landmarks_only,
if normalize and len(all_frame_data) > filter_window_size:
print(f"🔄 Normalizing data for each person...")
all_frame_data = normalize_landmarks_per_person(
all_frame_data,
window_size=filter_window_size,
poly_order=filter_poly_order
)
# Put normalized landmarks back
for i, frame_data in enumerate(all_landmarks):
if i < len(normalized_landmarks):
all_landmarks[i]['landmarks'] = normalized_landmarks[i]
# Compress data to reduce file size
print(f"🗜️ Compressing data (frame sampling: {frame_sampling}, precision: {precision})...")
compressed_frames = compress_pose_data(all_frame_data, frame_sampling, precision)
actual_frames_saved = len(compressed_frames)
# Create output in compatible format
# Calculate compression ratio
original_frame_count = len(all_frame_data)
compression_ratio = (original_frame_count - actual_frames_saved) / original_frame_count * 100
print(f"📊 Compression: {original_frame_count} frames reduced to {actual_frames_saved} ({compression_ratio:.1f}% reduction)")
# Create output in compatible format with compressed frames
json_data = {
'source': source_name,
'frame_width': frame_width,
'frame_height': frame_height,
'src': source_name, # Shortened key
'w': frame_width, # Shortened key
'h': frame_height, # Shortened key
'fps': fps,
'total_frames': processed_frames,
'keypoint_names': KEYPOINT_NAMES,
'connections': [{'start': c[0], 'end': c[1]} for c in POSE_CONNECTIONS],
'frames': all_landmarks,
'metadata': {
'frames': processed_frames,
'keypoints': KEYPOINT_NAMES, # More descriptive key
'connections': [{'s': c[0], 'e': c[1]} for c in POSE_CONNECTIONS], # Shortened keys
'data': compressed_frames, # Use compressed data
'meta': { # Shortened key
'model': f"YOLOv11-{model_size}-pose",
'device': str(model.device),
'normalized': normalize,
'detection_threshold': detection_threshold,
'filter_window_size': filter_window_size if normalize else None,
'filter_poly_order': filter_poly_order if normalize else None,
'created_at': time.strftime('%Y-%m-%d %H:%M:%S')
'threshold': detection_threshold,
'filter_size': filter_window_size if normalize else None,
'filter_order': filter_poly_order if normalize else None,
'frame_sampling': frame_sampling,
'precision': precision,
'created': time.strftime('%Y-%m-%d %H:%M:%S')
}
}
@@ -371,7 +643,8 @@ def run_pose_detection(
with open(output_file, 'w') as f:
json.dump(json_data, f)
print(f"💾 Saved tracking data to {output_file}")
file_size_mb = os.path.getsize(output_file) / (1024 * 1024)
print(f"💾 Saved tracking data to {output_file} ({file_size_mb:.2f} MB)")
elif output_file:
print(f"⚠️ No pose data to save. Output file was not created.")
@@ -383,7 +656,7 @@ def run_pose_detection(
# Restore original NMS function
torchvision.ops.nms = original_nms
return all_landmarks
return all_frame_data
def main():
# Set up simple argument parser
@@ -411,8 +684,12 @@ def main():
help='Window size for smoothing filter (must be odd, larger = smoother)')
parser.add_argument('--filter-order', type=int, default=4,
help='Polynomial order for smoothing filter (1-4)')
parser.add_argument('--batch-size', type=int, default=1,
parser.add_argument('--batch-size', type=int, default=4,
help='Batch size for processing (higher uses more VRAM but can be faster)')
parser.add_argument('--frame-sampling', type=int, default=2,
help='Save only every Nth frame (1=all frames, 2=half, 4=quarter, etc.)')
parser.add_argument('--precision', type=int, default=3, choices=[2, 3, 4],
help='Decimal precision for coordinates (2-4, lower=smaller file)')
args = parser.parse_args()
@@ -430,6 +707,8 @@ def main():
print(f"• Device: {args.device}")
print(f"• Preview: {'Disabled' if args.no_preview else 'Enabled'}")
print(f"• Normalization: {'Disabled' if args.no_normalize else 'Enabled'}")
print(f"• Frame sampling: Every {args.frame_sampling} frame(s)")
print(f"• Coordinate precision: {args.precision} decimal places")
print("="*50 + "\n")
# Run pose detection
@@ -444,7 +723,9 @@ def main():
model_size=args.model,
device=args.device,
show_preview=not args.no_preview,
batch_size=args.batch_size
batch_size=args.batch_size,
frame_sampling=args.frame_sampling,
precision=args.precision
)
except KeyboardInterrupt:
print("\n⏹️ Process interrupted by user")

436
pose_viewer.py Normal file
View File

@@ -0,0 +1,436 @@
import argparse
import json
import os
import sys
import time
import urllib.request
from pathlib import Path
import cv2
import numpy as np
import pygame
from pygame.locals import *
# Define colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)
CYAN = (0, 255, 255)
MAGENTA = (255, 0, 255)
# Define keypoint colors (custom palette)
KEYPOINT_COLORS = [
(255, 0, 0), # nose (red)
(255, 85, 0), # left_eye (orange-red)
(255, 170, 0), # right_eye (orange)
(255, 255, 0), # left_ear (yellow)
(170, 255, 0), # right_ear (yellow-green)
(85, 255, 0), # left_shoulder (green-yellow)
(0, 255, 0), # right_shoulder (green)
(0, 255, 85), # left_elbow (green-cyan)
(0, 255, 170), # right_elbow (cyan-green)
(0, 255, 255), # left_wrist (cyan)
(0, 170, 255), # right_wrist (cyan-blue)
(0, 85, 255), # left_hip (blue-cyan)
(0, 0, 255), # right_hip (blue)
(85, 0, 255), # left_knee (blue-purple)
(170, 0, 255), # right_knee (purple-blue)
(255, 0, 255), # left_ankle (magenta)
(255, 0, 170) # right_ankle (magenta-pink)
]
# Person ID colors
PERSON_COLORS = [
(255, 0, 0), # red
(0, 255, 0), # green
(0, 0, 255), # blue
(255, 255, 0), # yellow
(255, 0, 255), # magenta
(0, 255, 255), # cyan
(255, 128, 0), # orange
(128, 0, 255), # purple
(0, 255, 128), # mint
(255, 255, 255) # white
]
def download_video(url, output_dir="downloaded_videos"):
"""Download a video from a URL and return the local file path"""
os.makedirs(output_dir, exist_ok=True)
video_name = os.path.basename(url).split("?")[0]
if not video_name or "." not in video_name:
video_name = f"video_{int(time.time())}.mp4"
output_path = os.path.join(output_dir, video_name)
if os.path.exists(output_path):
print(f"✅ Video already downloaded: {output_path}")
return output_path
print(f"⬇️ Downloading video from {url} to {output_path}...")
urllib.request.urlretrieve(url, output_path)
print(f"✅ Video downloaded successfully to {output_path}")
return output_path
def load_pose_data(json_file):
"""Load pose data from a JSON file"""
print(f"📂 Loading pose data from {json_file}...")
with open(json_file, 'r') as f:
data = json.load(f)
# Extract metadata
width = data.get('w', 1280)
height = data.get('h', 720)
fps = data.get('fps', 30)
total_frames = data.get('frames', 0)
# Get frame sampling and precision from metadata if available
metadata = data.get('meta', {})
frame_sampling = metadata.get('frame_sampling', 1)
precision = metadata.get('precision', 3)
# Extract connections
connections = []
for conn in data.get('connections', []):
start = conn.get('s', 0)
end = conn.get('e', 0)
connections.append((start, end))
# Extract keypoint names
keypoint_names = data.get('keypoints', [])
# Extract frame data
frames = data.get('data', [])
print(f"✅ Loaded {len(frames)} frames of pose data")
print(f"📊 Video: {width}x{height}@{fps}fps, {total_frames} total frames")
print(f"🔍 Frame sampling: {frame_sampling}, Precision: {precision}")
return {
'width': width,
'height': height,
'fps': fps,
'total_frames': total_frames,
'frame_sampling': frame_sampling,
'precision': precision,
'connections': connections,
'keypoint_names': keypoint_names,
'frames': frames
}
def create_pygame_window(width, height, title="Pose Viewer"):
"""Create a PyGame window"""
pygame.init()
window = pygame.display.set_mode((width, height))
pygame.display.set_caption(title)
return window
def draw_pose(frame, pose_data, frame_idx, original_width, original_height, prev_frame_idx=None):
"""Draw pose data on a given frame"""
# Clone the frame to avoid modifying the original
pose_frame = frame.copy()
# Find the closest pose frame to the current video frame (should be first and only in temp_pose_data)
if pose_data['frames']:
closest_frame = pose_data['frames'][0]
connections = pose_data['connections']
# Draw each person
for person_idx, person in enumerate(closest_frame['p']):
person_id = person['id']
person_color = PERSON_COLORS[person_id % len(PERSON_COLORS)]
# Get keypoints
keypoints = person['k']
# Create a dictionary to store keypoints by index
kp_dict = {}
for kp in keypoints:
kp_dict[kp[0]] = (
int(kp[1] * original_width),
int(kp[2] * original_height),
kp[3]
)
# Draw connections
for conn in connections:
if conn[0] in kp_dict and conn[1] in kp_dict:
start_point = kp_dict[conn[0]][:2]
end_point = kp_dict[conn[1]][:2]
# Use average confidence to determine line thickness
avg_conf = (kp_dict[conn[0]][2] + kp_dict[conn[1]][2]) / 2
thickness = int(avg_conf * 3) + 1
cv2.line(pose_frame, start_point, end_point, person_color, thickness)
# Draw keypoints
for kp_idx, (x, y, conf) in kp_dict.items():
# Circle size based on confidence
radius = int(conf * 5) + 2
cv2.circle(pose_frame, (x, y), radius, KEYPOINT_COLORS[kp_idx % len(KEYPOINT_COLORS)], -1)
# Draw person ID
bbox = person['b']
x, y = int(bbox[0]), int(bbox[1])
cv2.putText(
pose_frame,
f"ID: {person_id}",
(x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
person_color,
2
)
return pose_frame
def draw_ui_controls(surface, width, height, playing, current_frame, total_frames):
"""Draw UI controls on the PyGame surface"""
# Background for controls
control_height = 50
control_surface = pygame.Surface((width, control_height))
control_surface.fill(BLACK)
# Draw play/pause button
button_width = 80
button_height = 30
button_x = 20
button_y = (control_height - button_height) // 2
pygame.draw.rect(control_surface, BLUE, (button_x, button_y, button_width, button_height))
font = pygame.font.SysFont(None, 24)
text = font.render("Pause" if playing else "Play", True, WHITE)
text_rect = text.get_rect(center=(button_x + button_width//2, button_y + button_height//2))
control_surface.blit(text, text_rect)
# Draw stop button
stop_button_x = button_x + button_width + 20
pygame.draw.rect(control_surface, RED, (stop_button_x, button_y, button_width, button_height))
stop_text = font.render("Stop", True, WHITE)
stop_text_rect = stop_text.get_rect(center=(stop_button_x + button_width//2, button_y + button_height//2))
control_surface.blit(stop_text, stop_text_rect)
# Draw seek bar
seekbar_x = stop_button_x + button_width + 40
seekbar_y = button_y + button_height // 2
seekbar_width = width - seekbar_x - 40
seekbar_height = 10
# Background bar
pygame.draw.rect(control_surface, (100, 100, 100),
(seekbar_x, seekbar_y - seekbar_height//2, seekbar_width, seekbar_height))
# Progress bar
progress = current_frame / total_frames if total_frames > 0 else 0
progress_width = int(seekbar_width * progress)
pygame.draw.rect(control_surface, GREEN,
(seekbar_x, seekbar_y - seekbar_height//2, progress_width, seekbar_height))
# Display current time / total time
time_text = font.render(f"Frame: {current_frame} / {total_frames}", True, WHITE)
time_rect = time_text.get_rect(center=(seekbar_x + seekbar_width//2, seekbar_y - 20))
control_surface.blit(time_text, time_rect)
# Blit the control surface to the main surface
surface.blit(control_surface, (0, height - control_height))
# Return button regions for click handling
play_button_rect = pygame.Rect(button_x, height - control_height + button_y, button_width, button_height)
stop_button_rect = pygame.Rect(stop_button_x, height - control_height + button_y, button_width, button_height)
seekbar_rect = pygame.Rect(seekbar_x, height - control_height + seekbar_y - seekbar_height//2,
seekbar_width, seekbar_height)
return play_button_rect, stop_button_rect, seekbar_rect
def run_viewer(video_path, json_path):
"""Main function to run the pose viewer"""
# Load pose data
pose_data = load_pose_data(json_path)
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"❌ Error: Could not open video {video_path}")
return
# Get video properties
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create window - now only showing single visualization
window_width = video_width
window_height = video_height + 50 # Additional space for controls
window = create_pygame_window(window_width, window_height, f"Pose Viewer - {os.path.basename(video_path)}")
# Setup clock
clock = pygame.time.Clock()
# State variables
playing = False
current_frame = 0
prev_frame_idx = None
last_valid_pose_frame = None
# Create a frame lookup for fast access - maps video frame number to pose data frame
frame_lookup = {}
max_pose_frame = 0
for pose_frame in pose_data['frames']:
frame_num = pose_frame['f']
frame_lookup[frame_num] = pose_frame
max_pose_frame = max(max_pose_frame, frame_num)
# Initial render
ret, frame = cap.read()
if not ret:
print("❌ Error: Could not read the first frame")
return
# Main loop
while True:
# Handle events
for event in pygame.event.get():
if event.type == QUIT:
pygame.quit()
sys.exit()
elif event.type == KEYDOWN:
if event.key == K_ESCAPE:
pygame.quit()
sys.exit()
elif event.key == K_SPACE:
playing = not playing
elif event.type == MOUSEBUTTONDOWN:
# Check if any buttons were clicked
mouse_pos = pygame.mouse.get_pos()
play_button_rect, stop_button_rect, seekbar_rect = draw_ui_controls(
window, window_width, window_height, playing, current_frame, total_frames
)
if play_button_rect.collidepoint(mouse_pos):
playing = not playing
elif stop_button_rect.collidepoint(mouse_pos):
playing = False
current_frame = 0
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
ret, frame = cap.read()
prev_frame_idx = None
last_valid_pose_frame = None
elif seekbar_rect.collidepoint(mouse_pos):
# Calculate position ratio
x_offset = mouse_pos[0] - seekbar_rect.x
ratio = x_offset / seekbar_rect.width
# Set frame position
current_frame = int(ratio * total_frames)
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
ret, frame = cap.read()
prev_frame_idx = None # Reset previous frame index after seeking
# Handle playback
if playing:
ret, frame = cap.read()
if not ret:
# End of video, loop back to start
playing = False
current_frame = 0
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
ret, frame = cap.read()
prev_frame_idx = None
last_valid_pose_frame = None
if not ret:
break
current_frame += 1
# Find the appropriate pose frame for the current video frame
frame_sampling = pose_data['frame_sampling']
# Try to find the exact frame in lookup
pose_frame = frame_lookup.get(current_frame)
# If not found, find the closest previous frame based on sampling
if not pose_frame:
# Calculate what the nearest pose frame should be
# This searches for the most recent pose frame
nearest_frame = current_frame
while nearest_frame > 0 and nearest_frame not in frame_lookup:
nearest_frame -= 1
if nearest_frame in frame_lookup:
pose_frame = frame_lookup[nearest_frame]
# Update the last valid pose frame if we found one
if pose_frame:
last_valid_pose_frame = pose_frame
# Draw pose on frame - use the most recent valid pose frame
if last_valid_pose_frame:
# Create a special frame dict with only the current pose for draw_pose
temp_pose_data = pose_data.copy()
temp_pose_data['frames'] = [last_valid_pose_frame]
pose_frame = draw_pose(frame, temp_pose_data, current_frame, video_width, video_height)
else:
# If no pose data found yet, just show the original frame
pose_frame = frame.copy()
prev_frame_idx = current_frame
# Convert frame from BGR to RGB for PyGame
rgb_pose_frame = cv2.cvtColor(pose_frame, cv2.COLOR_BGR2RGB)
pygame_pose_frame = pygame.surfarray.make_surface(rgb_pose_frame.swapaxes(0, 1))
# Draw frame
window.blit(pygame_pose_frame, (0, 0))
# Draw UI controls
play_button_rect, stop_button_rect, seekbar_rect = draw_ui_controls(
window, window_width, window_height, playing, current_frame, total_frames
)
# Draw metadata
font = pygame.font.SysFont(None, 20)
metadata_text = f"Frame Sampling: {pose_data['frame_sampling']}, Precision: {pose_data['precision']}"
metadata_surface = font.render(metadata_text, True, WHITE)
window.blit(metadata_surface, (10, 10))
# Update display
pygame.display.flip()
# Cap framerate
clock.tick(fps)
# Clean up
cap.release()
pygame.quit()
def main():
parser = argparse.ArgumentParser(description='Pose Viewer for JSON pose data with video')
parser.add_argument('--video', '-v', required=True, help='Video file path or URL')
parser.add_argument('--json', '-j', required=True, help='JSON pose data file path')
args = parser.parse_args()
# Handle URL input for video
video_path = args.video
if video_path.startswith('http://') or video_path.startswith('https://'):
video_path = download_video(video_path)
if not os.path.exists(video_path):
print(f"❌ Error: Video file not found: {video_path}")
return
if not os.path.exists(args.json):
print(f"❌ Error: JSON file not found: {args.json}")
return
# Run the viewer
run_viewer(video_path, args.json)
if __name__ == "__main__":
main()

View File

@@ -6,4 +6,5 @@ flask-cors>=3.0.10
numpy>=1.19.0
scipy>=1.7.0
pillow>=9.0.0
mediapipe>=0.8.9
mediapipe>=0.8.9
pygame>=2.0.0