#include "PointDataWriter.h"
#include <iostream>

#include <cstring>

using namespace amrex;

namespace simflowny {

AMREX_GPU_MANAGED int localMaxLevel;

   void
   WriteMultiLevelPointPlot(const std::string &plotdirname,
                int nlevels,
                const Vector<const MultiFab*> &mf,
                const Vector<std::string> &varnames,
                const Vector<Geometry> &geom,
                Real time,
                const Vector<Real> &coordinates)
   {
      BL_PROFILE("WriteMultiLevelPointPlot()");

      BL_ASSERT(nlevels <= mf.size());
      BL_ASSERT(nlevels <= geom.size());
      BL_ASSERT(mf[0]->nComp() == varnames.size());
      #if (AMREX_SPACEDIM == 2)
         BL_ASSERT(coordinates.size() >= 2);
      #endif
      #if (AMREX_SPACEDIM == 3)
         BL_ASSERT(coordinates.size() == 3);
      #endif

      int finest_level = nlevels-1;

      Gpu::ManagedVector<Real> result_data;
      result_data.resize(mf[0]->nComp(), 0);
      localMaxLevel = -1;
      //Calculate the local point (if possible) in the finest level
      calculatePoint(mf, coordinates, &result_data, &localMaxLevel, geom, finest_level);

      //Communicate maximum level in all processes
      int myProc = ParallelDescriptor::MyProc();
      int nProcs = ParallelDescriptor::NProcs();
      Vector<int> maxLevels(nProcs, 0);
      ParallelAllGather::AllGather(localMaxLevel, maxLevels.data(), ParallelContext::CommunicatorAll());

      //Get the process number that have the maximum level point
      int maxLevel = -1;
      int procIndex = -1;
      for (int i = 0; i < nProcs; i++) {
         if (maxLevel < maxLevels[i]) {
            maxLevel = maxLevels[i];
            procIndex = i;
         }
      }

      //Only the process with the higher level writes the plot
      if (myProc == procIndex) {
         for (int n = 0; n < mf[0]->nComp(); n++) {
            std::string varname = plotdirname + "/" + varnames[n];
            std::ofstream outputfile;
            char name[1024];
            std::string varname_int = varname + "_POINT";
            strcpy(name, varname_int.c_str());
            // Create file if does not exist, otherwise open at the end of the file
            if (time == 0) {
               outputfile.open (name, std::ios::out);
            }
            else {
               outputfile.open (name, std::ios::app);
            }
            outputfile << time << "\t" << result_data[n]<< std::endl;
            outputfile.close();
         }
      }
   }

   /*
    *************************************************************************
    *
    * Private function to perform the calculations
    *
    *************************************************************************
    */
   void 
   calculatePoint(const Vector<const MultiFab*> &mf_v, 
                  const Vector<Real> coordinates, 
                  Gpu::ManagedVector<Real>* result_data, 
                  int* maxLevel, 
                  const Vector<Geometry> geom,
                  int finest_level)
   {

      for (int lev = 0; lev <= finest_level; lev++) {
         const MultiFab* data_mf = mf_v[lev];

         const auto dx     = geom[lev].CellSizeArray();

         const auto problo = geom[lev].ProbLoArray();

         //Calculate point index
         int i = round((coordinates[0] - problo[0])/dx[0]);
         int j = round((coordinates[1] - problo[1])/dx[1]);
         int k = 0;
      #if (AMREX_SPACEDIM == 3)
         k = round((coordinates[2] - problo[2])/dx[2]);
      #endif

   #ifdef AMREX_USE_OMP
   #pragma omp parallel if (Gpu::notInLaunchRegion())
   #endif
         {
           for (MFIter mfi(*data_mf, TilingIfNotGPU()); mfi.isValid(); ++mfi)
           {
               // Set up tileboxes and nodal tileboxes
               const Box& bx = mfi.tilebox();

               const auto loVec = bx.loVect();
               const auto hiVec = bx.hiVect();

               //Check if index is inside the current box
               if (
                  loVec[0] <= i && hiVec[0] >= i
                  && loVec[1] <= j && hiVec[1] >= j
               #if (AMREX_SPACEDIM == 3)
                  && loVec[2] <= k && hiVec[2] >= k
               #endif
                  ) {

                  // Grab fab pointers from state multifabs
                  Array4<const Real> data  = data_mf->array(mfi);

               Real * result_data_ptr = (*result_data).dataPtr();

                  amrex::ParallelFor(data_mf->nComp(),
                  [=] AMREX_GPU_DEVICE (int n) noexcept
                  {
                     result_data_ptr[n] = data(i, j, k, n);
                     *maxLevel = lev;
                  });
               }
            }
         }
      }


   }

}
