Files
JDClone/pose_detector_api.py

340 lines
12 KiB
Python
Raw Permalink Normal View History

2025-05-04 03:56:07 -03:00
import base64
import json
import threading
import time
from io import BytesIO
import cv2
import mediapipe as mp
import numpy as np
from flask import Flask, Response, render_template
from flask_cors import CORS
from flask_socketio import SocketIO
# Initialize MediaPipe Pose
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose
# Define colors
BLUE = (255, 0, 0)
GREEN = (0, 255, 0)
RED = (0, 0, 255)
YELLOW = (0, 255, 255)
# Initialize Flask app and SocketIO
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*")
# Global variables to store the latest frames and data
latest_raw_frame = None
latest_annotated_frame = None
latest_landmarks_data = None
processing_active = True
def process_landmarks(landmarks, width, height):
"""Process landmarks into a JSON-serializable format"""
landmark_data = []
if landmarks:
for idx, landmark in enumerate(landmarks):
landmark_data.append({
'idx': idx,
'x': landmark.x * width,
'y': landmark.y * height,
'z': landmark.z,
'visibility': landmark.visibility
})
return landmark_data
def get_body_connections():
"""Return the connections between body parts"""
connections = []
for connection in mp_pose.POSE_CONNECTIONS:
connections.append({
'start': connection[0],
'end': connection[1]
})
return connections
def pose_detection_thread():
global latest_raw_frame, latest_annotated_frame, latest_landmarks_data, processing_active
# Initialize webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open webcam")
return
# Get webcam properties
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Set up MediaPipe Pose
with mp_pose.Pose(
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as pose:
# Initialize the last_time variable properly
last_time = time.time()
while cap.isOpened() and processing_active:
success, frame = cap.read()
if not success:
print("Error: Could not read frame")
break
# Flip the image horizontally for a selfie-view display
frame = cv2.flip(frame, 1)
# Store a copy of the raw frame
raw_frame = frame.copy()
latest_raw_frame = raw_frame.copy()
# To improve performance, optionally mark the image as not writeable
frame.flags.writeable = False
# Convert image to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image with MediaPipe Pose
results = pose.process(frame_rgb)
# Draw the pose annotations on the image
frame.flags.writeable = True
# Prepare landmark data dictionary
landmarks_data = {
'landmarks': [],
'connections': get_body_connections(),
'image_width': frame_width,
'image_height': frame_height,
'timestamp': time.time()
}
if results.pose_landmarks:
# Store landmark data
landmarks_data['landmarks'] = process_landmarks(
results.pose_landmarks.landmark,
frame_width,
frame_height
)
# Draw the pose landmarks
mp_drawing.draw_landmarks(
frame,
results.pose_landmarks,
mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())
# You can also draw custom lines or highlight specific landmarks
landmarks = results.pose_landmarks.landmark
# Example: Draw a line between the shoulders with a custom color
h, w, c = frame.shape
try:
# Left shoulder to right shoulder
left_shoulder = (int(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y * h))
right_shoulder = (int(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y * h))
cv2.line(frame, left_shoulder, right_shoulder, YELLOW, 4)
# Custom lines for torso
left_hip = (int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].y * h))
right_hip = (int(landmarks[mp_pose.PoseLandmark.RIGHT_HIP].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y * h))
# Draw torso lines
cv2.line(frame, left_shoulder, left_hip, RED, 4)
cv2.line(frame, right_shoulder, right_hip, RED, 4)
cv2.line(frame, left_hip, right_hip, RED, 4)
# Draw arms
left_elbow = (int(landmarks[mp_pose.PoseLandmark.LEFT_ELBOW].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_ELBOW].y * h))
right_elbow = (int(landmarks[mp_pose.PoseLandmark.RIGHT_ELBOW].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_ELBOW].y * h))
left_wrist = (int(landmarks[mp_pose.PoseLandmark.LEFT_WRIST].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y * h))
right_wrist = (int(landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y * h))
cv2.line(frame, left_shoulder, left_elbow, BLUE, 4)
cv2.line(frame, left_elbow, left_wrist, BLUE, 4)
cv2.line(frame, right_shoulder, right_elbow, BLUE, 4)
cv2.line(frame, right_elbow, right_wrist, BLUE, 4)
# Draw legs
left_knee = (int(landmarks[mp_pose.PoseLandmark.LEFT_KNEE].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_KNEE].y * h))
right_knee = (int(landmarks[mp_pose.PoseLandmark.RIGHT_KNEE].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_KNEE].y * h))
left_ankle = (int(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].x * w),
int(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].y * h))
right_ankle = (int(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].x * w),
int(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].y * h))
cv2.line(frame, left_hip, left_knee, GREEN, 4)
cv2.line(frame, left_knee, left_ankle, GREEN, 4)
cv2.line(frame, right_hip, right_knee, GREEN, 4)
cv2.line(frame, right_knee, right_ankle, GREEN, 4)
except:
pass
# Add FPS counter
current_time = time.time()
time_diff = current_time - last_time
# Avoid division by zero
if time_diff > 0:
fps = int(1 / time_diff)
else:
fps = 0
last_time = current_time
cv2.putText(frame, f"FPS: {fps}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# Update global variables with the latest data
latest_annotated_frame = frame.copy()
latest_landmarks_data = landmarks_data
# Emit landmark data through SocketIO
socketio.emit('landmarks', json.dumps(landmarks_data))
# Sleep to avoid excessive CPU usage
time.sleep(0.01)
# Release the webcam
cap.release()
def generate_frames(get_annotated=False):
"""Generator function for streaming frames"""
while True:
if get_annotated and latest_annotated_frame is not None:
frame = latest_annotated_frame.copy()
elif latest_raw_frame is not None:
frame = latest_raw_frame.copy()
else:
# If no frames are available yet, yield an empty response
time.sleep(0.1)
continue
# Encode frame as JPEG
_, buffer = cv2.imencode('.jpg', frame)
frame_bytes = buffer.tobytes()
# Yield the frame in the format expected by Response
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
# Flask routes
@app.route('/')
def index():
"""Serve a simple HTML page for testing the API"""
return """
<!DOCTYPE html>
<html>
<head>
<title>Pose Detection API</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
}
.container {
display: flex;
flex-wrap: wrap;
gap: 20px;
margin-bottom: 20px;
}
.video-container {
width: 480px;
}
h1, h2, h3 {
color: #333;
}
pre {
background-color: #f4f4f4;
padding: 10px;
border-radius: 5px;
overflow: auto;
max-height: 300px;
}
</style>
</head>
<body>
<h1>Pose Detection API</h1>
<div class="container">
<div class="video-container">
<h2>Raw Video Feed</h2>
<img src="/video_feed" width="100%" />
</div>
<div class="video-container">
<h2>Annotated Video Feed</h2>
<img src="/video_feed/annotated" width="100%" />
</div>
</div>
<h2>Landmark Data (Live Updates)</h2>
<pre id="landmarks-data">Waiting for data...</pre>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
<script>
// Connect to the Socket.IO server
const socket = io();
// Listen for landmark data
socket.on('landmarks', function(data) {
const landmarksData = JSON.parse(data);
document.getElementById('landmarks-data').textContent =
JSON.stringify(landmarksData, null, 2);
});
</script>
</body>
</html>
"""
@app.route('/video_feed')
def video_feed():
"""Route to serve the raw video feed"""
return Response(generate_frames(get_annotated=False),
mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/video_feed/annotated')
def video_feed_annotated():
"""Route to serve the annotated video feed"""
return Response(generate_frames(get_annotated=True),
mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/landmarks')
def get_landmarks():
"""Route to get the latest landmarks data"""
if latest_landmarks_data:
return Response(json.dumps(latest_landmarks_data),
mimetype='application/json')
else:
return Response(json.dumps({"error": "No landmarks data available yet"}),
mimetype='application/json')
def main():
# Start the pose detection thread
detection_thread = threading.Thread(target=pose_detection_thread)
detection_thread.daemon = True
detection_thread.start()
# Start the Flask app with SocketIO
print("Starting API server at http://127.0.0.1:5000")
socketio.run(app, host='0.0.0.0', port=5000, debug=False, allow_unsafe_werkzeug=True)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("Shutting down...")
processing_active = False