Files
JDClone/pose_detector_api.py
2025-05-04 03:56:07 -03:00

340 lines
12 KiB
Python

import base64
import json
import threading
import time
from io import BytesIO
import cv2
import mediapipe as mp
import numpy as np
from flask import Flask, Response, render_template
from flask_cors import CORS
from flask_socketio import SocketIO
# Initialize MediaPipe Pose
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose
# Define colors
BLUE = (255, 0, 0)
GREEN = (0, 255, 0)
RED = (0, 0, 255)
YELLOW = (0, 255, 255)
# Initialize Flask app and SocketIO
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*")
# Global variables to store the latest frames and data
latest_raw_frame = None
latest_annotated_frame = None
latest_landmarks_data = None
processing_active = True
def process_landmarks(landmarks, width, height):
"""Process landmarks into a JSON-serializable format"""
landmark_data = []
if landmarks:
for idx, landmark in enumerate(landmarks):
landmark_data.append({
'idx': idx,
'x': landmark.x * width,
'y': landmark.y * height,
'z': landmark.z,
'visibility': landmark.visibility
})
return landmark_data
def get_body_connections():
"""Return the connections between body parts"""
connections = []
for connection in mp_pose.POSE_CONNECTIONS:
connections.append({
'start': connection[0],
'end': connection[1]
})
return connections
def pose_detection_thread():
global latest_raw_frame, latest_annotated_frame, latest_landmarks_data, processing_active
# Initialize webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open webcam")
return
# Get webcam properties
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Set up MediaPipe Pose
with mp_pose.Pose(
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as pose:
# Initialize the last_time variable properly
last_time = time.time()
while cap.isOpened() and processing_active:
success, frame = cap.read()
if not success:
print("Error: Could not read frame")
break
# Flip the image horizontally for a selfie-view display
frame = cv2.flip(frame, 1)
# Store a copy of the raw frame
raw_frame = frame.copy()
latest_raw_frame = raw_frame.copy()
# To improve performance, optionally mark the image as not writeable
frame.flags.writeable = False
# Convert image to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image with MediaPipe Pose
results = pose.process(frame_rgb)
# Draw the pose annotations on the image
frame.flags.writeable = True
# Prepare landmark data dictionary
landmarks_data = {
'landmarks': [],
'connections': get_body_connections(),
'image_width': frame_width,
'image_height': frame_height,
'timestamp': time.time()
}
if results.pose_landmarks:
# Store landmark data
landmarks_data['landmarks'] = process_landmarks(
results.pose_landmarks.landmark,
frame_width,
frame_height
)
# Draw the pose landmarks
mp_drawing.draw_landmarks(
frame,
results.pose_landmarks,
mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())
# You can also draw custom lines or highlight specific landmarks
landmarks = results.pose_landmarks.landmark
# Example: Draw a line between the shoulders with a custom color
h, w, c = frame.shape
try:
# Left shoulder to right shoulder
left_shoulder = (int(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y * h))
right_shoulder = (int(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y * h))
cv2.line(frame, left_shoulder, right_shoulder, YELLOW, 4)
# Custom lines for torso
left_hip = (int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].y * h))
right_hip = (int(landmarks[mp_pose.PoseLandmark.RIGHT_HIP].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y * h))
# Draw torso lines
cv2.line(frame, left_shoulder, left_hip, RED, 4)
cv2.line(frame, right_shoulder, right_hip, RED, 4)
cv2.line(frame, left_hip, right_hip, RED, 4)
# Draw arms
left_elbow = (int(landmarks[mp_pose.PoseLandmark.LEFT_ELBOW].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_ELBOW].y * h))
right_elbow = (int(landmarks[mp_pose.PoseLandmark.RIGHT_ELBOW].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_ELBOW].y * h))
left_wrist = (int(landmarks[mp_pose.PoseLandmark.LEFT_WRIST].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y * h))
right_wrist = (int(landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y * h))
cv2.line(frame, left_shoulder, left_elbow, BLUE, 4)
cv2.line(frame, left_elbow, left_wrist, BLUE, 4)
cv2.line(frame, right_shoulder, right_elbow, BLUE, 4)
cv2.line(frame, right_elbow, right_wrist, BLUE, 4)
# Draw legs
left_knee = (int(landmarks[mp_pose.PoseLandmark.LEFT_KNEE].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_KNEE].y * h))
right_knee = (int(landmarks[mp_pose.PoseLandmark.RIGHT_KNEE].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_KNEE].y * h))
left_ankle = (int(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].y * h))
right_ankle = (int(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].y * h))
cv2.line(frame, left_hip, left_knee, GREEN, 4)
cv2.line(frame, left_knee, left_ankle, GREEN, 4)
cv2.line(frame, right_hip, right_knee, GREEN, 4)
cv2.line(frame, right_knee, right_ankle, GREEN, 4)
except:
pass
# Add FPS counter
current_time = time.time()
time_diff = current_time - last_time
# Avoid division by zero
if time_diff > 0:
fps = int(1 / time_diff)
else:
fps = 0
last_time = current_time
cv2.putText(frame, f"FPS: {fps}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# Update global variables with the latest data
latest_annotated_frame = frame.copy()
latest_landmarks_data = landmarks_data
# Emit landmark data through SocketIO
socketio.emit('landmarks', json.dumps(landmarks_data))
# Sleep to avoid excessive CPU usage
time.sleep(0.01)
# Release the webcam
cap.release()
def generate_frames(get_annotated=False):
"""Generator function for streaming frames"""
while True:
if get_annotated and latest_annotated_frame is not None:
frame = latest_annotated_frame.copy()
elif latest_raw_frame is not None:
frame = latest_raw_frame.copy()
else:
# If no frames are available yet, yield an empty response
time.sleep(0.1)
continue
# Encode frame as JPEG
_, buffer = cv2.imencode('.jpg', frame)
frame_bytes = buffer.tobytes()
# Yield the frame in the format expected by Response
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
# Flask routes
@app.route('/')
def index():
"""Serve a simple HTML page for testing the API"""
return """
<!DOCTYPE html>
<html>
<head>
<title>Pose Detection API</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
}
.container {
display: flex;
flex-wrap: wrap;
gap: 20px;
margin-bottom: 20px;
}
.video-container {
width: 480px;
}
h1, h2, h3 {
color: #333;
}
pre {
background-color: #f4f4f4;
padding: 10px;
border-radius: 5px;
overflow: auto;
max-height: 300px;
}
</style>
</head>
<body>
<h1>Pose Detection API</h1>
<div class="container">
<div class="video-container">
<h2>Raw Video Feed</h2>
<img src="/video_feed" width="100%" />
</div>
<div class="video-container">
<h2>Annotated Video Feed</h2>
<img src="/video_feed/annotated" width="100%" />
</div>
</div>
<h2>Landmark Data (Live Updates)</h2>
<pre id="landmarks-data">Waiting for data...</pre>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
<script>
// Connect to the Socket.IO server
const socket = io();
// Listen for landmark data
socket.on('landmarks', function(data) {
const landmarksData = JSON.parse(data);
document.getElementById('landmarks-data').textContent =
JSON.stringify(landmarksData, null, 2);
});
</script>
</body>
</html>
"""
@app.route('/video_feed')
def video_feed():
"""Route to serve the raw video feed"""
return Response(generate_frames(get_annotated=False),
mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/video_feed/annotated')
def video_feed_annotated():
"""Route to serve the annotated video feed"""
return Response(generate_frames(get_annotated=True),
mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/landmarks')
def get_landmarks():
"""Route to get the latest landmarks data"""
if latest_landmarks_data:
return Response(json.dumps(latest_landmarks_data),
mimetype='application/json')
else:
return Response(json.dumps({"error": "No landmarks data available yet"}),
mimetype='application/json')
def main():
# Start the pose detection thread
detection_thread = threading.Thread(target=pose_detection_thread)
detection_thread.daemon = True
detection_thread.start()
# Start the Flask app with SocketIO
print("Starting API server at http://127.0.0.1:5000")
socketio.run(app, host='0.0.0.0', port=5000, debug=False, allow_unsafe_werkzeug=True)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("Shutting down...")
processing_active = False