1. 將shp的標簽數據轉成coco
# -*- coding: utf-8 -*-
import os, json
import cv2
from osgeo import gdal
import numpy as np
from osgeo import ogr, gdal, osr
from shapely.geometry import box, shape
from shapely.geometry.polygon import Polygon
import collections
import datetime
import geopandas as gpd
import shutildef read_img(filename):dataset=gdal.Open(filename)im_width = dataset.RasterXSizeim_height = dataset.RasterYSizeim_geotrans = dataset.GetGeoTransform()im_proj = dataset.GetProjection()im_data = dataset.ReadAsArray(0,0,im_width,im_height)# del dataset?return im_width, im_height, im_proj, im_geotrans, im_data, datasetdef write_img(filename,im_proj,im_geotrans,im_data):if 'int8' in im_data.dtype.name:datatype = gdal.GDT_Byteelif 'int16' in im_data.dtype.name:datatype = gdal.GDT_UInt16else:datatype = gdal.GDT_Float32if len(im_data.shape) == 3:im_bands, im_height, im_width = im_data.shapeelse:im_bands, (im_height, im_width) = 1,im_data.shape?driver = gdal.GetDriverByName("GTiff")dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)dataset.SetGeoTransform(im_geotrans)dataset.SetProjection(im_proj)if im_bands == 1:dataset.GetRasterBand(1).WriteArray(im_data)else:for i in range(im_bands):dataset.GetRasterBand(i+1).WriteArray(im_data[i])def data2YoloAndCoco(shapefile_path, tif_path): full_name = os.path.split(tif_path)[1]name = full_name[:-4]# 打開Shapefile文件shapefile_ds = ogr.Open(shapefile_path)if shapefile_ds is None:print("無法打開Shapefile文件")returnshapefile_layer = shapefile_ds.GetLayer()#feature_num = shapefile_layer.GetFeatureCount() ?# get poly count# 打開TIFF文件獲取地理轉換信息tif_ds = gdal.Open(tif_path)if tif_ds is None:print("無法打開TIFF文件")returnwidth = tif_ds.RasterXSizeheight = tif_ds.RasterYSizeyolo_label_path = os.path.join(yolo_txt_path, name + ".txt")txt = open(yolo_label_path, 'w')def get_bbox_points(ring, geo_transform, x_res, y_res, width, height):corner1 = ring.GetPoint(0)corner2 = ring.GetPoint(1)corner3 = ring.GetPoint(2)corner4 = ring.GetPoint(3)# print(corner1, corner2, corner3, corner4)# 計算像素坐標,考慮分辨率pixel_corner1 = (int((corner1[0] - geo_transform[0]) / x_res), int((corner1[1] - geo_transform[3]) / y_res))pixel_corner2 = (int((corner2[0] - geo_transform[0]) / x_res), int((corner2[1] - geo_transform[3]) / y_res))pixel_corner3 = (int((corner3[0] - geo_transform[0]) / x_res), int((corner3[1] - geo_transform[3]) / y_res))pixel_corner4 = (int((corner4[0] - geo_transform[0]) / x_res), int((corner4[1] - geo_transform[3]) / y_res))x1, y1_ ?= pixel_corner1x2, y2_ ?= pixel_corner2x3, y3_ ?= pixel_corner3x4, y4_ ?= pixel_corner4y1 = y1_y2 = y2_y3 = y3_y4 = y4_# print(x1,y1,x2,y2,x3,y3,x4,y4)w = x2 - x1h = y3 - y2x_center = x1 + w/2.0y_center = y2 + h/2.0x_normalized = abs(x_center / width)y_normalized = abs(y_center / height)width_normalized = abs(w / width)height_normalized = abs(h / height)return x_normalized, y_normalized, width_normalized, height_normalized#return x1,y1,x2,y2,x3,y3,x4,y4def get_boundary_points(geom, geo_transform, x_res, y_res):points = [] ?# store points in real worldpixels = [] ?# store pixels in imagesx_pixels = []y_pixels = []feature_type = geom.GetGeometryName() #feature_type: LINEARRINGfor j in range(geom.GetPointCount()):px = geom.GetX(j)py = geom.GetY(j)points.append((px, py))for p in points:new_pixel_x = int((p[0] - geo_transform[0]) / x_res)new_pixel_y = int((p[1] - geo_transform[3]) / y_res)x_pixels.append(new_pixel_x)y_pixels.append(new_pixel_y)pixels.append([new_pixel_x, new_pixel_y])return x_pixels, y_pixels, pixelsdef getsegmenation(x_pixels, y_pixels):getsegmenation_list = []minx = min(x_pixels)maxx = max(x_pixels)miny = min(y_pixels)maxy = max(y_pixels)box_w = maxx - minxbox_h = maxy - minybounding_box_area = box_w * box_hbox_info = [minx, miny, box_w, box_h]getsegmenation = [[minx, miny], [maxx, miny], [minx, maxy], [maxx, maxy]]getsegmenation = np.asarray(getsegmenation).flatten().tolist() #segmentation[[x1,y1,x2,y2,...]]getsegmenation_list.append(getsegmenation)return box_info, bounding_box_area, getsegmenation_listif __name__ == "__main__":now = datetime.datetime.now()# 定義coco數據格式data = dict(info=dict(description=None,url=None,version=None,year=now.year,contributor=None,date_created=now.strftime('%Y-%m-%d %H:%M:%S.%f'),),licenses=[dict(url=None,id=0,name=None,)],images=[# license, url, file_name, height, width, date_captured, id],type='instances',annotations=[# segmentation, area, iscrowd, image_id, bbox, category_id, id],categories=[# supercategory, id, name],
)# 定義類別信息#class_names = ["pine", "spruce", "birch", "populus"]cls_dict = {'1':'pine', '2':'spruce', '3':'birch', '4':'populus'}for i, class_name in enumerate(class_names):data["categories"].append({"id": i + 1,"name": class_name,"supercategory": ""})root_tiff_folder = './data4train/train_image_128/'root_shpf_folder = './data4train/train_label_128/'out_json_file = './STDtrain128.json'image_id = 0for sitname in os.listdir(root_shpf_folder):for regionn in os.listdir(os.path.join(root_shpf_folder, sitname)):#tiff_folder = './data4train/train_image_128/'#shpf_folder = './data4train/train_label_128/'shpf_folder = os.path.join(root_shpf_folder, sitname, regionn)tiff_folder = os.path.join(root_tiff_folder, sitname, regionn)# # 遍歷每個shp文件for shpfile in os.listdir(shpf_folder):?? ??? ?if shpfile[-4:] == ".shp":print('Processing shpfile:', shpfile)?? ??? ?shpfile_path = os.path.join(shpf_folder, shpfile)shpfile_name, shpfile_ext = os.path.splitext(shpfile)#siten, regionn, mark_ = shpfile_name.split('-')#tiffile_name = siten + '_' + regionn + '_deno.tif'#shpfile_path = './train_shp/jokisalo_region1_deno_1.shp'tiffile_name = shpfile_name + '.tif'tiffile_path = os.path.join(tiff_folder, tiffile_name)#tiffile_path = './train_img/jokisalo_region1_deno_1.tif'#txt_path = os.path.join(txt_folder, txt_name)#print('tiffile_path:', tiffile_path)#copy tiff image files to a new folder?tiffile_to_path = os.path.join('./data4train/TIFFImage-train-128/', tiffile_name)shutil.copy(tiffile_path, tiffile_to_path)dataset = gdal.Open(tiffile_path)im_width = dataset.RasterXSizeim_height = dataset.RasterYSizedata['images'].append(dict(license=0,url=None,file_name=tiffile_name,height = im_width,width = im_height,date_captured=None,id=image_id,))# 打開Shapefile文件shapefile_ds = ogr.Open(shpfile_path)#gdf = gpd.read_file(shpfile_path)if shapefile_ds is None:print("無法打開Shapefile文件")pass# 獲取字段信息shapefile_layer = shapefile_ds.GetLayer()layer_defn = shapefile_layer.GetLayerDefn()num_fields = layer_defn.GetFieldCount()#feature_num = shapefile_layer.GetFeatureCount() ?# get poly count# 打印字段信息#for i in range(num_fields):#field_defn = layer_defn.GetFieldDefn(i)#print(f"字段名稱: {field_defn.GetName()}, 類型: {field_defn.GetTypeName()}")#exit(0)# 打開TIFF文件獲取地理轉換信息geo_transform = dataset.GetGeoTransform()# 分辨率x_res = geo_transform[1]y_res = geo_transform[5]# 遍歷每個要素#bbox_id = 0x_pixels = []y_pixels = []for feature in shapefile_layer:#print(feature)#exit(0)geometry = feature.GetGeometryRef()ring = geometry.GetGeometryRef(0)?class_id = str(feature.GetField("Class")) #The field retore class id?if class_id in list(cls_dict.keys()): #['1','2','3','4']label = cls_dict[class_id]feature_type = ring.GetGeometryName()x_pixels, y_pixels, point_pixel = get_boundary_points(ring, geo_transform, x_res, y_res) # get xy of each featureif len(x_pixels) > 0 and len(y_pixels) > 0:bbox, area, bbox_points = getsegmenation(x_pixels, y_pixels)# 將邊界框信息保存到COCO格式的字典中data["annotations"].append({"id": len(data['annotations']),"image_id": image_id,"category_id": int(class_id),"segmentation": bbox_points,"area": area,"bbox": bbox,"iscrowd": 0})#bbox_id += 1else:print('class_id is empty!')image_id += 1shapefile_ds = Nonedataset = Nonewith open(out_json_file, 'w') as f:json.dump(data, f)f.close()