#include "IntegrateDataWriter.h"

#include <AMReX_MultiFabUtil.H>

#include <AMReX_ParReduce.H>


#include <cmath>

#include <limits>



using namespace amrex;

namespace simflowny {

   bool calc_int;
   bool calc_l2n;
   bool calc_min;
   bool calc_max;
   bool calc_amin;
   bool calc_amax;

   double gpu_integral;
   double gpu_l2norm;
   double gpu_min;
   double gpu_max;
   double gpu_absmin;
   double gpu_absmax;

   void
   WriteMultiLevelIntegralPlot (const std::string& plotdirname, 
                            int nlevels,
                            const Vector<const MultiFab*>& mf,
                            const Vector<std::string>& varnames,
                            const Vector<Geometry>& geom, 
                            Real time,
                            const Vector<IntVect>& ref_ratios, 
                            const Vector<std::string> &calculations)
   {
      BL_PROFILE("WriteMultiLevelIntegralPlot()");

      BL_ASSERT(nlevels <= mf.size());
      BL_ASSERT(nlevels <= geom.size());
      BL_ASSERT(mf[0]->nComp() == varnames.size());

      int finest_level = nlevels-1;
      int ncomp = mf[0]->nComp();

      IntegralData result_data;
      result_data.integral.resize(ncomp);
      result_data.l2norm.resize(ncomp);
      result_data.min.resize(ncomp);
      result_data.max.resize(ncomp);
      result_data.absmin.resize(ncomp);
      result_data.absmax.resize(ncomp);


      calc_int = std::find(calculations.begin(), calculations.end(), "INTEGRAL") != calculations.end();
      calc_l2n = std::find(calculations.begin(), calculations.end(), "L2NORM") != calculations.end();
      calc_min = std::find(calculations.begin(), calculations.end(), "MIN") != calculations.end();
      calc_max = std::find(calculations.begin(), calculations.end(), "MAX") != calculations.end();
      calc_amin = std::find(calculations.begin(), calculations.end(), "ABSMIN") != calculations.end();
      calc_amax = std::find(calculations.begin(), calculations.end(), "ABSMAX") != calculations.end();
      
      for (int n = 0; n < ncomp; n++) {
         result_data.integral[n] = 0.;
         result_data.l2norm[n] = 0.;
         result_data.min[n] = std::numeric_limits<amrex::Real>::max();
         result_data.max[n] = std::numeric_limits<amrex::Real>::min();
         result_data.absmin[n] = std::numeric_limits<amrex::Real>::max();
         result_data.absmax[n] = 0.;
      }


      calculateReductions(mf, ref_ratios, result_data, geom, finest_level);
      BL_PROFILE_VAR("WriteMultiLevelIntegralPlot(): MPI reduce", writeintegral1);
      if(std::find(calculations.begin(), calculations.end(), "INTEGRAL") != calculations.end()) {
         //Calculate local integral
         ParallelDescriptor::ReduceRealSum(result_data.integral.data(), ncomp, ParallelContext::IOProcessorNumberSub());
      }
      if(std::find(calculations.begin(), calculations.end(), "L2NORM") != calculations.end()) {
         //Calculate local integral
         ParallelDescriptor::ReduceRealSum(result_data.l2norm.data(), ncomp, ParallelContext::IOProcessorNumberSub());
      }
      if(std::find(calculations.begin(), calculations.end(), "MAX") != calculations.end()) {
         //Calculate local integral
         ParallelDescriptor::ReduceRealMax(result_data.max.data(), ncomp, ParallelContext::IOProcessorNumberSub());
      }
      if(std::find(calculations.begin(), calculations.end(), "MIN") != calculations.end()) {
         //Calculate local integral
         ParallelDescriptor::ReduceRealMin(result_data.min.data(), ncomp, ParallelContext::IOProcessorNumberSub());
      }
      if(std::find(calculations.begin(), calculations.end(), "ABSMAX") != calculations.end()) {
         //Calculate local integral
         ParallelDescriptor::ReduceRealMax(result_data.absmax.data(), ncomp, ParallelContext::IOProcessorNumberSub());
      }
      if(std::find(calculations.begin(), calculations.end(), "ABSMIN") != calculations.end()) {
         //Calculate local integral
         ParallelDescriptor::ReduceRealMin(result_data.absmin.data(), ncomp, ParallelContext::IOProcessorNumberSub());
      }
      BL_PROFILE_VAR_STOP(writeintegral1);
      BL_PROFILE_VAR("WriteMultiLevelIntegralPlot(): I/O", writeintegral2);


      for (int n = 0; n < ncomp; n++) {
         std::string varname = plotdirname + "/" + varnames[n];
         if(std::find(calculations.begin(), calculations.end(), "INTEGRAL") != calculations.end()) {
            //Calculate local integral
            //Real integral_value = result_data[n]->integral;
            //ParallelDescriptor::ReduceRealSum(integral_value);
            // Write variable files
            if (ParallelContext::IOProcessorSub()) {
               std::ofstream outputfile;
               char name[1024];
               std::string varname_int = varname + "_INTEGRAL";
               strcpy(name, varname_int.c_str());
               // Create file if does not exist, otherwise open at the end of the file
               if (time == 0) {
                  outputfile.open (name, std::ios::out);
               }
               else {
                  outputfile.open (name, std::ios::app);
               }
               outputfile << time << "\t" << result_data.integral[n] << std::endl;
               outputfile.close();
            }
         }
         if(std::find(calculations.begin(), calculations.end(), "L2NORM") != calculations.end()) {
            //Calculate local integral
            //Real integral_value = result_data[n]->l2norm;
            //ParallelDescriptor::ReduceRealSum(integral_value);
            // Write variable files
            if (ParallelContext::IOProcessorSub()) {
               std::ofstream outputfile;
               char name[1024];
               std::string varname_int = varname + "_L2NORM";
               strcpy(name, varname_int.c_str());
               // Create file if does not exist, otherwise open at the end of the file
               if (time == 0) {
                  outputfile.open (name, std::ios::out);
               }
               else {
                  outputfile.open (name, std::ios::app);
               }
               outputfile << time << "\t" <<result_data.l2norm[n]<< std::endl;
               outputfile.close();
            }
         }
         if(std::find(calculations.begin(), calculations.end(), "MAX") != calculations.end()) {
            //Calculate local integral
            //Real max_value = result_data[n]->max;
            //ParallelDescriptor::ReduceRealMax(max_value);
            // Write variable files
            if (ParallelContext::IOProcessorSub()) {
               std::ofstream outputfile;
               char name[1024];
               std::string varname_int = varname + "_MAX";
               strcpy(name, varname_int.c_str());
               // Create file if does not exist, otherwise open at the end of the file
               if (time == 0) {
                  outputfile.open (name, std::ios::out);
               }
               else {
                  outputfile.open (name, std::ios::app);
               }
               outputfile << time << "\t" << result_data.max[n] << std::endl;
               outputfile.close();
            }
         }
         if(std::find(calculations.begin(), calculations.end(), "MIN") != calculations.end()) {
            //Calculate local integral
            //Real min_value = result_data[n]->min;
            //ParallelDescriptor::ReduceRealMin(min_value);
            // Write variable files
            if (ParallelContext::IOProcessorSub()) {
               std::ofstream outputfile;
               char name[1024];
               std::string varname_int = varname + "_MIN";
               strcpy(name, varname_int.c_str());
               // Create file if does not exist, otherwise open at the end of the file
               if (time == 0) {
                  outputfile.open (name, std::ios::out);
               }
               else {
                  outputfile.open (name, std::ios::app);
               }
               outputfile << time << "\t" << result_data.min[n] << std::endl;
               outputfile.close();
            }
         }
         if(std::find(calculations.begin(), calculations.end(), "ABSMAX") != calculations.end()) {
            //Calculate local integral
            //Real absmax_value = result_data[n]->absmax;
            //ParallelDescriptor::ReduceRealMax(absmax_value);
            // Write variable files
            if (ParallelContext::IOProcessorSub()) {
               std::ofstream outputfile;
               char name[1024];
               std::string varname_int = varname + "_ABSMAX";
               strcpy(name, varname_int.c_str());
               // Create file if does not exist, otherwise open at the end of the file
               if (time == 0) {
                  outputfile.open (name, std::ios::out);
               }
               else {
                  outputfile.open (name, std::ios::app);
               }
               outputfile << time << "\t" << result_data.absmax[n] << std::endl;
               outputfile.close();
            }
         }
         if(std::find(calculations.begin(), calculations.end(), "ABSMIN") != calculations.end()) {
            //Calculate local integral
            //Real absmin_value = result_data[n]->absmin;
            //ParallelDescriptor::ReduceRealMin(absmin_value);
            // Write variable files
            if (ParallelContext::IOProcessorSub()) {
               std::ofstream outputfile;
               char name[1024];
               std::string varname_int = varname + "_ABSMIN";
               strcpy(name, varname_int.c_str());
               // Create file if does not exist, otherwise open at the end of the file
               if (time == 0) {
                  outputfile.open (name, std::ios::out);
               }
               else {
                  outputfile.open (name, std::ios::app);
               }
               outputfile << time << "\t" << result_data.absmin[n] << std::endl;
               outputfile.close();
            }
         }
      }
      BL_PROFILE_VAR_STOP(writeintegral2);
   }

   /*
    *************************************************************************
    *
    * Private function to perform the calculations
    *
    *************************************************************************
    */
    void

   calculateReductions(const Vector<const MultiFab*> &mf_nd_v,
                       const Vector<IntVect> &ref_ratios,
                       IntegralData& result_data,
                       const Vector<Geometry> geom,
                       int finest_level)
   {
      BL_PROFILE("calculateReductions()");

      // first create a vector of cell-centered multifabs with no components
      // we need it for the ParReduce loop

      amrex::Vector<amrex::MultiFab> mf_cc_v(finest_level + 1);

      for (int ilev = 0; ilev <= finest_level; ilev++) {
         const auto& mf_nd = *mf_nd_v[ilev];
         auto &mf_cc = mf_cc_v[ilev];
         int ncomp = 0;
         int ngrow = 0;

         mf_cc.define(
            amrex::convert(mf_nd.boxArray(), IntVect::TheCellVector()),
            mf_nd.DistributionMap(), ncomp, ngrow);
      }

      int ncomp = mf_nd_v[0]->nComp();
      amrex::Vector<IntegralDataTuple> level_integral_data(ncomp);
      amrex::Vector<MinMaxDataTuple> level_minmax_data(ncomp);

      // we need masks for all except the finest level so do those first

      for (int ilev = 0; ilev < finest_level; ++ilev) {
         int coarse_val = 0;
         int fine_val = 1;
         amrex::iMultiFab mask_cc = amrex::makeFineMask(
               mf_cc_v[ilev], mf_cc_v[ilev+1], amrex::IntVect::TheZeroVector(),
               ref_ratios[ilev], amrex::Periodicity::NonPeriodic(), coarse_val,
               fine_val);

         amrex::iMultiFab mask_nd = amrex::makeFineMask(
            *mf_nd_v[ilev], *mf_nd_v[ilev+1], amrex::IntVect::TheZeroVector(),
            ref_ratios[ilev], amrex::Periodicity::NonPeriodic(), coarse_val,
            fine_val);

         const auto& mask_cc_arrs = mask_cc.const_arrays();
         const auto& mask_nd_arrs = mask_nd.const_arrays();
         const auto& nd_arrs = mf_nd_v[ilev]->const_arrays();
         const auto dx = geom[ilev].CellSizeArray();
         Real dv = AMREX_D_TERM(dx[0], *dx[1], *dx[2]);

         for (int icomp = 0; icomp < ncomp; ++icomp) {
           // first do the integration over the cell-centered values
            level_integral_data[icomp] = amrex::ParReduce(
               amrex::TypeList<amrex::ReduceOpSum, amrex::ReduceOpSum>{},
               amrex::TypeList<amrex::Real, amrex::Real>{},
               mf_cc_v[ilev], amrex::IntVect::TheZeroVector(),
               [=] AMREX_GPU_DEVICE(int box_no, int i, int j, int k)
                  noexcept -> IntegralDataTuple
               {
                  if (mask_cc_arrs[box_no](i, j, k) == 0) {
                     const auto& nd_arr = nd_arrs[box_no];
                     amrex::Real cc_val
				#if (AMREX_SPACEDIM == 2)
                        = 0.25 * ( nd_arr(i,j  ,0,icomp) + nd_arr(i+1,j  ,0,icomp)
                                 + nd_arr(i,j+1,0,icomp) + nd_arr(i+1,j+1,0,icomp));
				#elif (AMREX_SPACEDIM == 3)
                        = 0.125 * ( nd_arr(i,j  ,k  ,icomp) + nd_arr(i+1,j  ,k  ,icomp)
                                  + nd_arr(i,j+1,k  ,icomp) + nd_arr(i+1,j+1,k  ,icomp)
                                  + nd_arr(i,j  ,k+1,icomp) + nd_arr(i+1,j  ,k+1,icomp)
                                  + nd_arr(i,j+1,k+1,icomp) + nd_arr(i+1,j+1,k+1,icomp));
					#endif
                     return { cc_val*dv, cc_val*cc_val*dv };
                  }
                  else {
                     return { 0.0, 0.0 };
                  }
                  });
				// now do the min/max over the node-centered values

            	level_minmax_data[icomp] = amrex::ParReduce(
               amrex::TypeList<amrex::ReduceOpMin, amrex::ReduceOpMax, amrex::ReduceOpMin, amrex::ReduceOpMax>{},
               amrex::TypeList<amrex::Real, amrex::Real, amrex::Real, amrex::Real>{},
               *mf_nd_v[ilev], amrex::IntVect::TheZeroVector(),
               [=] AMREX_GPU_DEVICE(int box_no, int i, int j, int k)
                  noexcept -> MinMaxDataTuple
               {
                  if (mask_nd_arrs[box_no](i, j, k) == 0) {
                     amrex::Real nd_val = nd_arrs[box_no](i, j, k, icomp);
                     amrex::Real abs_nd_val = std::abs(nd_val);
                     return { nd_val, nd_val, abs_nd_val, abs_nd_val };
                  }
                  else {
                     return {std::numeric_limits<amrex::Real>::max(),
                             std::numeric_limits<amrex::Real>::min(),
                             std::numeric_limits<amrex::Real>::max(), 0.0 };
                  }
			 });
         }
         amrex::Gpu::streamSynchronize();

        for (int icomp = 0; icomp < ncomp; ++icomp) {
            if(calc_int) {
               result_data.integral[icomp] = result_data.integral[icomp] + amrex::get<0>(level_integral_data[icomp]);
            }
            if(calc_l2n) {
               result_data.l2norm[icomp] = result_data.l2norm[icomp] + amrex::get<1>(level_integral_data[icomp]);
            }
            if(calc_min) {
               result_data.min[icomp] = std::min(result_data.min[icomp], amrex::get<0>(level_minmax_data[icomp]));
            }
            if(calc_max) {
               result_data.max[icomp] = std::max(result_data.max[icomp], amrex::get<1>(level_minmax_data[icomp]));
            }
            if(calc_amin) {
               result_data.absmin[icomp] = std::min(result_data.absmin[icomp], amrex::get<2>(level_minmax_data[icomp]));
            }
            if(calc_amax) {
               result_data.absmax[icomp] = std::max(result_data.absmax[icomp], amrex::get<3>(level_minmax_data[icomp]));
            }
         }
	  }
      // now do finest level with no mask
      const auto& nd_arrs = mf_nd_v[finest_level]->const_arrays();
      const auto dx = geom[finest_level].CellSizeArray();
      Real dv = AMREX_D_TERM(dx[0], *dx[1], *dx[2]);

      for (int icomp = 0; icomp < ncomp; ++icomp) {
         // first do the integration over the cell-centered values
         level_integral_data[icomp] = amrex::ParReduce(
            amrex::TypeList<amrex::ReduceOpSum, amrex::ReduceOpSum>{},
            amrex::TypeList<amrex::Real, amrex::Real>{},
            mf_cc_v[finest_level], amrex::IntVect::TheZeroVector(),
            [=] AMREX_GPU_DEVICE(int box_no, int i, int j, int k)
               noexcept -> IntegralDataTuple
            {
               const auto& nd_arr = nd_arrs[box_no];
               amrex::Real cc_val
#if (AMREX_SPACEDIM == 2)
                  = 0.25 * ( nd_arr(i,j  ,0,icomp) + nd_arr(i+1,j  ,0,icomp)
                           + nd_arr(i,j+1,0,icomp) + nd_arr(i+1,j+1,0,icomp));
#elif (AMREX_SPACEDIM == 3)
                  = 0.125 * ( nd_arr(i,j  ,k  ,icomp) + nd_arr(i+1,j  ,k  ,icomp)
                            + nd_arr(i,j+1,k  ,icomp) + nd_arr(i+1,j+1,k  ,icomp)
                            + nd_arr(i,j  ,k+1,icomp) + nd_arr(i+1,j  ,k+1,icomp)
                            + nd_arr(i,j+1,k+1,icomp) + nd_arr(i+1,j+1,k+1,icomp));
#endif
               return { cc_val*dv, cc_val*cc_val*dv };
            });

        // now do the min/maxes over the node-centered values
         level_minmax_data[icomp] = amrex::ParReduce(
            amrex::TypeList<amrex::ReduceOpMin, amrex::ReduceOpMax, amrex::ReduceOpMin, amrex::ReduceOpMax>{},
            amrex::TypeList<amrex::Real, amrex::Real, amrex::Real, amrex::Real>{},
            *mf_nd_v[finest_level], amrex::IntVect::TheZeroVector(),
            [=] AMREX_GPU_DEVICE(int box_no, int i, int j, int k)
               noexcept -> MinMaxDataTuple
            {
               amrex::Real nd_val = nd_arrs[box_no](i, j, k, icomp);
               amrex::Real abs_nd_val = std::abs(nd_val);
               return { nd_val, nd_val, abs_nd_val, abs_nd_val };
            });
      }

      amrex::Gpu::streamSynchronize();

      for (int icomp = 0; icomp < ncomp; ++icomp) {
         if(calc_int) {
            result_data.integral[icomp] = result_data.integral[icomp] + amrex::get<0>(level_integral_data[icomp]);
         }
         if(calc_l2n) {
            result_data.l2norm[icomp] = result_data.l2norm[icomp] + amrex::get<1>(level_integral_data[icomp]);
         }
         if(calc_min) {
            result_data.min[icomp] = std::min(result_data.min[icomp], amrex::get<0>(level_minmax_data[icomp]));
         }
         if(calc_max) {
            result_data.max[icomp] = std::max(result_data.max[icomp], amrex::get<1>(level_minmax_data[icomp]));
         }
         if(calc_amin) {
            result_data.absmin[icomp] = std::min(result_data.absmin[icomp], amrex::get<2>(level_minmax_data[icomp]));
         }
         if(calc_amax) {
            result_data.absmax[icomp] = std::max(result_data.absmax[icomp], amrex::get<3>(level_minmax_data[icomp]));
         }
      }
   }
}

