#include "flatten.h"

namespace BandRemoval {

Filter::Filter(PClip child, int nLowThreshold, int nHighThreshold, int nRadius, IScriptEnvironment *env) :
GenericVideoFilter(child), nLowThreshold(nLowThreshold), nHighThreshold(nHighThreshold), nRadius(nRadius)
{
   nWidth = vi.width;
   nHeight = vi.height;
}

Filter::~Filter()
{
}

static void ProcessPlane(Byte *pDst, int nDstPitch, const Byte *pSrc, int nSrcPitch,
                         int nWidth, int nHeight, int nLowThreshold, int nHighThreshold, int nRadius)
{
   const int nSquare = (2 * nRadius + 1) * (2 * nRadius + 1);
   for ( int y = 0; y < nRadius; y++, pSrc += nSrcPitch, pDst += nDstPitch )
      for ( int x = 0; x < nWidth; x++ )
         pDst[x] = pSrc[x];

   for ( int y = nRadius; y < nHeight - nRadius; y++, pSrc += nSrcPitch, pDst += nDstPitch )
   {
      for ( int x = 0; x < nRadius; x++ )
         pDst[x] = pSrc[x];
      for ( int x = nRadius; x < nWidth - nRadius; x++ )
      {
         int sum = 0;
         const Byte *p = &pSrc[x];
         for ( int j = -nRadius; j < nRadius + 1; j++ )
            for ( int i = -nRadius; i < nRadius + 1; i++ )
               sum += p[i+j*nSrcPitch];

         if ( sum >= pDst[x] * nSquare - nLowThreshold && sum <= pDst[x] * nSquare + nHighThreshold )
            pDst[x] = (sum * 2 + nSquare) / (nSquare * 2);
         else
            pDst[x] = pSrc[x];
      }
      for ( int x = nWidth - nRadius; x < nWidth; x++ )
         pDst[x] = pSrc[x];
   }

   for ( int y = 0; y < nRadius; y++, pSrc += nSrcPitch, pDst += nDstPitch )
      for ( int x = 0; x < nWidth; x++ )
         pDst[x] = pSrc[x];
}

PVideoFrame __stdcall Filter::GetFrame(int n, IScriptEnvironment* env)
{
   PVideoFrame dst = env->NewVideoFrame(vi);
   PVideoFrame src = child->GetFrame(n, env);

   ProcessPlane(dst->GetWritePtr(PLANAR_Y), dst->GetPitch(PLANAR_Y), src->GetReadPtr(PLANAR_Y),
      src->GetPitch(PLANAR_Y), nWidth, nHeight, nLowThreshold, nHighThreshold, nRadius);
   ProcessPlane(dst->GetWritePtr(PLANAR_U), dst->GetPitch(PLANAR_U), src->GetReadPtr(PLANAR_U),
      src->GetPitch(PLANAR_U), nWidth >> 1, nHeight >> 1, nLowThreshold, nHighThreshold, nRadius);
   ProcessPlane(dst->GetWritePtr(PLANAR_V), dst->GetPitch(PLANAR_V), src->GetReadPtr(PLANAR_V),
      src->GetPitch(PLANAR_V), nWidth >> 1, nHeight >> 1, nLowThreshold, nHighThreshold, nRadius);

   return dst;
}

}

AVSValue __cdecl CreateBandRemoval(AVSValue args, void *user_data, IScriptEnvironment* env)
{
   return new BandRemoval::Filter(args[0].AsClip(), args[1].AsInt(1), args[2].AsInt(1), args[3].AsInt(1), env);
}

extern "C" __declspec(dllexport) const char* __stdcall AvisynthPluginInit2(IScriptEnvironment* env)
{
   env->AddFunction("BandRemoval", "ciii", CreateBandRemoval, NULL);

   return("BandRemoval");
}
