126 lines
4.2 KiB
Python
126 lines
4.2 KiB
Python
import cv2
|
|
import onnxruntime as ort
|
|
import torch
|
|
import ultralytics
|
|
from ultralytics import YOLO
|
|
import numpy as np
|
|
import ultralytics
|
|
from torchvision import transforms
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.image as mpimg
|
|
|
|
yolo_model_path = 'yolo11m-2_uniform.onnx'
|
|
ort_session = ort.InferenceSession(yolo_model_path)
|
|
|
|
pose_model_path = 'yolo11s-pose.pt'
|
|
pose_model = YOLO(pose_model_path)
|
|
|
|
def extract_person_boxes(yolo_outputs):
|
|
"""Extract person bounding boxes from YOLO ONNX output"""
|
|
boxes = yolo_outputs[0]
|
|
conf_threshold = 0.5
|
|
|
|
# Filter based on confidence score only
|
|
selected_indices = np.where(
|
|
boxes[:, 4] > conf_threshold
|
|
)[0]
|
|
|
|
return boxes[selected_indices, :4].astype(int)
|
|
|
|
def evaluate_sop(pose_output):
|
|
"""Basic SOP compliance evaluation (example implementation):
|
|
- Checks if both arms are visible (keypoint confidence > 0.6)
|
|
- Checks torso vertical angle (placeholder logic)"""
|
|
if len(pose_output) == 0 or pose_output[0].keypoints.shape[0] == 0:
|
|
return False # No keypoints detected, assume non-compliant
|
|
|
|
keypoints = pose_output[0].keypoints
|
|
|
|
# Example conditions
|
|
left_shoulder_conf = keypoints[5, 2] if keypoints.shape[0] > 5 else 0
|
|
right_shoulder_conf = keypoints[6, 2] if keypoints.shape[0] > 6 else 0
|
|
left_elbow_conf = keypoints[7, 2] if keypoints.shape[0] > 7 else 0
|
|
right_elbow_conf = keypoints[8, 2] if keypoints.shape[0] > 8 else 0
|
|
|
|
# Simple compliance criteria
|
|
arms_visible = (
|
|
left_shoulder_conf > 0.6 and
|
|
right_shoulder_conf > 0.6 and
|
|
left_elbow_conf > 0.6 and
|
|
right_elbow_conf > 0.6
|
|
)
|
|
|
|
# Add more conditions based on actual SOP requirements
|
|
return arms_visible # Temporary compliance criteria
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
video_path = 'PF-071124-2.mp4'
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
if not cap.isOpened():
|
|
print("Error opening video file")
|
|
exit()
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
input_frame = cv2.resize(frame, (640, 640))
|
|
input_frame = np.transpose(input_frame, (2, 0, 1)).astype(np.float32) / 255.0
|
|
input_tensor = np.expand_dims(input_frame, 0)
|
|
|
|
yolo_outputs = ort_session.run(None, {'images': input_tensor})
|
|
print("YOLO Outputs Shape:", [output.shape for output in yolo_outputs])
|
|
print("YOLO Outputs:", yolo_outputs)
|
|
|
|
persons = extract_person_boxes(yolo_outputs)
|
|
|
|
for bbox in persons:
|
|
x1, y1, x2, y2 = map(lambda arr: arr[0], bbox)
|
|
#x1, y1, x2, y2 = map(int, bbox)
|
|
roi = frame[y1:y2, x1:x2]
|
|
|
|
roi_resized = cv2.resize(roi, (256, 256))
|
|
roi_tensor = transform(roi_resized).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
pose_output = pose_model(roi_tensor)
|
|
|
|
print("Pose Output Type:", type(pose_output))
|
|
print("Pose Output Keys:", pose_output.keys()) if hasattr(pose_output, 'keys') else print("Pose Output:", pose_output)
|
|
|
|
if pose_output[0].keypoints.shape[0] > 0:
|
|
keypoints = pose_output[0].keypoints
|
|
if keypoints.shape[1] > 5: # Ensure there are at least 6 keypoints
|
|
compliant = evaluate_sop(pose_output)
|
|
else:
|
|
compliant = False # Not enough keypoints detected, assume non-compliant
|
|
else:
|
|
compliant = False # No keypoints detected, assume non-compliant
|
|
|
|
color = (0, 255, 0) if compliant else (0, 0, 255)
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
label = 'Compliant' if compliant else 'Non-compliant'
|
|
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
|
|
|
# Display the frame using matplotlib
|
|
plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
plt.axis('off') # Turn off axis labels
|
|
plt.show(block=False)
|
|
plt.pause(0.01)
|
|
plt.clf() # Clear the current figure
|
|
|
|
if cv2.waitKey(25) & 0xFF == ord('q'):
|
|
break
|
|
#if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
# break
|
|
|
|
cap.release()
|
|
|
|
cv2.destroyAllWindows()
|