2013-01-16 22:21:47 +08:00
#!/usr/bin/env python
import argparse
import sft
import sys , os , os . path , glob , math , cv2
from datetime import datetime
import numpy
2013-01-28 19:34:53 +08:00
plot_colors = [ ' b ' , ' c ' , ' r ' , ' g ' , ' m ' ]
2013-01-21 19:53:25 +08:00
2013-01-28 22:31:05 +08:00
# "key" : ( b, g, r)
2013-01-21 00:20:08 +08:00
bgr = { " red " : ( 0 , 0 , 255 ) ,
" green " : ( 0 , 255 , 0 ) ,
" blue " : ( 255 , 0 , 0 ) }
2013-01-24 18:45:11 +08:00
def range ( s ) :
try :
lb , rb = map ( int , s . split ( ' , ' ) )
return lb , rb
except :
raise argparse . ArgumentTypeError ( " Must be lb, rb " )
2013-01-16 22:21:47 +08:00
def call_parser ( f , a ) :
return eval ( " sft.parse_ " + f + " ( ' " + a + " ' ) " )
if __name__ == " __main__ " :
2013-01-28 19:34:53 +08:00
parser = argparse . ArgumentParser ( description = ' Plot ROC curve using Caltech method of per image detection performance estimation. ' )
2013-01-16 22:21:47 +08:00
# positional
2013-01-21 19:53:25 +08:00
parser . add_argument ( " cascade " , help = " Path to the tested detector. " , nargs = ' + ' )
2013-01-16 22:21:47 +08:00
parser . add_argument ( " input " , help = " Image sequence pattern. " )
parser . add_argument ( " annotations " , help = " Path to the annotations. " )
# optional
parser . add_argument ( " -m " , " --min_scale " , dest = " min_scale " , type = float , metavar = " fl " , help = " Minimum scale to be tested. " , default = 0.4 )
parser . add_argument ( " -M " , " --max_scale " , dest = " max_scale " , type = float , metavar = " fl " , help = " Maximum scale to be tested. " , default = 5.0 )
2013-01-28 19:34:53 +08:00
parser . add_argument ( " -o " , " --output " , dest = " output " , type = str , metavar = " path " , help = " Path to store resulting image. " , default = " ./roc.png " )
parser . add_argument ( " -n " , " --nscales " , dest = " nscales " , type = int , metavar = " n " , help = " Preferred count of scales from min to max. " , default = 55 )
2013-01-16 22:21:47 +08:00
2013-01-24 18:45:11 +08:00
parser . add_argument ( " -r " , " --scale-range " , dest = " scale_range " , type = range , default = ( 128 * 0.4 , 128 * 2.4 ) )
parser . add_argument ( " -e " , " --extended-range-ratio " , dest = " ext_ratio " , type = float , default = 1.25 )
2013-01-28 19:34:53 +08:00
parser . add_argument ( " -t " , " --title " , dest = " title " , type = str , default = " ROC curve Bahnhof " )
2013-01-24 18:45:11 +08:00
2013-01-16 22:21:47 +08:00
# required
parser . add_argument ( " -f " , " --anttn-format " , dest = " anttn_format " , choices = [ ' inria ' , ' caltech ' , " idl " ] , help = " Annotation file for test sequence. " , required = True )
2013-01-28 19:34:53 +08:00
parser . add_argument ( " -l " , " --labels " , dest = " labels " , required = True , help = " Plot labels for legend. " , nargs = ' + ' )
2013-01-16 22:21:47 +08:00
args = parser . parse_args ( )
2013-01-24 18:45:11 +08:00
print args . scale_range
2013-01-21 19:53:25 +08:00
print args . cascade
2013-01-28 19:34:53 +08:00
# parse annotations
sft . initPlot ( args . title )
2013-01-16 22:21:47 +08:00
samples = call_parser ( args . anttn_format , args . annotations )
2013-01-21 19:53:25 +08:00
for idx , each in enumerate ( args . cascade ) :
print each
cascade = sft . cascade ( args . min_scale , args . max_scale , args . nscales , each )
pattern = args . input
camera = cv2 . VideoCapture ( pattern )
# for plotting over dataset
nannotated = 0
nframes = 0
2013-01-20 01:25:09 +08:00
2013-01-21 19:53:25 +08:00
confidenses = [ ]
tp = [ ]
2013-01-24 02:21:56 +08:00
ignored = [ ]
2013-01-21 06:36:23 +08:00
2013-01-21 19:53:25 +08:00
while True :
ret , img = camera . read ( )
if not ret :
break ;
2013-01-21 06:36:23 +08:00
2013-01-21 19:53:25 +08:00
name = pattern % ( nframes , )
_ , tail = os . path . split ( name )
2013-01-16 22:21:47 +08:00
2013-01-24 20:22:08 +08:00
boxes = sft . filter_for_range ( samples [ tail ] , args . scale_range , args . ext_ratio )
2013-01-16 22:21:47 +08:00
2013-01-21 19:53:25 +08:00
nannotated = nannotated + len ( boxes )
nframes = nframes + 1
rects , confs = cascade . detect ( img , rois = None )
2013-01-20 01:25:09 +08:00
2013-01-21 19:53:25 +08:00
if confs is None :
continue
2013-01-16 22:21:47 +08:00
2013-01-21 19:53:25 +08:00
dts = sft . convert2detections ( rects , confs )
2013-01-21 06:36:23 +08:00
2013-01-21 19:53:25 +08:00
confs = confs . tolist ( ) [ 0 ]
confs . sort ( lambda x , y : - 1 if ( x - y ) > 0 else 1 )
confidenses = confidenses + confs
2013-01-21 00:20:08 +08:00
2013-01-24 02:21:56 +08:00
matched , skip_list = sft . match ( boxes , dts )
2013-01-21 19:53:25 +08:00
tp = tp + matched
2013-01-24 02:21:56 +08:00
ignored = ignored + skip_list
2013-01-18 00:36:39 +08:00
2013-01-21 19:53:25 +08:00
print nframes , nannotated
2013-01-16 22:21:47 +08:00
2013-01-24 02:21:56 +08:00
fppi , miss_rate = sft . computeROC ( confidenses , tp , nannotated , nframes , ignored )
2013-01-21 19:53:25 +08:00
sft . plotLogLog ( fppi , miss_rate , plot_colors [ idx ] )
2013-01-18 16:22:03 +08:00
2013-01-28 19:34:53 +08:00
sft . showPlot ( args . output , args . labels )