340 lines
12 KiB
Python
340 lines
12 KiB
Python
import base64
|
|
import json
|
|
import threading
|
|
import time
|
|
from io import BytesIO
|
|
|
|
import cv2
|
|
import mediapipe as mp
|
|
import numpy as np
|
|
from flask import Flask, Response, render_template
|
|
from flask_cors import CORS
|
|
from flask_socketio import SocketIO
|
|
|
|
# Initialize MediaPipe Pose
|
|
mp_drawing = mp.solutions.drawing_utils
|
|
mp_drawing_styles = mp.solutions.drawing_styles
|
|
mp_pose = mp.solutions.pose
|
|
|
|
# Define colors
|
|
BLUE = (255, 0, 0)
|
|
GREEN = (0, 255, 0)
|
|
RED = (0, 0, 255)
|
|
YELLOW = (0, 255, 255)
|
|
|
|
# Initialize Flask app and SocketIO
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
socketio = SocketIO(app, cors_allowed_origins="*")
|
|
|
|
# Global variables to store the latest frames and data
|
|
latest_raw_frame = None
|
|
latest_annotated_frame = None
|
|
latest_landmarks_data = None
|
|
processing_active = True
|
|
|
|
def process_landmarks(landmarks, width, height):
|
|
"""Process landmarks into a JSON-serializable format"""
|
|
landmark_data = []
|
|
|
|
if landmarks:
|
|
for idx, landmark in enumerate(landmarks):
|
|
landmark_data.append({
|
|
'idx': idx,
|
|
'x': landmark.x * width,
|
|
'y': landmark.y * height,
|
|
'z': landmark.z,
|
|
'visibility': landmark.visibility
|
|
})
|
|
|
|
return landmark_data
|
|
|
|
def get_body_connections():
|
|
"""Return the connections between body parts"""
|
|
connections = []
|
|
for connection in mp_pose.POSE_CONNECTIONS:
|
|
connections.append({
|
|
'start': connection[0],
|
|
'end': connection[1]
|
|
})
|
|
return connections
|
|
|
|
def pose_detection_thread():
|
|
global latest_raw_frame, latest_annotated_frame, latest_landmarks_data, processing_active
|
|
|
|
# Initialize webcam
|
|
cap = cv2.VideoCapture(0)
|
|
|
|
if not cap.isOpened():
|
|
print("Error: Could not open webcam")
|
|
return
|
|
|
|
# Get webcam properties
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
# Set up MediaPipe Pose
|
|
with mp_pose.Pose(
|
|
min_detection_confidence=0.5,
|
|
min_tracking_confidence=0.5) as pose:
|
|
|
|
# Initialize the last_time variable properly
|
|
last_time = time.time()
|
|
|
|
while cap.isOpened() and processing_active:
|
|
success, frame = cap.read()
|
|
if not success:
|
|
print("Error: Could not read frame")
|
|
break
|
|
|
|
# Flip the image horizontally for a selfie-view display
|
|
frame = cv2.flip(frame, 1)
|
|
|
|
# Store a copy of the raw frame
|
|
raw_frame = frame.copy()
|
|
latest_raw_frame = raw_frame.copy()
|
|
|
|
# To improve performance, optionally mark the image as not writeable
|
|
frame.flags.writeable = False
|
|
# Convert image to RGB
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
# Process the image with MediaPipe Pose
|
|
results = pose.process(frame_rgb)
|
|
|
|
# Draw the pose annotations on the image
|
|
frame.flags.writeable = True
|
|
|
|
# Prepare landmark data dictionary
|
|
landmarks_data = {
|
|
'landmarks': [],
|
|
'connections': get_body_connections(),
|
|
'image_width': frame_width,
|
|
'image_height': frame_height,
|
|
'timestamp': time.time()
|
|
}
|
|
|
|
if results.pose_landmarks:
|
|
# Store landmark data
|
|
landmarks_data['landmarks'] = process_landmarks(
|
|
results.pose_landmarks.landmark,
|
|
frame_width,
|
|
frame_height
|
|
)
|
|
|
|
# Draw the pose landmarks
|
|
mp_drawing.draw_landmarks(
|
|
frame,
|
|
results.pose_landmarks,
|
|
mp_pose.POSE_CONNECTIONS,
|
|
landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())
|
|
|
|
# You can also draw custom lines or highlight specific landmarks
|
|
landmarks = results.pose_landmarks.landmark
|
|
|
|
# Example: Draw a line between the shoulders with a custom color
|
|
h, w, c = frame.shape
|
|
try:
|
|
# Left shoulder to right shoulder
|
|
left_shoulder = (int(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y * h))
|
|
right_shoulder = (int(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y * h))
|
|
cv2.line(frame, left_shoulder, right_shoulder, YELLOW, 4)
|
|
|
|
# Custom lines for torso
|
|
left_hip = (int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].y * h))
|
|
right_hip = (int(landmarks[mp_pose.PoseLandmark.RIGHT_HIP].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y * h))
|
|
|
|
# Draw torso lines
|
|
cv2.line(frame, left_shoulder, left_hip, RED, 4)
|
|
cv2.line(frame, right_shoulder, right_hip, RED, 4)
|
|
cv2.line(frame, left_hip, right_hip, RED, 4)
|
|
|
|
# Draw arms
|
|
left_elbow = (int(landmarks[mp_pose.PoseLandmark.LEFT_ELBOW].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.LEFT_ELBOW].y * h))
|
|
right_elbow = (int(landmarks[mp_pose.PoseLandmark.RIGHT_ELBOW].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.RIGHT_ELBOW].y * h))
|
|
left_wrist = (int(landmarks[mp_pose.PoseLandmark.LEFT_WRIST].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y * h))
|
|
right_wrist = (int(landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y * h))
|
|
|
|
cv2.line(frame, left_shoulder, left_elbow, BLUE, 4)
|
|
cv2.line(frame, left_elbow, left_wrist, BLUE, 4)
|
|
cv2.line(frame, right_shoulder, right_elbow, BLUE, 4)
|
|
cv2.line(frame, right_elbow, right_wrist, BLUE, 4)
|
|
|
|
# Draw legs
|
|
left_knee = (int(landmarks[mp_pose.PoseLandmark.LEFT_KNEE].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.LEFT_KNEE].y * h))
|
|
right_knee = (int(landmarks[mp_pose.PoseLandmark.RIGHT_KNEE].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.RIGHT_KNEE].y * h))
|
|
left_ankle = (int(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].y * h))
|
|
right_ankle = (int(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].x * w),
|
|
int(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].y * h))
|
|
|
|
cv2.line(frame, left_hip, left_knee, GREEN, 4)
|
|
cv2.line(frame, left_knee, left_ankle, GREEN, 4)
|
|
cv2.line(frame, right_hip, right_knee, GREEN, 4)
|
|
cv2.line(frame, right_knee, right_ankle, GREEN, 4)
|
|
except:
|
|
pass
|
|
|
|
# Add FPS counter
|
|
current_time = time.time()
|
|
time_diff = current_time - last_time
|
|
# Avoid division by zero
|
|
if time_diff > 0:
|
|
fps = int(1 / time_diff)
|
|
else:
|
|
fps = 0
|
|
last_time = current_time
|
|
cv2.putText(frame, f"FPS: {fps}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
|
# Update global variables with the latest data
|
|
latest_annotated_frame = frame.copy()
|
|
latest_landmarks_data = landmarks_data
|
|
|
|
# Emit landmark data through SocketIO
|
|
socketio.emit('landmarks', json.dumps(landmarks_data))
|
|
|
|
# Sleep to avoid excessive CPU usage
|
|
time.sleep(0.01)
|
|
|
|
# Release the webcam
|
|
cap.release()
|
|
|
|
def generate_frames(get_annotated=False):
|
|
"""Generator function for streaming frames"""
|
|
while True:
|
|
if get_annotated and latest_annotated_frame is not None:
|
|
frame = latest_annotated_frame.copy()
|
|
elif latest_raw_frame is not None:
|
|
frame = latest_raw_frame.copy()
|
|
else:
|
|
# If no frames are available yet, yield an empty response
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
# Encode frame as JPEG
|
|
_, buffer = cv2.imencode('.jpg', frame)
|
|
frame_bytes = buffer.tobytes()
|
|
|
|
# Yield the frame in the format expected by Response
|
|
yield (b'--frame\r\n'
|
|
b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
|
|
|
|
# Flask routes
|
|
@app.route('/')
|
|
def index():
|
|
"""Serve a simple HTML page for testing the API"""
|
|
return """
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Pose Detection API</title>
|
|
<style>
|
|
body {
|
|
font-family: Arial, sans-serif;
|
|
max-width: 1000px;
|
|
margin: 0 auto;
|
|
padding: 20px;
|
|
}
|
|
.container {
|
|
display: flex;
|
|
flex-wrap: wrap;
|
|
gap: 20px;
|
|
margin-bottom: 20px;
|
|
}
|
|
.video-container {
|
|
width: 480px;
|
|
}
|
|
h1, h2, h3 {
|
|
color: #333;
|
|
}
|
|
pre {
|
|
background-color: #f4f4f4;
|
|
padding: 10px;
|
|
border-radius: 5px;
|
|
overflow: auto;
|
|
max-height: 300px;
|
|
}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<h1>Pose Detection API</h1>
|
|
|
|
<div class="container">
|
|
<div class="video-container">
|
|
<h2>Raw Video Feed</h2>
|
|
<img src="/video_feed" width="100%" />
|
|
</div>
|
|
|
|
<div class="video-container">
|
|
<h2>Annotated Video Feed</h2>
|
|
<img src="/video_feed/annotated" width="100%" />
|
|
</div>
|
|
</div>
|
|
|
|
<h2>Landmark Data (Live Updates)</h2>
|
|
<pre id="landmarks-data">Waiting for data...</pre>
|
|
|
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
|
|
<script>
|
|
// Connect to the Socket.IO server
|
|
const socket = io();
|
|
|
|
// Listen for landmark data
|
|
socket.on('landmarks', function(data) {
|
|
const landmarksData = JSON.parse(data);
|
|
document.getElementById('landmarks-data').textContent =
|
|
JSON.stringify(landmarksData, null, 2);
|
|
});
|
|
</script>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
@app.route('/video_feed')
|
|
def video_feed():
|
|
"""Route to serve the raw video feed"""
|
|
return Response(generate_frames(get_annotated=False),
|
|
mimetype='multipart/x-mixed-replace; boundary=frame')
|
|
|
|
@app.route('/video_feed/annotated')
|
|
def video_feed_annotated():
|
|
"""Route to serve the annotated video feed"""
|
|
return Response(generate_frames(get_annotated=True),
|
|
mimetype='multipart/x-mixed-replace; boundary=frame')
|
|
|
|
@app.route('/landmarks')
|
|
def get_landmarks():
|
|
"""Route to get the latest landmarks data"""
|
|
if latest_landmarks_data:
|
|
return Response(json.dumps(latest_landmarks_data),
|
|
mimetype='application/json')
|
|
else:
|
|
return Response(json.dumps({"error": "No landmarks data available yet"}),
|
|
mimetype='application/json')
|
|
|
|
def main():
|
|
# Start the pose detection thread
|
|
detection_thread = threading.Thread(target=pose_detection_thread)
|
|
detection_thread.daemon = True
|
|
detection_thread.start()
|
|
|
|
# Start the Flask app with SocketIO
|
|
print("Starting API server at http://127.0.0.1:5000")
|
|
socketio.run(app, host='0.0.0.0', port=5000, debug=False, allow_unsafe_werkzeug=True)
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
print("Shutting down...")
|
|
processing_active = False |