
''' ---------------------------------------------------------------------------
Name: North Carolina AEP Tool

Version: 1.2

ArcGIS Pro Version: 3.4 (2025)

Author: Applied Weather Associates

Usage: Run script tool within ArcGIS Pro desktop environment

Required Arguments:
    - Input basin polygon feature/shapefile
    - Durations to analyze
    - Output Folder
    - (Optional) PMP point feature class (output of PMP tool) for each duration for AEP of PMP

Description:
    The tool calculates basin average AEP estimates, and 5%/95% confidence intervals, for each
    available return frequency and duration. Excel tables summarizing the AEP for
    each duration are created.  Chart image plots are produced as .png files for each duration.
    Optionally, the AEP of PMP is calculated and added to the plots.

---------------------------------------------------------------------------'''

###########################################################################
# import Python modules

import os
import arcpy
from arcpy import env
import math
import numpy as np
import arcpy.analysis as an
import arcpy.management as dm
import arcpy.conversion as con
import matplotlib as mpl
import matplotlib.pyplot as plt
import re
from arcpy.sa import *  # Spatial Analyst

# Set overwrite option
dm.Delete("memory")  ## Clear memory

# Check out Spatial Analyst Extension
arcpy.CheckOutExtension("Spatial")     

##############################################
# Parameter definitions
##############################################

#input parameters
basin = arcpy.GetParameter(0)                                                   # get AOI Basin Shapefile
aepDurations = arcpy.GetParameter(1)
outLocation = arcpy.GetParameterAsText(2)
getAEPofPMP = arcpy.GetParameter(3) 
pmpPointsList = arcpy.GetParameter(4)

scriptPath = os.getcwd()                                # get location of script Folder
home = os.path.dirname(scriptPath)                      # get location of 'PMP' Project Folder
arcpy.AddMessage("\nPMP_Evaluation_Tool folder filepath: " + home)

## Define 5% and 95% error for project
err5 = [0.8726, 0.9061, 0.9120, 0.9088, 0.8977, 0.8882, 0.8791, 0.8683, 0.8540, 0.8429, 0.8153, 0.8033, 0.7638, 0.7217, 0.6809, 0.6402, 0.6005, 0.5624]
err95 = [1.115, 1.110, 1.1050, 1.1000, 1.0950, 1.105, 1.115, 1.125, 1.135, 1.145, 1.1632, 1.1841, 1.2591, 1.3438, 1.4341, 1.5392, 1.6571, 1.7861]

##############################################
# Function definitions
############################################## 
def getAOIarea(aoiBasin):
    tempBasin = os.path.join(arcpy.env.scratchGDB, "areaBasin")
    dm.CopyFeatures(aoiBasin, tempBasin)
    dm.AddGeometryAttributes(tempBasin, "AREA_GEODESIC", "", "SQUARE_MILES_US", 102039)
    totalArea = 0.0
    with arcpy.da.SearchCursor(tempBasin, ['AREA_GEO']) as cursor:
        for row in cursor:
            totalArea += row[0]
    dm.Delete(tempBasin)
    # arcpy.AddMessage("\nTotal basin area:  " + str(round(totalArea,2)) + " sqmi")
    if totalArea < 1.0:
        aoiArea = round(totalArea, 2)
        arcpy.AddMessage(f"\nArea used for ARF analysis: {aoiArea:.2f} sqmi")
        return (aoiArea)
    else:
        aoiArea = round(totalArea, 0)
        arcpy.AddMessage(f"\nArea used for ARF analysis: {aoiArea:.0f} sqmi")
        return aoiArea

def getARF(basinArea, basinZone, dur):
    arcpy.AddMessage("\n\tCalculating basin ARF...")

    if basinZone == 1 or basinZone == 2:
        if dur == '24':
            b = 0.53348
            c = 0.03818
            e = 7798.4
        elif dur == '6':
            b = 0.65388
            c = 0.00097
            e = 452.18       
        else:
            arcpy.AddMessage("\nInvalid duration or basin location.  Returning ARF = 1.0")
            return 1.0

    if basinZone == 3:
        if dur == '24':
            b = 0.53348
            c = 0.03818
            e = 7798.4
        elif dur == '6':
            b = 0.71358
            c = -0.02341
            e = 285.784624       
        else:
            arcpy.AddMessage("\nInvalid duration or basin location.  Returning ARF = 1.0")
            return 1.0

    if basinZone == 4:
        if dur == '24':
            b = 0.53348
            c = 0.03818
            e = 7798.4
        elif dur == '6':
            b = 0.56148
            c = 0.17642
            e = 212.469739      
        else:
            arcpy.AddMessage("\nInvalid duration or basin location.  Returning ARF = 1.0")
            return 1.0

    y = c + ((1-c)/(1+math.exp(b*(math.log(basinArea)-math.log(e)))))
    y = round(y,3)
    arcpy.AddMessage("\tlog-logistic parameters:  ")
    arcpy.AddMessage("\t\tb = " + str(b))
    arcpy.AddMessage("\t\tc = " + str(c))
    arcpy.AddMessage("\t\te = " + str(e))

    arcpy.AddMessage("\n\tBasin ARF calculation for " + str(basinArea) + " sqmi and Transposition Zone " + str(basinZone) + " = " + str(y))
    return y

def getYearFromFilename(name):
    freq = re.search('pf_(.*)yr', name)  ## ensure this fits naming convention
    return int(freq.group(1))

def sortByYear(filename):           ## Sorts list of AEP filenames by frequency
  return sorted(filename, key=getYearFromFilename)

def getFieldAve(table, field):      # getFieldAve() returns the numeric average of all values in a field
    na = arcpy.da.TableToNumPyArray(table, [field])         # get average shifted SPAS total storm depths for each PMP point
    return np.mean(na[field])

def basinZone(basin):  ## This function returns the transposition zone of the the basin centroid and centroid coordinates
    tempBasin = env.scratchGDB + "\\tempBasin"
    disBasin = env.scratchGDB + "\\disBasin"
    tempCentroid = env.scratchGDB + "\\tempCentroid"
    joinFeat = home + "\\Input\\Non_Storm_Data.gdb\\Vector_Grid"
    joinOutput = env.scratchGDB + "\\joinOut"
    dm.Dissolve(basin, disBasin)
    desc = arcpy.Describe(disBasin)
    sr = desc.spatialReference
    geographic_sr = arcpy.SpatialReference(4326)    
    if sr.type != "Geographic":     # check if basin SR is geographic or projected.  
        dm.Project(disBasin, tempBasin, geographic_sr)  # Apply geographic SR
    else:
        tempBasin =  disBasin

    dm.CreateFeatureclass(env.scratchGDB,"tempCentroid","POINT",spatial_reference = geographic_sr)  # Create a temporary centroid point FC and add centroid as point
    with arcpy.da.InsertCursor(tempCentroid, "SHAPE@XY") as iCur:
        with arcpy.da.SearchCursor(tempBasin,"SHAPE@") as sCur:
            for sRow in sCur:
                cent = sRow[0].centroid             # get the centroid
                iCur.insertRow([(cent.X,cent.Y)])   # write it to the new feature class
    
    an.SpatialJoin(tempCentroid, joinFeat, joinOutput)
    centZone = arcpy.da.SearchCursor(joinOutput, ("ZONE",)).next()[0]
    #centDivide = arcpy.da.SearchCursor(joinOutput, ("DIVIDE",)).next()[0]
    del tempBasin, tempCentroid, joinFeat, joinOutput, desc, sr
    return (centZone, cent.X, cent.Y)

def getRasterMean(aepGRID):     # Return the mean gridded value over basin or value at basin centroid
    if basinArea < 0.1:     # If basins is below size threshold use point extraction
        centroidVal = arcpy.GetCellValue_management(aepGRID, f"{basinCentX} {basinCentY}").getOutput(0)
        # arcpy.AddMessage("\tRaster Value at Centroid: " + centroidVal)
        return round(float(centroidVal) * ARF,2)
    else:
        outRas = ExtractByMask(aepGRID, basin, 'INSIDE')
        if outRas.maximum == outRas.noDataValue:    # Check if extracted grid has values at all cells.  If so, return point value at centroid
            centroidVal = arcpy.GetCellValue_management(aepGRID, f"{basinCentX} {basinCentY}").getOutput(0)
            # arcpy.AddMessage("\tNoData at maximum: " + centroidVal)
            return round(float(centroidVal) * ARF,2)
        arfRas = outRas * ARF           ## Apply areal reduction factor to extracted precip grid
        return round(arfRas.mean,2)

##############################################################################################
##############################################################################################

# Describe basin and get location
desBasin = arcpy.Describe(basin)
basinName = desBasin.baseName
basinArea = getAOIarea(basin)
basinZone, basinCentX, basinCentY =  basinZone(basin)
arcpy.AddMessage(f"{basinName} basin location:\n\tTransposition Zone: {basinZone} \n\tCentroid: {basinCentX:.5f} {basinCentY:.5f}")
if not basinZone:
    arcpy.AddMessage("Warning! AEP not available basin location.  Check to ensure basin centroid is within project domain.  Exiting.")
    exit()

# Get basin AEP for each duration
for dur in aepDurations:
    arcpy.AddMessage("\n\nEvaluating basin precipitation frequency for the " + dur + "-hour duration...")

    # Get basin areal reduction factor for basin location and duration
    ARF = getARF(basinArea, basinZone, dur)

    # Get a list of AEP GRIDs from the workspace
    pfWS = home + r"\Input\AEP.gdb"  ## Set workspace containing AEP frequency grids
    env.workspace = pfWS
    wildcard = '*{}*'.format(dur)
    aepGRIDs = sortByYear(arcpy.ListRasters(wildcard))        ### Create list of AEP grids and sort by frequency
    #arcpy.AddMessage("\n\tFrequency Grids (count: " + str(len(aepGRIDs)) + ")")
    
    ##############################################
    # Calculate basin average precip for each freqeuncy
    ##############################################
    yrList = []
    aepList = []
    precList = []
    lwrList = []
    uprList = []

    arcpy.AddMessage("\n\tCalculating " + dur + "-hour basin average precipitation for each AEP...")
    for aepGRID in aepGRIDs:
        # arcpy.AddMessage("\t" + aepGRID)
        yr = getYearFromFilename(aepGRID)
        rasMean = getRasterMean(aepGRID)    # Get basin AEP from raster
        arcpy.AddMessage("\n\t" + str(yr) + "-year basin average preciptation: " + str(rasMean) + '"')
        yrList.append(yr)                  ## Add year to list
        aepList.append(1.0 / yr)           ## Add year to list
        precList.append(rasMean)           ## Add precip to list at coresponding index
        lwr = round(err5[aepGRIDs.index(aepGRID)] * rasMean,2)  ## Calculate lower bounds precip based on index
        arcpy.AddMessage("\tLower Bounds: " + str(lwr) + 'in')
        lwrList.append(lwr)                ## Add lwr bounds precip to list
        upr = round(err95[aepGRIDs.index(aepGRID)] * rasMean,2)  ## Calculate upper bounds precip based on index
        arcpy.AddMessage("\tUpper Bounds: " + str(upr) + 'in')
        uprList.append(upr)                ## Add lwr bounds precip to list
        dm.Delete(r"memory\clipRas")

    ##############################################
    # Calculate AEP of PMP
    ##############################################
    
    pmpList = []
    ariPMPList = []
    aepPMPList = []
    stormTypeList = []
    try:
        for pmpPoints in pmpPointsList:
            if dur == "6":      ## change "6"-hour duration to "06" to fit PMP file naming convention
                durName = "06"
            else:
                durName = dur     
            descPoints = arcpy.Describe(pmpPoints)
            pmpName = descPoints.baseName
            pmpField = "PMP_" + durName
            basinPMP = round(getFieldAve(pmpPoints, "PMP_" + durName),1)    # Get basin average PMP depths from PMP point feature classes
            pmpList.append(basinPMP)
            ## Get storm type from filename
            if "Local" in pmpName:
                stormType = "Local"
                stormTypeList.append(stormType)
            elif "General" in pmpName:
                stormType = "General"
                stormTypeList.append(stormType)
            elif "Tropical" in pmpName:
                stormType = "Tropical"
                stormTypeList.append(stormType)
            else:
                arcpy.AddMessage("\n***'PMP_Points' Storm Type " + stormType + " could not be determined from input***")
                stormType = ""
                stormTypeList.append(stormType)
            arcpy.AddMessage("\n\t" + dur + "-hour " + stormType + " storm PMP: " + str(basinPMP))
            ### Calculate AEP of PMP with interp
            ariPMP = round(np.interp(basinPMP, precList, yrList, left=0.5, right=10000000000.5),0)
            ariPMPList.append("{0:,.0f}".format(ariPMP))
            aepPMP = "{:.2e}".format(1 / ariPMP)
            aepPMPList.append(aepPMP)
            arcpy.AddMessage("\tAEP of " + dur + "-hour " + stormType + " storm PMP: " + str(aepPMP))
    except:
        arcpy.AddMessage("\n***Warning! There is a problem with the PMP Points input.  AEP of PMP will not be calculated.***")   


    ##############################################
    # Create table of AEP values for dur
    ##############################################
    tableName = basinName + "_" + dur + "hr_AEP"
    tablePath = r"in_memory"
    aepTable = dm.CreateTable(tablePath, tableName)

    # dm.AddField(aepTable, field_name="ARI", field_type="TEXT", field_length="15", field_alias="Average Recurrence Interval (yrs)")
    dm.AddField(aepTable, field_name="AEP", field_type="DOUBLE", field_precision="10", field_alias="Annual Exceedance Probability")
    dm.AddField(aepTable, field_name="PPT_50pct", field_type="DOUBLE", field_precision="10", field_alias=r"Precip (in) 50% conf.")
    dm.AddField(aepTable, field_name="PPT_5pct", field_type="DOUBLE", field_precision="10", field_alias=r"Precip (in) 5% conf.")
    dm.AddField(aepTable, field_name="PPT_95pct", field_type="DOUBLE", field_precision="10", field_alias=r"Precip (in) 95% conf.")

    fieldNames = [field.name for field in arcpy.ListFields(aepTable)][1:]       ## List of field names.  [1:] to remove ObjectID
    # arcpy.AddMessage("Fields: " + str(fieldNames))

    ##############################################
    # Populate AEP tables and Export to Excel
    ##############################################
    i = 0
    rowValues = []
    for yr in yrList:
        # ARI = ('{:,}'.format(yrList[i]))         ## Using STRING datatype since ArcMap can only hand 32bit int (Pro can use BIGINT)
        AEP = aepList[i]
        ppt50 = precList[i]
        ppt5 = lwrList[i]
        ppt95 = uprList[i]
        rowValues.append([AEP, ppt50, ppt5, ppt95])
        i += 1
    # arcpy.AddMessage("Row Values: " + str(rowValues))
    with arcpy.da.InsertCursor(aepTable, fieldNames) as cursor:
        for row in rowValues:
            cursor.insertRow(row)

    con.TableToExcel(aepTable, outLocation + r"/" + tableName + ".xlsx", Use_field_alias_as_column_header="ALIAS")

    ##############################################
    # Create AEP chart plot figures
    ##############################################
    # Configure the plot
    fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True)

    # Plot precipitation vs aep on logarithmic scale
    prcpLine, = ax.loglog(aepList, precList, '-', color='b', linewidth=2, label=dur + 'hr Precipitation (in)')  ## Create log-scale plot of PF
    confLine1, = ax.loglog(aepList, lwrList, linestyle=':', color='b', label='95% Confidence Interval Bounds')  ## Create log-scale plot 5% precip
    confLine2, = ax.loglog(aepList, uprList, linestyle=':', color='b', label='95% Confidence Interval (in)')  ## Create log-scale plot 5% precip
    ax.invert_xaxis()

    # Add PMP lines for each stormtype
    i = 0
    colors = ['r','g','y']
    pmpLines = []  ## A list of to hold PMP lines that the Legend will access with Handles
    for pmp in pmpList:     ### Add each PMP line to list
        pmpLines.append(ax.axhline(y=pmp, linestyle='--', color=colors[i], label= dur + "hr " + stormTypeList[i] + ' Storm PMP'))
        i += 1

    # Set labels and title
    plt.xlabel('Annual Exceedance Probability')
    plt.ylabel(dur + 'hr Precipitation (in)')
    titleName = basinName + " (" + str(basinArea) + "-sqmi) Basin Average " + dur + "-hour Annual Exceedance Probability"
    plt.title(titleName)

    # Add legends
    legend1 = ax.legend(handles=[prcpLine, confLine1], loc='upper left')
    ax.add_artist(legend1)
    legend2 = ax.legend(handles=pmpLines, loc='lower left')
    ax.add_artist(legend2)

    # Format Axes
    x_major = mpl.ticker.LogLocator(base = 10.0)
    ax.xaxis.set_major_locator(x_major)
    ax.xaxis.set_major_formatter(mpl.ticker.LogFormatterSciNotation(base=10,minor_thresholds=(10,10)))     
    x_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
    ax.xaxis.set_minor_locator(x_minor)
    ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())   
    ax.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(0.1, 1e-10)

    # Add AEP of PMP to plot
    if getAEPofPMP:
        try:
            col_labels=['PMP (in)','AEP']
            row_labels = []  ## A list of to hold PMP lines that the Legend will access with Handles
            for stormType in stormTypeList:     ### Add each PMP line to list
                label = stormType + " Storm"
                row_labels.append(label)
            table_vals = np.array([pmpList,aepPMPList]).transpose()      # Table Values
            the_table = ax.table(cellText=table_vals,      ## Set table parameters
                              colWidths = [0.1]*3,
                              rowLabels=row_labels,
                              colLabels=col_labels,
                              cellLoc='center',
                              loc="lower right",
                              zorder=3)
            # the_table.scale(1, 1.5)
            # plt.text(12,3.4,'Table Title',size=8)
        except IndexError:
            if len(pmpPointsList) == 0:
                arcpy.AddMessage("\n***Warning! You must provide a PMP point feature class to get AEP of PMP.  AEP of PMP will not be calculated.***")
            else:
                arcpy.AddMessage("\n***Warning! There is a problem with the PMP Points input.  AEP of PMP will not be calculated.***")   
    
    # Save the plot as a figure
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(outLocation + r"/" + str(basinName) + "_" + dur + "hr"+ "_AEP_Plot.png")  #Save image
    plt.close()         #Close chart to remove from memory
    arcpy.AddMessage("\n" + dur + "-hour AEP Chart exported to output folder.")

dm.Delete("memory")