我的需求是只保留人(和没有人的副样本),抽样,其他忽略
import json import os import shutil import random def coco_to_yolo_bbox(image_width, image_height, bbox): """ Convert bbox from COCO format to YOLO format. COCO format: [top_left_x, top_left_y, width, height] YOLO format: [x_center, y_center, width, height] (normalized) """ x_tl, y_tl, width, height = bbox x_center = (x_tl + width / 2.0) / image_width y_center = (y_tl + height / 2.0) / image_height width /= image_width height /= image_height return x_center, y_center, width, height def filter_and_convert_annotations(input_annotation_file, output_annotation_dir, output_image_dir, src_image_dir, num_positive_samples=None, num_negative_samples=None): os.makedirs(output_image_dir, exist_ok=True) os.makedirs(output_annotation_dir, exist_ok=True) with open(input_annotation_file, 'r') as f: coco = json.load(f) # 获取类别为 "person" 的 ID person_category_id = [cat['id'] for cat in coco['categories'] if cat['name'] == 'person'][0] # 过滤注释,只保留人类并且有边界框信息的注释 person_annotations = [anno for anno in coco['annotations'] if anno['category_id'] == person_category_id and 'bbox' in anno] # 获取所有含有"人"的图像 ID person_image_ids = set(anno['image_id'] for anno in person_annotations) # 过滤图片,只保留含有人的图片 person_images = [img for img in coco['images'] if img['id'] in person_image_ids] # 设置正样本的数量 if num_positive_samples is None: num_positive_samples = len(person_images) else: num_positive_samples = min(num_positive_samples, len(person_images)) # 随机抽取指定数量的正样本 person_images = random.sample(person_images, num_positive_samples) positive_image_ids = set(img['id'] for img in person_images) # 获取选中的正样本图片 ID # 过滤注释,只保留选中的正样本图片的注释 person_annotations = [anno for anno in person_annotations if anno['image_id'] in positive_image_ids] # 获取没有人的图片 ID all_image_ids = set(img['id'] for img in coco['images']) nonperson_image_ids = list(all_image_ids - person_image_ids) # 转换为列表 # 设置负样本的数量 if num_negative_samples is None: num_negative_samples = len(nonperson_image_ids) else: num_negative_samples = min(num_negative_samples, len(nonperson_image_ids)) # 随机抽取指定数量的负样本 negative_image_ids = random.sample(nonperson_image_ids, num_negative_samples) # 过滤图片,只保留负样本 negative_images = [img for img in coco['images'] if img['id'] in negative_image_ids] # 将图像和标注转换为YOLO格式 for img in person_images: img_id = img['id'] img_width = img['width'] img_height = img['height'] img_file_name = img['file_name'] img_annotations = [anno for anno in person_annotations if anno['image_id'] == img_id] yolo_annotations = [] for anno in img_annotations: bbox = anno['bbox'] x_center, y_center, width, height = coco_to_yolo_bbox(img_width, img_height, bbox) yolo_annotation = f'0 {x_center} {y_center} {width} {height}' # yolo_annotation = f'{person_category_id} {x_center} {y_center} {width} {height}' yolo_annotations.append(yolo_annotation) annotation_file_name = os.path.splitext(img_file_name)[0] + '.txt' with open(os.path.join(output_annotation_dir, annotation_file_name), 'w') as f: f.write('\n'.join(yolo_annotations)) # 为负样本创建空标注文件 for img in negative_images: img_file_name = img['file_name'] annotation_file_name = os.path.splitext(img_file_name)[0] + '.txt' open(os.path.join(output_annotation_dir, annotation_file_name), 'w').close() # 复制图像 selected_image_filenames = [img['file_name'] for img in person_images + negative_images] total_files = len(selected_image_filenames) for i, filename in enumerate(selected_image_filenames): src_path = os.path.join(src_image_dir, filename) dst_path = os.path.join(output_image_dir, filename) if not os.path.exists(dst_path): # 确保不重复拷贝 shutil.copyfile(src_path, dst_path) if i % 100 == 0: # 每拷贝100张图片打印一次进度 print(f"Copied {i}/{total_files} images") # 设置输入输出路径 filter_and_convert_annotations( '/path/coco_2017/annotations_trainval/instances_train2017.json', '/path/coco_2017_yolo/train_label', '/path/coco_2017_yolo/train', '/path/coco_2017/train2017', num_positive_samples=4000, num_negative_samples=400 ) filter_and_convert_annotations( '/path/coco_2017/annotations_trainval/instances_val2017.json', '/path/coco_2017_yolo/val_label', '/path/coco_2017_yolo/val', '/path/coco_2017/val2017', num_positive_samples=400, num_negative_samples=40 )