#!/usr/bin/python

import sys
import math

# function definition is here
def getDepth(curLat, shorelineLat, depthType):

    mileToMeter = 1609.344
    dlat = shorelineLat - curLat
    dist = dlat * 111000

    if (dist <= (-1.5*mileToMeter)):    # more than 1.5 miles inland
        depth = -14.0
    elif (dist < 0):                    # land, up to 1.5 miles inland
        depth = dist*(-14.0/(-1.5*mileToMeter))
    elif (dist == 0):                   # coastline
        depth = 0.0
    else:

        if (depthType == 'exshallow'):
            breakPtDistInMeter = [0.0, 5.0*mileToMeter, 80.0*mileToMeter, 100.0*mileToMeter, 150.0*mileToMeter, -1.0, -1.0]
            breakPtDepthInMeter = [0.0,2.5,5.0,50.0,2000.0,-1.0,-1.0]
            numBreakPt = 5
        elif (depthType == 'shallow'):
            breakPtDistInMeter = [0.0, 5.0*mileToMeter, 40.0*mileToMeter, 60.0*mileToMeter, 80.0*mileToMeter, 100.0*mileToMeter, 150.0*mileToMeter]
            breakPtDepthInMeter = [0.0,4.0,20.0,35.0,60.0,200.0,2000.0]
            numBreakPt = 7
        elif (depthType == 'moderate'):
            breakPtDistInMeter = [0.0, 5.0*mileToMeter, 40.0*mileToMeter, 60.0*mileToMeter, 80.0*mileToMeter, 100.0*mileToMeter, 150.0*mileToMeter]
            breakPtDepthInMeter = [0.0,10.0,35.0,50.0,100.0,400.0,2000.0]
            numBreakPt = 7
        elif (depthType == 'deep'):
            breakPtDistInMeter = [0.0, 5.0*mileToMeter, 40.0*mileToMeter, 60.0*mileToMeter, 80.0*mileToMeter, 100.0*mileToMeter, 150.0*mileToMeter]
            breakPtDepthInMeter=[0.0,20.0,70.0,600.0,1000.0,1400.0,2000.0]
            numBreakPt = 7
        elif (depthType == 'verydeep'):
            breakPtDistInMeter = [0.0, 5.0*mileToMeter, 40.0*mileToMeter, 60.0*mileToMeter, 80.0*mileToMeter, 100.0*mileToMeter, -1.0]
            breakPtDepthInMeter=[0.0,80.0,700.0,1200.0,1700.0,2000.0,-1.0]
            numBreakPt = 6
        else:
            print('Input error! Depth type is not valid. exit!')
            sys.exit

        valueSet = 0

        for i in range(1, numBreakPt):
            if (dist < breakPtDistInMeter[i]):
                depth = (breakPtDepthInMeter[i]-breakPtDepthInMeter[i-1]) * (dist-breakPtDistInMeter[i-1]) \
                        / (breakPtDistInMeter[i]-breakPtDistInMeter[i-1]) + breakPtDepthInMeter[i-1]
                valueSet = 1
                break;
#       end for loop of i in range(1, numBreakPt)

        if (valueSet == 0):
            depth = 2000.0

#   end else block

    return depth;

# end getDpeth

# begin main get1DStormSurgeWithInput.py program
# Note: python array p(n) starts at index 0, and ends at index n-1
# if the selected point is 100, then it's array index is 99

print(len(sys.argv))

if len(sys.argv) < 7:
   print('usage: python get1DStormSurgeWithInput.py inputFile numTime')
   print(' shorelineLat depthType selectedPt outputFile')
   sys.exit('Missing input arguments, program exited!')

#***************************************************************
# read input variables
#***************************************************************

cnt = 0
for arg in sys.argv:
   print(arg)
   cnt+=1
   if cnt == 2:
      inputFile = arg
   elif cnt == 3:
      numTime = int(arg)
   elif cnt == 4:
      shorelineLat = float(arg)
   elif cnt == 5:
      depthType = arg
   elif cnt == 6:
      selectedPt = int(arg)
   elif cnt == 7:
      outputFile = arg

# end for loop for input arguments

#************************************************************************
# declare some constants
#************************************************************************

g = 9.8
gamma = 3.5e-6
C_k = 0.0025
rho_w = 1000.

#***************************************************************
# open the input file to read the header 
#***************************************************************

infile = open(inputFile, 'r')
line = infile.readline()
tokens = line.split()
print(tokens)
ni = int(tokens[0])
minLat = float(tokens[1])
maxLat = float(tokens[2])
dx = float(tokens[3])
dt = float(tokens[4])

print('read => ', ni, minLat, maxLat, dx, dt)
      
#***************************************************************
# allocate dynamic variables
#***************************************************************

h = [0.0 for i in range(ni)]      
rlat = [0.0 for i in range(ni)] 
wind = [0.0 for i in range(ni)]
windold = [0.0 for i in range(ni)]
u = [0.0 for i in range(ni)]
uold = [0.0 for i in range(ni)]
v = [0.0 for i in range(ni)]
vold = [0.0 for i in range(ni)]
p = [0.0 for i in range(ni)]
pold = [0.0 for i in range(ni)]
qavg = [0.0 for i in range(ni)]
qavgold = [0.0 for i in range(ni)]
s = [0.0 for i in range(ni)]
sold = [0.0 for i in range(ni)]

#***************************************************************
# read the lonlat points, and calculate the depth in the specific depth type 
#***************************************************************

for i in range(0, ni):
    line = infile.readline();
    tokens = line.split()
    print(tokens)
    curIndex = int(tokens[0])
    rlon = float(tokens[1])
    rlat[i] = float(tokens[2])
    depth = getDepth(rlat[i], shorelineLat, depthType);
#   if (depth < minDepth):
#       h[i] = minDepth
#   else:
    h[i] = depth

#   print(i, rlat[i], h[i])

# end for loop of i in range (0, ni)

#***************************************************************
# read in the wind, u, v, and p values for the first time step
#***************************************************************
      
line = infile.readline()
tokens = line.split()
curIndex = int(tokens[0])
curTimeInSec = float(tokens[1])
loneye = float(tokens[2])
lateye = float(tokens[3])
      
for i in range(0, ni):
    line= infile.readline()
    tokens = line.split()
    curIndex = int(tokens[0])
    windold[i] = float(tokens[1])
    uWind = float(tokens[2])
    vWind = float(tokens[3])
    pold[i] = float(tokens[4])
    uold[i] = vWind
    vold[i] = -uWind

# end for loop of i in range (0, ni)

#***************************************************************
# initialize arrays qavgold and sold
#***************************************************************

sold[1] = 0.0
sxTotal = 0.0
syTotal = 0.0

for i in range(0, ni-1):

    if (rlat[i] >= shorelineLat or rlat[i+1] >= shorelineLat):
        break;

    windavg = ( wind[i]+wind[i+1] )/2.

    uavg = ( u[i]+u[i+1] )/2.

    vavg = ( v[i]+v[i+1] )/2.

    davg = ( h[i]+h[i+1] )/2.

#   qavgold[i]=davg*math.sqrt( gamma*windavg*math.abs(vavg)/C_k)*
#            math.tanh(math.sqrt(gamma*windavg*math.abs(vavg))*math.sqrt(C_k)*dt/davg)

    qavgold[i]=davg*math.sqrt(gamma*vavg*vavg/C_k)*math.tanh(math.sqrt(gamma*vavg*vavg)*math.sqrt(C_k)*dt/davg)

    if (vavg <= 0.0):
        qavgold[i]=-1.0*qavgold[i]

    favg=2*7.292e-5*((math.sin(math.radians(rlat[i])) + math.sin(math.radians(rlat[i+1])))/2.)

    sy=dx*favg*qavgold[i]/(g*davg)

#   sx=dx*gamma*windavg*uavg/(g*davg)

    sx=dx*gamma*uavg*uavg/(g*davg)
	  
    if (uavg <= 0.0):
        sx=-1.0*sx

    sold[i+1] = sx+sy+sold[i]

    sxTotal = sxTotal+sx
    syTotal = syTotal+sy

# end for loop of i in range(0, ni-1):

#***************************************************************
# open output files for the results and error checking variables at lat=29.95
#***************************************************************

out30 = open(outputFile, 'w')
out40 = open(outputFile+'.sx.'+str(selectedPt), 'w')
out41 = open(outputFile+'.sy.'+str(selectedPt), 'w')
out42 = open(outputFile+'.sp.'+str(selectedPt), 'w')
out43 = open(outputFile+'.stab_check.'+str(selectedPt), 'w')
out44 = open(outputFile+'.qavg.'+str(selectedPt), 'w')
out45 = open(outputFile+'.favg.'+str(selectedPt), 'w')
out46 = open(outputFile+'.sxTotal.'+str(selectedPt), 'w')
out47 = open(outputFile+'.syTotal.'+str(selectedPt), 'w')

#***************************************************************
# Continue to loop over the entire input file to read in the wind, u, v, and p 
# of every time step (starting 2nd time step, index 1) and calculate the surge values
#***************************************************************

for timeStep in range(1, numTime):

    line = infile.readline()
    tokens = line.split()
    curIndex = int(tokens[0])
    curTimeInSec = float(tokens[1])
    loneye = float(tokens[2])
    lateye = float(tokens[3])

    curTimeInHr = curTimeInSec / 3600.0

#   out30.write('%6d%12.4f%12.5f%12.5f' % ((timeStep+1), curTimeInHr, loneye, lateye))

#   hasOut = 0
#
#   if (k == 1):
#	out50 = open(inputFile//'.k=1', 'r')
#	hasOut = 1
#   elif (k == 25):
#       out50 = open(inputFile//".k=25", 'r')
#       hasOut = 1
#   elif (k == 49):
#       out50 = open(inputFile//".k=49", 'r')
#       hasOut = 1

    for i in range(0, ni):
        line = infile.readline()
        tokens = line.split()
        curIndex = int(tokens[0])
        wind[i] = float(tokens[1])
        uWind = float(tokens[2])
        vWind = float(tokens[3])
        p[i] = float(tokens[4])

#	if (hasOut):
#	    out50.write('%12.5f%12.5f' % (rlat[i], wind[i]))

        u[i] = vWind
        v[i] = -uWind

#   end for loop of i in range (0, ni)

#   if (hasOut):
#       out50.close()

    qavg[0]=0

#   s[0]=(101300.0-p[0])/(rho_w*g)
    s[0]=0
    sxTotal = 0.0
    syTotal = 0.0

    for i in range(0, ni-1):

        if (rlat[i] >= shorelineLat or rlat[i+1] >= shorelineLat):
            break;

        windavg = ( wind[i]+wind[i+1]+windold[i]+windold[i+1] )/4.

        vavg = ( v[i]+v[i+1]+vold[i]+vold[i+1] )/4.

        davg = ( h[i]+h[i+1]+sold[i]+sold[i+1] )/2.

        if (davg < 0.0):
            print('negative surge at timeStep=',timeStep, ' i=', i)
            print('h[i]=',h[i],' and sold[i]=',sold[i])
            print('h[i+1]=',h[i+1],' and sold[i+1]=',sold[i+1])
            break

#	print(i, qavgold(i), gamma, windavg, vavg, davg)

#       stab_check=math.sqrt(gamma*windavg*math.abs(vavg))*math.sqrt(C_k)*dt/davg
        stab_check=math.sqrt(gamma*vavg*vavg)*math.sqrt(C_k)*dt/davg

        if (stab_check <= 0.95):
#           qavg[i]=qavgold[i]+gamma*windavg*vavg*dt-(dt*C_k*qavgold[i]**2)/davg**2
            qavg[i]=qavgold[i]+gamma*vavg*vavg*dt-(dt*C_k*qavgold[i]**2)/davg**2
        else:
#           qavg[i]=davg*math.sqrt( gamma*windavg*math.abs(vavg)/C_k)
            qavg[i]=davg*math.sqrt( gamma*vavg*vavg/C_k)

            if (vavg <= 0.0):
                qavg[i]=-1.0*qavg[i]

#       end else block

#       val=math.sqrt( davg*davg*gamma*windavg*vavg/C_k)

#       if (qavg[i] > val):
#           qavg[i]=val

#       if (math.abs(vold[i+1]-v[i+1]) > (1.25/4)):
#           timeadj=240.*60.
#       if (math.abs(vold[i+1]-v[i+1]) > (2.5/4)):
#           timeadj=120.*60.
#       if (math.abs(vold[i+1]-v[i+1]) > (5.0/4)):
#           timeadj=60.*60.
#       if (math.abs(vold[i+1]-v[i+1]) > (10.0/4)):
#           timeadj=30.*60.
#       if (vold[i+1] >= 0.0 and v[i+1] <= 0.0):
#           timeadj=15.*60.
#       if (vold[i+1] <= 0.0 and v[i+1] >= 0.0):
#           timeadj=15.*60.

#       qavg[i]=davg*math.sqrt(gamma*windavg*math.abs(vavg)/C_k)*
#               math.tanh(math.sqrt(gamma*windavg*math.abs(vavg))*math.sqrt(C_k)*timeadj/davg)

#       qavg[i]=math.sqrt(davg*davg*gamma*windavg*math.abs(vavg)/C_k)

#       if (vavg <= 0.0):
#           qavg[i]=-1.0*qavg[i]

        favg=2*7.292e-5*((math.sin(math.radians(rlat[i])) + math.sin(math.radians(rlat[i+1])))/2.)

        sy=dx*favg*qavg[i]/(g*davg)

        windavg=(wind[i]+wind[i+1])/2.

        uavg=(u[i]+u[i+1])/2.

#       sx=dx*gamma*windavg*uavg/(g*davg)

        sx=dx*gamma*uavg*uavg/(g*davg)

        if (uavg <= 0.0):
            sx=-1.0*sx

#       if (i == (ni-1)):
#           print(dx,gamma,windavg,uavg,g,davg,sx)

#       sp=(p[i+1]-p[i])/(rho_w*g)
        sp=(101300.0-p[i+1])/(rho_w*g)

#       s[i+1]=sx+sy+sp+s[i]
        s[i+1]=sx+sy+s[i]

#
# do graphics and diagnostics here
#
# output variables at point selectedPt for time series and error-check 

        sxTotal = sxTotal + sx
        syTotal = syTotal + sy

        if ((i+1) == selectedPt):
            out40.write('%12.4f%15.8f\n' % (curTimeInHr, sx))
            out41.write('%12.4f%15.8f\n' % (curTimeInHr, sy))
            out42.write('%12.4f%15.8f\n' % (curTimeInHr, sp))
            out43.write('%12.4f%15.8f\n' % (curTimeInHr, stab_check))
            out44.write('%12.4f%15.8f\n' % (curTimeInHr, qavg[i]))
            out45.write('%12.4f%15.8f\n' % (curTimeInHr, favg))
            out46.write('%12.4f%15.8f\n' % (curTimeInHr, sxTotal))
            out47.write('%12.4f%15.8f\n' % (curTimeInHr, syTotal))
#       end if block

# output the last 12 hours - 1 min time step = 720 steps
# only output water points, hard-wired here

        if (timeStep >= (numTime - 720) and (i+1) <= 8792):

            if (((i+1) > 4000) and ((i+1) < 8000) and ((i+1) % 500) == 0):
                out30.write('%6d%12.4f%6d%12.5f%15.8f\n' % (timeStep+1, curTimeInHr, i+1, rlat[i+1], s[i+1]+sp))
            elif (((i+1) >= 8000) and ((i+1) % 100) == 0):
                out30.write('%6d%12.4f%6d%12.5f%15.8f\n' % (timeStep+1, curTimeInHr, i+1, rlat[i+1], s[i+1]+sp))
            elif ((i+1) >= 8700 and (i+1) <= 8792):
                out30.write('%6d%12.4f%6d%12.5f%15.8f\n' % (timeStep+1, curTimeInHr, i+1, rlat[i+1], s[i+1]+sp))

#       end if block

#   end for loop of i in range (0, ni-1)

    for i in range(0,ni):
        windold[i]=wind[i]
        vold[i]=v[i]
        uold[i]=u[i]
        pold[i]=p[i]
        qavgold[i]=qavg[i]
        sold[i]=s[i]

#   end for loop of i in range (0, ni)

# end for loop of timeStep in range(1, numTime)

infile.close()
out30.close()
out40.close()
out41.close()
out42.close()
out43.close()
out44.close()
out45.close()
out46.close()
out47.close()