module nonlinear_terms
!
! not correct for slowing down species
! terms like df/dU should involve anon, etc.
!
! missing factors of 1/B in A_perp terms??
!
  implicit none

  public :: init_nonlinear_terms
  public :: add_nonlinear_terms, finish_nl_terms
  public :: finish_init, reset_init, algorithm, nonlin

  private

  ! knobs
  integer :: nonlinear_mode_switch
  integer :: flow_mode_switch

  integer, parameter :: nonlinear_mode_none = 1, nonlinear_mode_on = 2
  integer, parameter :: flow_mode_off = 1, flow_mode_on = 2

  !complex, dimension(:,:), allocatable :: phi_avg, apar_avg, aperp_avg  

  real, dimension (:,:), allocatable :: ba, gb, bracket
  ! yxf_lo%ny, yxf_lo%llim_proc:yxf_lo%ulim_alloc

  real, dimension (:,:,:), allocatable :: aba, agb, abracket
  ! 2*ntgrid+1, 2, accelx_lo%llim_proc:accelx_lo%ulim_alloc

  !complex, dimension (:,:), allocatable :: xax, xbx, g_xf
  ! xxf_lo%nx, xxf_lo%llim_proc:xxf_lo%ulim_alloc

! CFL coefficients
  real :: cfl, cflx, cfly

! hyperviscosity coefficients
  real :: C_par, C_perp, p_x, p_y, p_z

  integer :: algorithm = 1
  logical :: nonlin = .false.
  logical :: initialized = .false.
  logical :: initializing = .true.
  logical :: alloc = .true.
  logical :: zip = .false.
  logical :: accelerated = .false.
  
contains
  
  subroutine init_nonlinear_terms 
    use theta_grid, only: init_theta_grid, ntgrid
    use kt_grids, only: init_kt_grids, naky, ntheta0, nx, ny, akx, aky
    use le_grids, only: init_le_grids, nlambda, negrid
    use species, only: init_species, nspec
    use run_parameters, only: tnorm
    use gs2_layouts, only: init_dist_fn_layouts, yxf_lo, accelx_lo
    use gs2_layouts, only: init_gs2_layouts
    use gs2_transforms, only: init_transforms
    implicit none
    logical :: dum1, dum2

    if (initialized) return
    initialized = .true.
    
    call init_gs2_layouts
    call init_theta_grid
    call init_kt_grids
    call init_le_grids (dum1, dum2)
    call init_species
    call init_dist_fn_layouts (ntgrid, naky, ntheta0, nlambda, negrid, nspec)

    call read_parameters

    if (nonlinear_mode_switch == nonlinear_mode_on)  then
       algorithm = 1 
       nonlin = .true.
    end if

    if (nonlinear_mode_switch /= nonlinear_mode_none) then
       call init_transforms (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny, accelerated)

       if (accelerated) then
          if (alloc) then
             allocate (     aba(2*ntgrid+1, 2, accelx_lo%llim_proc:accelx_lo%ulim_alloc))
             allocate (     agb(2*ntgrid+1, 2, accelx_lo%llim_proc:accelx_lo%ulim_alloc))
             allocate (abracket(2*ntgrid+1, 2, accelx_lo%llim_proc:accelx_lo%ulim_alloc))
             alloc = .false.
          endif
          aba = 0. ; agb = 0. ; abracket = 0.
       else
          if (alloc) then
             allocate (     ba(yxf_lo%ny,yxf_lo%llim_proc:yxf_lo%ulim_alloc))
             allocate (     gb(yxf_lo%ny,yxf_lo%llim_proc:yxf_lo%ulim_alloc))
             allocate (bracket(yxf_lo%ny,yxf_lo%llim_proc:yxf_lo%ulim_alloc))
             alloc = .false.
          endif
          ba = 0. ; gb = 0. ; bracket = 0.
       end if

       cfly = aky(naky)/cfl*0.5/tnorm
       cflx = akx((ntheta0+1)/2)/cfl*0.5/tnorm
    end if

  end subroutine init_nonlinear_terms

  subroutine read_parameters
    use file_utils, only: input_unit, input_unit_exist, error_unit
    use text_options
    use mp, only: proc0, broadcast
    implicit none
    type (text_option), dimension (4), parameter :: nonlinearopts = &
         (/ text_option('default', nonlinear_mode_none), &
            text_option('none', nonlinear_mode_none), &
            text_option('off', nonlinear_mode_none), &
            text_option('on', nonlinear_mode_on) /)
    character(20) :: nonlinear_mode
    type (text_option), dimension (3), parameter :: flowopts = &
         (/ text_option('default', flow_mode_off), &
            text_option('off', flow_mode_off), &
            text_option('on', flow_mode_on) /)
    character(20) :: flow_mode
    namelist /nonlinear_terms_knobs/ nonlinear_mode, flow_mode, cfl, &
         C_par, C_perp, p_x, p_y, p_z, zip
    integer :: ierr, in_file
    logical :: exist
    logical :: done = .false.

    if (done) return
    done = .true.

    if (proc0) then
       nonlinear_mode = 'default'
       flow_mode = 'default'
       cfl = 0.1
       C_par = 0.1
       C_perp = 0.1
       p_x = 6.0
       p_y = 6.0
       p_z = 6.0

       in_file=input_unit_exist("nonlinear_terms_knobs",exist)
       if(exist) read (unit=in_file,nml=nonlinear_terms_knobs)

       ierr = error_unit()
       call get_option_value &
            (nonlinear_mode, nonlinearopts, nonlinear_mode_switch, &
            ierr, "nonlinear_mode in nonlinear_terms_knobs")
       call get_option_value &
            (flow_mode, flowopts, flow_mode_switch, &
            ierr, "flow_mode in nonlinear_terms_knobs")
    end if

    call broadcast (nonlinear_mode_switch)
    call broadcast (flow_mode_switch)
    call broadcast (cfl)
    call broadcast (C_par) 
    call broadcast (C_perp) 
    call broadcast (p_x)
    call broadcast (p_y)
    call broadcast (p_z)
    call broadcast (zip)

    if (flow_mode_switch == flow_mode_on) then
       if (proc0) write(*,*) 'Forcing flow_mode = off'
       flow_mode_switch = flow_mode_off
    endif

  end subroutine read_parameters

  subroutine add_nonlinear_terms (g0, g1, g2, phi, apar, aperp, istep, dt_cfl, bd, fexp)
    use theta_grid, only: ntgrid
    use gs2_layouts, only: g_lo
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g0, g1, g2
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi,    apar,    aperp
    integer, intent (in) :: istep
    real, intent (in) :: bd
    complex, intent (in) :: fexp
    real :: dt_cfl

    select case (nonlinear_mode_switch)
    case (nonlinear_mode_none)
       dt_cfl = 1.e4
       ! nothing
    case (nonlinear_mode_on)
       if (istep /= 0) call add_nl (g0, g1, g2, phi, apar, aperp, istep, dt_cfl, bd, fexp)
    end select
  end subroutine add_nonlinear_terms

  subroutine add_nl (g0, g1, g2, phi, apar, aperp, istep, dt_cfl, bd, fexp)
    use mp, only: max_allreduce
    use theta_grid, only: ntgrid, kxfac
    use gs2_layouts, only: g_lo, ik_idx, it_idx, il_idx, is_idx
    use gs2_layouts, only: accelx_lo, yxf_lo
    use dist_fn_arrays, only: g, ittp
    use species, only: spec
    use gs2_transforms, only: transform2, inverse2
    use run_parameters, only: delt, fapar, faperp, fphi
    use kt_grids, only: aky, akx, ntheta0
    use le_grids, only: forbid
    use gs2_time, only: save_dt
    use constants, only: zi
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g0, g1, g2
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi, apar, aperp
    integer, intent (in) :: istep
    real, intent (in) :: bd
    complex, intent (in) :: fexp
    integer :: istep_last = 0
    integer :: i, j, k
    real :: max_vel, zero
    real :: dt_cfl

    integer :: iglo, ik, it, is, ig, il, ia, isgn
    
    if (initializing) then
       dt_cfl = 1.e4
       return
    endif

    if (istep /= istep_last) then

       zero = epsilon(0.0)
       g2 = g1

       if (fphi > zero) then
          call load_kx_phi
       else
          g1 = 0.
       end if

       if (faperp > zero) call load_kx_aperp
       if (fapar  > zero) call load_kx_apar

       if (accelerated) then
          call transform2 (g1, aba, ia)
       else
          call transform2 (g1, ba)
       end if

       if (fphi > zero) then
          call load_ky_phi
       else
          g1 = 0.
       end if
       if (faperp > zero) call load_ky_aperp

! more generally, there should probably be a factor of anon...

       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             do ig = -ntgrid, ntgrid
                g1(ig,isgn,iglo) = g1(ig,isgn,iglo)*spec(is)%zt + zi*aky(ik)*g(ig,isgn,iglo)
             end do
          end do
       end do

       if (accelerated) then
          call transform2 (g1, agb, ia)
       else
          call transform2 (g1, gb)
       end if

       if (accelerated) then
          max_vel = 0.
          do k = accelx_lo%llim_proc, accelx_lo%ulim_proc
             do j = 1, 2
                do i = 1, 2*ntgrid+1
                   abracket(i,j,k) = aba(i,j,k)*agb(i,j,k)*kxfac
                   max_vel = max(max_vel, abs(aba(i,j,k)))
                end do
             end do
          end do
          max_vel = max_vel * cfly
!          max_vel = maxval(abs(aba)*cfly)
       else
          max_vel = 0.
          do j = yxf_lo%llim_proc, yxf_lo%ulim_proc
             do i = 1, yxf_lo%ny
                bracket(i,j) = ba(i,j)*gb(i,j)*kxfac
                max_vel = max(max_vel,abs(ba(i,j)))
             end do
          end do
          max_vel = max_vel*cfly
!          max_vel = maxval(abs(ba)*cfly)
       endif

       if (fphi > zero) then
          call load_ky_phi
       else
          g1 = 0.
       end if

       if (faperp > zero) call load_ky_aperp
       if (fapar  > zero) call load_ky_apar

       if (accelerated) then
          call transform2 (g1, aba, ia)
       else
          call transform2 (g1, ba)
       end if

       if (fphi > zero) then
          call load_kx_phi
       else
          g1 = 0.
       end if

       if (faperp > zero) call load_kx_aperp

! more generally, there should probably be a factor of anon...

       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          it = it_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             do ig = -ntgrid, ntgrid
                g1(ig,isgn,iglo) = g1(ig,isgn,iglo)*spec(is)%zt + zi*akx(it)*g(ig,isgn,iglo)
             end do
          end do
       end do

       if (accelerated) then
          call transform2 (g1, agb, ia)
       else
          call transform2 (g1, gb)
       end if

       if (accelerated) then
          do k = accelx_lo%llim_proc, accelx_lo%ulim_proc
             do j = 1, 2
                do i = 1, 2*ntgrid+1
                   abracket(i,j,k) = abracket(i,j,k) - aba(i,j,k)*agb(i,j,k)*kxfac
                   max_vel = max(max_vel, abs(aba(i,j,k))*cflx)
                end do
             end do
          end do
       else
          do j = yxf_lo%llim_proc, yxf_lo%ulim_proc
             do i = 1, yxf_lo%ny
                bracket(i,j) = bracket(i,j) - ba(i,j)*gb(i,j)*kxfac
                max_vel = max(max_vel,abs(ba(i,j))*cflx)
             end do
          end do
       end if

       call max_allreduce(max_vel)

       dt_cfl = 1./max_vel
       call save_dt (delt, dt_cfl)
       
       if (accelerated) then
          call inverse2 (abracket, g1, ia)
       else
          call inverse2 (bracket, g1)
       end if
          
! factor of one-half appears elsewhere
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          il = il_idx(g_lo, iglo)
! Totally trapped particles get no bakdif 
          do ig = -ntgrid, ntgrid-1
             if (il == ittp(ig)) cycle
             g1(ig,1,iglo) = (1.+bd)*g1(ig+1,1,iglo) + (1.-bd)*g1(ig,1,iglo)
             g1(ig,2,iglo) = (1.-bd)*g1(ig+1,2,iglo) + (1.+bd)*g1(ig,2,iglo)
          end do
! zero out spurious g1 outside trapped boundary
          where (forbid(:,il))
             g1(:,1,iglo) = 0.0
             g1(:,2,iglo) = 0.0
          end where
       end do

    endif

    if (zip) then
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          it = it_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
!          if (it == 3 .or. it == ntheta0-1) then
!          if (it /= 1) then
          if (ik /= 1) then
!          if (ik == 2 .and. it == 1) then
             g (:,1,iglo) = 0.
             g (:,2,iglo) = 0.
             g1(:,1,iglo) = 0.
             g1(:,2,iglo) = 0.
          end if
       end do
    end if
     
    istep_last = istep

  contains

    subroutine load_kx_phi

      use dist_fn_arrays, only: aj0
      complex :: fac

      do iglo = g_lo%llim_proc, g_lo%ulim_proc
         it = it_idx(g_lo,iglo)
         ik = ik_idx(g_lo,iglo)
         do ig = -ntgrid, ntgrid
            fac = zi*akx(it)*aj0(ig,iglo)*phi(ig,it,ik)*fphi
            g1(ig,1,iglo) = fac
            g1(ig,2,iglo) = fac
         end do
      end do

    end subroutine load_kx_phi

    subroutine load_ky_phi

      use dist_fn_arrays, only: aj0
      complex :: fac

      do iglo = g_lo%llim_proc, g_lo%ulim_proc
         it = it_idx(g_lo,iglo)
         ik = ik_idx(g_lo,iglo)
         do ig = -ntgrid, ntgrid
            fac = zi*aky(ik)*aj0(ig,iglo)*phi(ig,it,ik)*fphi
            g1(ig,1,iglo) = fac
            g1(ig,2,iglo) = fac
         end do
      end do

    end subroutine load_ky_phi

! should I use vpa or vpac in next two routines??

    subroutine load_kx_apar

      use dist_fn_arrays, only: vpa, aj0
      use gs2_layouts, only: is_idx

      do iglo = g_lo%llim_proc, g_lo%ulim_proc
         it = it_idx(g_lo,iglo)
         ik = ik_idx(g_lo,iglo)
         is = is_idx(g_lo,iglo)
         do ig = -ntgrid, ntgrid
            g1(ig,1,iglo) = g1(ig,1,iglo) - zi*akx(it)*aj0(ig,iglo)*spec(is)%stm &
                 *vpa(ig,1,iglo)*apar(ig,it,ik)*fapar 
         end do
         do ig = -ntgrid, ntgrid
            g1(ig,2,iglo) = g1(ig,2,iglo) - zi*akx(it)*aj0(ig,iglo)*spec(is)%stm &
                 *vpa(ig,2,iglo)*apar(ig,it,ik)*fapar 
         end do
      end do

    end subroutine load_kx_apar

    subroutine load_ky_apar

      use dist_fn_arrays, only: vpa, aj0
      use gs2_layouts, only: is_idx

      do iglo = g_lo%llim_proc, g_lo%ulim_proc
         it = it_idx(g_lo,iglo)
         ik = ik_idx(g_lo,iglo)
         is = is_idx(g_lo,iglo)
         do ig = -ntgrid, ntgrid
            g1(ig,1,iglo) = g1(ig,1,iglo) - zi*aky(ik)*aj0(ig,iglo)*spec(is)%stm &
                 *vpa(ig,1,iglo)*apar(ig,it,ik)*fapar 
         end do
         do ig = -ntgrid, ntgrid
            g1(ig,2,iglo) = g1(ig,2,iglo) - zi*aky(ik)*aj0(ig,iglo)*spec(is)%stm &
                 *vpa(ig,2,iglo)*apar(ig,it,ik)*fapar 
         end do
      end do

    end subroutine load_ky_apar

    subroutine load_kx_aperp

      use dist_fn_arrays, only: vperp2, aj1
      use gs2_layouts, only: is_idx
      complex :: fac

! Is this factor of two from the old normalization?

      do iglo = g_lo%llim_proc, g_lo%ulim_proc
         it = it_idx(g_lo,iglo)
         ik = ik_idx(g_lo,iglo)
         is = is_idx(g_lo,iglo)
         do ig = -ntgrid, ntgrid
            fac = g1(ig,1,iglo) + zi*akx(it)*aj1(ig,iglo) &
                 *2.0*vperp2(ig,iglo)*spec(is)%tz*aperp(ig,it,ik)*faperp
            g1(ig,1,iglo) = fac
            g1(ig,2,iglo) = fac
         end do
      end do

    end subroutine load_kx_aperp

    subroutine load_ky_aperp

      use dist_fn_arrays, only: vperp2, aj1
      use gs2_layouts, only: is_idx
      complex :: fac

! Is this factor of two from the old normalization?

      do iglo = g_lo%llim_proc, g_lo%ulim_proc
         it = it_idx(g_lo,iglo)
         ik = ik_idx(g_lo,iglo)
         is = is_idx(g_lo,iglo)
         do ig = -ntgrid, ntgrid
            fac = g1(ig,1,iglo) + zi*aky(ik)*aj1(ig,iglo) &
                 *2.0*vperp2(ig,iglo)*spec(is)%tz*aperp(ig,it,ik)*faperp
            g1(ig,1,iglo) = fac 
            g1(ig,2,iglo) = fac
         end do
      end do

    end subroutine load_ky_aperp

  end subroutine add_nl

  subroutine finish_nl_terms

    if (nonlinear_mode_switch == nonlinear_mode_none) return
!    deallocate (ba)
!    deallocate (gb)
!    deallocate (bracket)
!    alloc = .true.

  end subroutine finish_nl_terms

  subroutine reset_init
    
    initialized = .false.
    initializing = .true.
    call finish_nl_terms

  end subroutine reset_init

  subroutine finish_init

    initializing = .false.

  end subroutine finish_init

end module nonlinear_terms


