module transforms

  use redistribute, only: redist_type
  use par, only: nx0_tot, ny0_tot, nz0_tot
  use par, only: nx_fft, ny_fft, nz_fft
  implicit none

  public :: init_transforms
  public :: transform_2d, inverse_2d
  public :: transform_3d, inverse_3d

  type (redist_type) :: p2x, x2y, y2z

     real, dimension (nx0_tot+1, nx_fft) :: x_data
  complex, dimension (nx0_tot+1, nx_fft) :: kx_data
  complex, dimension (ny0_tot+1, ny_fft) :: y_data
  complex, dimension (nz0_tot+1, nz_fft) :: z_data
    
# ifdef T3E

  integer, parameter :: nwork_x = 2 * nx0_tot
  integer, parameter :: nwork_y = 4 * ny0_tot
  integer, parameter :: nwork_z = 4 * nz0_tot

  integer, parameter :: ntable_x = 2 * nx0_tot
  integer, parameter :: ntable_y = 2 * ny0_tot
  integer, parameter :: ntable_z = 2 * nz0_tot

# else

  integer, parameter :: nwork_x = 4 + 4*nx0_tot
  integer, parameter :: nwork_y = 8 * ny0_tot
  integer, parameter :: nwork_z = 8 * nz0_tot

  integer, parameter :: ntable_x = 100 + 4*nx0_tot
  integer, parameter :: ntable_y = 100 + 8*ny0_tot
  integer, parameter :: ntable_z = 100 + 8*nz0_tot

# endif

  real, dimension (nwork_x) :: work_x
  real, dimension (nwork_y) :: work_y

  real, dimension (ntable_x) :: table_x
  real, dimension (ntable_y) :: table_y

#  if NZ != 1
  real, dimension (nwork_z) :: work_z
  real, dimension (ntable_z) :: table_z
#  endif

  private

contains

  subroutine init_transforms 

    implicit none
    logical :: initialized = .false.

    if (initialized) return
    initialized = .true.

    call init_x_fft
    call init_y_fft

#  if NZ != 1
    call init_z_fft
#  endif

  end subroutine init_transforms

  subroutine init_x_fft

    use par, only: nx0_tot

    call init_x_redist
    call scfft(0, nx0_tot, 0., 0., 0., table_x, work_x, 0)

  end subroutine init_x_fft

  subroutine init_y_fft

    use par, only: ny0_tot

    call init_y_redist
    call ccfft(0, ny0_tot, 0., 0., 0., table_y, work_y, 0)

  end subroutine init_y_fft

  subroutine init_z_fft

    use par, only: nz0_tot

#  if NZ != 1
    call init_z_redist
    call ccfft(0, nz0_tot, 0., 0., 0., table_z, work_z, 0)
#  endif

  end subroutine init_z_fft

  subroutine init_x_redist

    use par, only: nx0, ny0, nz0, k3,nzm2
    use layouts, only: init_layouts, lo, x_lo, idx_local, proc_id, idx
    use layouts, only: pidx2xidx
    use mp, only: iproc, npe
    use redistribute, only: index_list_type, init_redist, delete_list
    implicit none
    type (index_list_type), dimension (0:npe-1) :: to_list, from_list
    integer, dimension (0:npe-1) :: nn_to, nn_from
    logical :: initialized = .false.
    integer :: i, j, k, iblock, jblock, kblock, ix, jx, iv, ip, n
    integer, dimension (4) :: from_low
    integer, dimension (2) :: to_high
    integer, dimension (1) :: from_high
    integer :: to_low

    if (initialized) return
    initialized = .true.

    call init_layouts

    ! count number of elements to be redistributed to/from each processor
    nn_to = 0
    nn_from = 0
    do iv=1,3
       do kblock = 0, lo % nzblock-1
          do jblock = 0, lo % nyblock-1
             do iblock = 0, lo % nxblock-1
                do k=k3,nzm2
                   do j=3,ny0+2
                      do i=3,nx0+2
                         call pidx2xidx (i, j, k, iv, iblock, jblock, kblock, ix, jx)
                         if (idx_local (lo, iblock, jblock, kblock)) &
                              nn_from( proc_id(x_lo, jx)) = nn_from(proc_id(x_lo, jx))+1
                         if (idx_local (x_lo, jx)) &                       
                              nn_to( proc_id(lo, iblock, jblock, kblock)) = &
                              nn_to( proc_id(lo, iblock, jblock, kblock)) + 1
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

! allocate space for index lists
    do ip = 0, npe-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip) % first (nn_from(ip)))
          allocate (from_list(ip) % second(nn_from(ip)))
          allocate (from_list(ip) % third (nn_from(ip)))
          allocate (from_list(ip) % fourth(nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip) % first (nn_to(ip)))
          allocate (to_list(ip) % second(nn_to(ip)))
       end if
    end do

! get local indices of elements distributed to/from other processors
    nn_to = 0
    nn_from = 0
    do iv=1,3
       do kblock = 0, lo % nzblock-1
          do jblock = 0, lo % nyblock-1
             do iblock = 0, lo % nxblock-1
                do k=k3,nzm2
                   do j=3,ny0+2
                      do i=3,nx0+2
                         call pidx2xidx (i, j, k, iv, iblock, jblock, kblock, ix, jx)
                         if (idx_local (lo, iblock, jblock, kblock)) then
                            ip = proc_id (x_lo, jx)
                            n = nn_from (ip) + 1
                            nn_from (ip) = n
                            from_list(ip) % first(n)  = i
                            from_list(ip) % second(n) = j
                            from_list(ip) % third(n)  = k
                            from_list(ip) % fourth(n) = iv
                         end if
                         if (idx_local (x_lo, jx)) then
                            ip = proc_id (lo, iblock, jblock, kblock)
                            n = nn_to(ip) + 1
                            nn_to(ip) = n
                            to_list(ip) % first(n)  = ix
                            to_list(ip) % second(n) = idx (x_lo, jx)
                         end if
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

    from_low (1) = 1
    from_low (2) = 1
    from_low (3) = 1
    from_low (4) = 1

    to_low = 1
    
!!!.I don't think that to_high is ever used.
    to_high(1) = x_lo % nx0_tot + 1
    to_high(2) = x_lo % nx_fft

    from_high = 0  ! not used in this case

    call init_redist (p2x, 'r', to_low, to_high, to_list, &
         from_low, from_high, from_list)

    call delete_list (to_list)
    call delete_list (from_list)

  end subroutine init_x_redist

  subroutine init_y_redist

    use par, only: nx0, ny0, nz0, k3, nzm2
    use layouts, only: init_layouts, x_lo, y_lo, lo, idx_local, idx, proc_id
    use layouts, only: xidx2yidx
    use mp, only: npe, iproc
    use redistribute, only: index_list_type, init_redist, delete_list
    implicit none
    type (index_list_type), dimension (0:npe-1) :: to_list, from_list
    integer, dimension (0:npe-1) :: nn_to, nn_from
    logical :: initialized = .false.
    integer :: i, j, k, iblock, jblock, kblock, ix, jx, iy, jy, iv, ip, n
    integer, dimension (2) :: from_low
    integer, dimension (2) :: to_high
    integer, dimension (1) :: from_high
    integer :: to_low

    if (initialized) return
    initialized = .true.

    call init_layouts

    ! count number of elements to be redistributed to/from each processor
    nn_to = 0
    nn_from = 0
    do iv=1,3
       do kblock = 0, lo % nzblock-1
          do jblock = 0, lo % nyblock-1
             do iblock = 0, lo % nxblock-1
                do k=k3,nzm2
                   do j=3,ny0+2
                      do i=3,nx0+2
                         call xidx2yidx (i, j, k, iv, iblock, jblock, kblock, ix, jx, iy, jy)
                         if (idx_local (x_lo, jx)) &
                              nn_from( proc_id(y_lo, jy))=nn_from(proc_id(y_lo, jy))+1
                         if (idx_local (y_lo, jy)) &                       
                              nn_to( proc_id(x_lo, jx)) = nn_to( proc_id(x_lo, jx)) + 1
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

! allocate space for index lists
    do ip = 0, npe-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip) % first (nn_from(ip)))
          allocate (from_list(ip) % second(nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip) % first (nn_to(ip)))
          allocate (to_list(ip) % second(nn_to(ip)))
       end if
    end do

! get local indices of elements distributed to/from other processors
    nn_to = 0
    nn_from = 0
    do iv=1,3
       do kblock = 0, lo % nzblock-1
          do jblock = 0, lo % nyblock-1
             do iblock = 0, lo % nxblock-1
                do k=k3,nzm2
                   do j=3,ny0+2
                      do i=3,nx0+2
                         call xidx2yidx (i, j, k, iv, iblock, jblock, kblock, ix, jx, iy, jy)
                         if (idx_local (x_lo, jx)) then
                            ip = proc_id (y_lo, jy)
                            n = nn_from (ip) + 1
                            nn_from (ip) = n
                            from_list(ip) % first(n)  = ix
                            from_list(ip) % second(n) = idx (x_lo, jx)
                         end if
                         if (idx_local (y_lo, jy)) then
                            ip = proc_id (x_lo, jx)
                            n = nn_to(ip) + 1
                            nn_to(ip) = n
                            to_list(ip) % first(n)  = iy
                            to_list(ip) % second(n) = idx (y_lo, jy)
                         end if
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

    from_low (1) = 1
    from_low (2) = 1

    to_low = 1

!!!.I don't think that to_high is ever used.    
    to_high(1) = y_lo % ny0_tot + 1
    to_high(2) = y_lo % ny_fft

    from_high = 0  ! not used in this case

    call init_redist (x2y, 'c', to_low, to_high, to_list, &
         from_low, from_high, from_list)

    call delete_list (to_list)
    call delete_list (from_list)

  end subroutine init_y_redist

  subroutine init_z_redist

    use par, only: nx0, ny0, nz0, k3, nzm2
    use layouts, only: init_layouts, y_lo, z_lo, lo, idx_local, proc_id, idx
    use layouts, only: yidx2zidx
    use mp, only: npe, iproc
    use redistribute, only: index_list_type, init_redist, delete_list
    implicit none
    type (index_list_type), dimension (0:npe-1) :: to_list, from_list
    integer, dimension (0:npe-1) :: nn_to, nn_from
    logical :: initialized = .false.
    integer :: i, j, k, iblock, jblock, kblock, iy, jy, iz, jz, iv, n, ip
    integer, dimension (2) :: from_low
    integer, dimension (2) :: to_high
    integer, dimension (1) :: from_high
    integer :: to_low

    if (initialized) return
    initialized = .true.

    call init_layouts

    ! count number of elements to be redistributed to/from each processor
    nn_to = 0
    nn_from = 0
    do iv=1,3
       do kblock = 0, lo % nzblock-1
          do jblock = 0, lo % nyblock-1
             do iblock = 0, lo % nxblock-1
                do i=3,nx0+2  !!! This is an attempt to optimize
                   do k=k3,nzm2
                      do j=3,ny0+2
                         call yidx2zidx (i, j, k, iv, iblock, jblock, kblock, iy, jy, iz, jz)
                         if (idx_local (y_lo, jy)) &
                              nn_from( proc_id(z_lo, jz))=nn_from(proc_id(z_lo, jz))+1
                         if (idx_local (z_lo, jz)) &                       
                              nn_to( proc_id(y_lo, jy)) = nn_to( proc_id(y_lo, jy)) + 1
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

! allocate space for index lists
    do ip = 0, npe-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip) % first (nn_from(ip)))
          allocate (from_list(ip) % second(nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip) % first (nn_to(ip)))
          allocate (to_list(ip) % second(nn_to(ip)))
       end if
    end do

! get local indices of elements distributed to/from other processors
    nn_to = 0
    nn_from = 0
    do iv=1,3
       do kblock = 0, lo % nzblock-1
          do jblock = 0, lo % nyblock-1
             do iblock = 0, lo % nxblock-1
                do i=3,nx0+2  !!! This is an attempt to optimize
                   do k=k3,nzm2
                      do j=3,ny0+2
                         call yidx2zidx (i, j, k, iv, iblock, jblock, kblock, iy, jy, iz, jz)
                         if (idx_local (y_lo, jy)) then
                            ip = proc_id (z_lo, jz)
                            n = nn_from (ip) + 1
                            nn_from (ip) = n
                            from_list(ip) % first(n)  = iy
                            from_list(ip) % second(n) = idx (y_lo, jy)
                         end if
                         if (idx_local (z_lo, jz)) then
                            ip = proc_id (y_lo, jy)
                            n = nn_to(ip) + 1
                            nn_to(ip) = n
                            to_list(ip) % first(n)  = iz
                            to_list(ip) % second(n) = idx (z_lo, jz)
                         end if
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

    from_low (1) = 1
    from_low (2) = 1

    to_low = 1
    
!!!.I don't think that too_high is ever used.
    to_high(1) = z_lo % nz0_tot + 1
    to_high(2) = z_lo % nz_fft

    from_high = 0  ! not used in this case

    call init_redist (y2z, 'c', to_low, to_high, to_list, &
         from_low, from_high, from_list)

    call delete_list (to_list)
    call delete_list (from_list)

  end subroutine init_z_redist

  subroutine x_transform(x, kx)
    use par
    implicit none
    real,    dimension (nx0_tot+1, nx_fft), intent (in) :: x
    complex, dimension (nx0_tot+1, nx_fft), intent (out) :: kx
    integer :: n

    do n = 1, nx_fft
       call scfft(1, nx0_tot, 1., x(:,n), kx(:,n), table_x, work_x, 0)
    end do

  end subroutine x_transform

  subroutine y_transform (y) 
    use par
    implicit none
    complex, dimension (ny0_tot + 1, ny_fft) :: y
    integer :: n

    do n = 1, ny_fft
       call ccfft (1, ny0_tot, 1., y(:,n), y(:,n), table_y, work_y, 0)
    end do

  end subroutine y_transform

  subroutine z_transform (z) 
    use par
    implicit none
    complex, dimension (nz0_tot+1, nz_fft) :: z
    integer :: n

#  if NZ != 1
    do n = 1, nz_fft
       call ccfft (1, nz0_tot, 1., z(:,n), z(:,n), table_z, work_z, 0)
    end do
#  endif

  end subroutine z_transform

  subroutine inverse_x (kx, x)  
    use par 
    implicit none
    complex, dimension (nx0_tot+1, nx_fft), intent (in) :: kx
    real,    dimension (nx0_tot+1, nx_fft), intent (out) :: x
    integer :: n

    do n = 1, nx_fft
       call csfft (-1, nx0_tot, 1., kx(:,n), x(:,n), table_x, work_x, 0)
    end do

  end subroutine inverse_x

  subroutine inverse_y (y) 
    use par
    implicit none
    complex, dimension (ny0_tot+1, ny_fft) :: y
    integer :: n

    do n = 1, ny_fft
       call ccfft (-1, ny0_tot, 1., y(:,n), y(:,n), table_y, work_y, 0)
    end do

  end subroutine inverse_y

  subroutine inverse_z (z) 
    use par
    implicit none
    complex, dimension (nz0_tot+1, nz_fft) :: z
    integer :: n

#  if NZ != 1
    do n = 1, nz_fft
       call ccfft (-1, nz0_tot, 1., z(:,n), z(:,n), table_z, work_z, 0)
    end do
#  endif
  end subroutine inverse_z

  subroutine transform_2d (a, y_data)
    
    use par
    use redistribute, only: gather, scatter
       real, dimension (nx+1, ny+1, nz, 3) :: a
    complex, dimension (ny0_tot+1, ny_fft) :: y_data

    y_data (ny0_tot+1,:) = 0.

    call gather (p2x, a, x_data)
    call x_transform (x_data, kx_data)

    call gather (x2y, kx_data, y_data)
    call y_transform (y_data)

  end subroutine transform_2d

  subroutine inverse_2d (a, y_data)

    use par
    use redistribute, only: gather, scatter
    real, dimension (nx+1, ny+1, nz, 3) :: a
    complex, dimension (ny0_tot+1, ny_fft) :: y_data

    call inverse_y (y_data)
    call scatter (x2y, y_data, kx_data)

    call inverse_x (kx_data, x_data)
    call scatter (p2x, x_data, a)

  end subroutine inverse_2d

  subroutine transform_3d (a, z_data)
    
    use par
    use redistribute, only: gather, scatter
    real, dimension (nx+1, ny+1, nz, 3) :: a
    complex, dimension (nz0_tot+1, nz_fft) :: z_data

    z_data (nz0_tot+1,:) = 0.

    call gather (p2x, a, x_data)
    call x_transform (x_data, kx_data)

    call gather (x2y, kx_data, y_data)
    call y_transform (y_data)

    call gather (y2z, y_data, z_data)
    call z_transform (z_data)

  end subroutine transform_3d

  subroutine inverse_3d (a, z_data)

    use par
    use redistribute, only: gather, scatter
    real, dimension (nx+1, ny+1, nz, 3) :: a
    complex, dimension (nz0_tot+1, nz_fft) :: z_data

    call inverse_z (z_data) 
    call scatter (y2z, z_data, y_data)

    call inverse_y (y_data)
    call scatter (x2y, y_data, kx_data)

    call inverse_x (kx_data, x_data)
    call scatter (p2x, x_data, a)
    
  end subroutine inverse_3d

end module transforms
