module layouts

  public :: init_layouts
  public :: proc_id, idx_local
  public :: p_layout_type
  public :: x_layout_type
  public :: y_layout_type
  public :: z_layout_type
  public :: lo, x_lo, y_lo, z_lo
  public :: pidx2xidx, xidx2yidx, yidx2zidx
  public :: idx, j_idx, i_idx, midx

  type :: p_layout_type
     integer :: iproc    
     integer :: iblock, jblock, kblock
     integer :: nxblock, nyblock, nzblock
     integer :: imin0, imax0, jmin0, jmax0, kmin0, kmax0
  end type p_layout_type

  type :: x_layout_type
     integer :: iproc    
     integer :: nx0_tot, ny0_tot, nz0_tot
     integer :: nx0, ny0, nz0
     integer :: nx_fft
  end type x_layout_type

  type :: y_layout_type
     integer :: iproc    
     integer :: nx0_tot, ny0_tot, nz0_tot
     integer :: nx0, ny0, nz0
     integer :: ny_fft
  end type y_layout_type

  type :: z_layout_type
     integer :: iproc    
     integer :: nx0_tot, ny0_tot, nz0_tot
     integer :: nx0, ny0, nz0
     integer :: nz_fft
  end type z_layout_type

  type (p_layout_type) :: lo
  type (x_layout_type) :: x_lo
  type (y_layout_type) :: y_lo
  type (z_layout_type) :: z_lo

  interface proc_id
     module procedure proc_id_p
     module procedure proc_id_x
     module procedure proc_id_y
     module procedure proc_id_z
  end interface

  interface i_idx
     module procedure i_idx_y, i_idx_z
  end interface

! typically end up adding more, so...
  interface j_idx
     module procedure j_idx_z
  end interface

  interface midx
!    Calculates the 2nd (master) index of the x/y/z data array
!       from the 2nd (local) index on iproc processor.
     module procedure midx_jx, midx_jy, midx_jz
  end interface

  interface idx
!    idx_x/y/z calculates the 2nd (master) index of the x/y/z/data array.
!    idx_jx/y/z calculates the 2nd (local) index of the x/y/z/data array
!       from the 2nd (master) index.
     module procedure idx_x, idx_jx
     module procedure idx_y, idx_jy
     module procedure idx_z, idx_jz
  end interface

  interface idx_local 
     module procedure idx_local_p, idx_local_x, idx_local_y, idx_local_z
  end interface

contains
  
  subroutine init_layouts 
    use mp, only: iproc
    use par, only: nx_fft, ny_fft, nz_fft
    use par, only: nx0_tot, ny0_tot, nz0_tot
    use par, only: nx0, ny0, nz0
    integer, parameter :: nxblock=NXBLOCK, nyblock=NYBLOCK, nzblock=NZBLOCK
    logical :: initialized = .false.

    if (initialized) return
    initialized = .true.

! store information abount parallelism in a structure

    lo % iproc = iproc

    lo % iblock = mod (iproc, nxblock)
    lo % jblock = mod (iproc / nxblock, nyblock) 
    lo % kblock = mod (iproc / (nxblock * nyblock),nzblock)

    lo % nxblock = nxblock
    lo % nyblock = nyblock
    lo % nzblock = nzblock

    x_lo % iproc = iproc
    x_lo % nx0_tot = nx0_tot
    x_lo % ny0_tot = ny0_tot
    x_lo % nz0_tot = nz0_tot
    x_lo % nx0 = nx0
    x_lo % ny0 = ny0
    x_lo % nz0 = nz0
    x_lo % nx_fft = nx_fft

    y_lo % iproc = iproc
    y_lo % nx0_tot = nx0_tot
    y_lo % ny0_tot = ny0_tot
    y_lo % nz0_tot = nz0_tot
    y_lo % nx0 = nx0
    y_lo % ny0 = ny0
    y_lo % nz0 = nz0
    y_lo % ny_fft = ny_fft

    z_lo % iproc = iproc
    z_lo % nx0_tot = nx0_tot
    z_lo % ny0_tot = ny0_tot
    z_lo % nz0_tot = nz0_tot
    z_lo % nx0 = nx0
    z_lo % ny0 = ny0
    z_lo % nz0 = nz0
    z_lo % nz_fft = nz_fft

  end subroutine init_layouts

! returns processor number which has iblock, jblock, kblock
  function proc_id_p (lo, iblock, jblock, kblock)
    use mp, only: nproc
    implicit none
    integer :: proc_id_p
    type (p_layout_type), intent (in) :: lo
    integer, intent (in) :: iblock, jblock, kblock

    proc_id_p = iblock + lo%nxblock*(jblock + lo%nyblock*kblock) 

  end function proc_id_p

! returns .true. if iblock, jblock, kblock is on this processor
  function idx_local_p (lo, iblock, jblock, kblock)

    logical :: idx_local_p
    type (p_layout_type), intent (in) :: lo
    integer, intent (in) :: iblock, jblock, kblock

    idx_local_p = lo%iproc == proc_id(lo, iblock, jblock, kblock)

  end function idx_local_p
    
! returns second (master-listed) index of x_data array, given j, k, iv
  function idx_x (lo, jblock, kblock, j, k, iv)
    use par, only: numtwo
    implicit none
    integer :: idx_x
    type (x_layout_type), intent (in) :: lo
    integer, intent (in) :: jblock, kblock, j, k, iv
    integer :: jj, kk

! subtract 2 from each of j and k to avoid guard cells
! Except in nz=1 case, where k is always=1.
    jj = j-2      + lo % ny0 * jblock
    kk = k-numtwo + lo % nz0 * kblock

    idx_x = jj-1 + lo%ny0_tot*(kk-1 + lo%nz0_tot*(iv-1))
  end function idx_x

! returns second (local) index of x_data array, given jx,
! where jx is the second(master) index of x_data_array
! which runs from 0 to (ny0_tot*nz0_tot*3 - 1)
  function idx_jx (lo, jx)
    implicit none
    integer :: idx_jx
    type (x_layout_type), intent (in) :: lo
    integer, intent (in) :: jx

    idx_jx = mod (jx, lo%nx_fft) + 1
  end function idx_jx

! Returns second (master) index of x_data_array, given
! the second (local) index of x_data_array and processor number.
  function midx_jx (lo, jx, iproc)
    implicit none
    integer :: midx_jx
    type (x_layout_type), intent (in) :: lo
    integer, intent (in) :: jx
    integer, intent (in) :: iproc

    midx_jx = (jx - 1) + lo%nx_fft*iproc
  end function midx_jx

! returns PE number which has x_data(:,i)
! where i is the second (master) index of x_data_array.
  function proc_id_x (lo, i)
    implicit none
    integer :: proc_id_x
    type (x_layout_type), intent (in) :: lo
    integer, intent (in) :: i
    proc_id_x = i/lo%nx_fft
  end function proc_id_x
 
! returns true if second master-listed index of x_layout is on this PE
  function idx_local_x (lo, i)
    implicit none
    logical idx_local_x
    type (x_layout_type), intent (in) :: lo
    integer, intent (in) :: i
    idx_local_x = lo%iproc == proc_id(lo, i)
  end function idx_local_x

! returns second master-listed index of y_data array, 
! given iblock, kblock, i, k, iv
  function idx_y (lo, iblock, kblock, i, k, iv)
    use par, only: numtwo   
    implicit none
    integer :: idx_y
    type (y_layout_type), intent (in) :: lo
    integer, intent (in) :: iblock, kblock, i, k, iv
    integer :: ii, kk

! subtract 2 from each of i and j to eliminate guard cells
! Except in nz=1 case, where k is always=1.
    ii = i-2      + lo % nx0 * iblock
    kk = k-numtwo + lo % nz0 * kblock

    idx_y = kk-1 + lo%nz0_tot*(ii-1 + lo%nx0_tot*(iv-1))
  end function idx_y

! returns extended i index (i.e., 1 <= i <= nx0_tot) from y layout 
! j is the second (master) index of y_layout.
  function i_idx_y (lo, j)
    implicit none
    integer :: i_idx_y
    type (y_layout_type), intent (in) :: lo
    integer, intent (in) :: j
    
    i_idx_y = 1 + mod(j/lo%nz0_tot, lo%nx0_tot)
  end function i_idx_y

! returns second (local) index of y_data array, given jy
  function idx_jy (lo, jy)
    implicit none
    integer :: idx_jy
    type (y_layout_type), intent (in) :: lo
    integer, intent (in) :: jy

    idx_jy = mod (jy, lo%ny_fft) + 1
  end function idx_jy

! Returns second (master) index of y_data_array, given
! the second (local) index of y_data_array and processor number.
  function midx_jy (lo, jy, iproc)
    implicit none
    integer :: midx_jy
    type (y_layout_type), intent (in) :: lo
    integer, intent (in) :: jy
    integer, intent (in) :: iproc

    midx_jy = (jy - 1) + lo%ny_fft*iproc
  end function midx_jy
  
! returns PE number which has y_data(:,i)
! i is the second (master) index of the y_data_array.
  function proc_id_y (lo, i)
    implicit none
    integer :: proc_id_y
    type (y_layout_type), intent (in) :: lo
    integer, intent (in) :: i
    proc_id_y = i/lo%ny_fft
  end function proc_id_y
 
! returns true if y_data(:,i) is on this PE
 function idx_local_y (lo, i)
   implicit none
   logical idx_local_y
   type (y_layout_type), intent (in) :: lo
   integer, intent (in) :: i
   idx_local_y = lo%iproc == proc_id(lo, i)
 end function idx_local_y

! returns second index of z_data array, given i, j, iv
  function idx_z (lo, iblock, jblock, i, j, iv)
    implicit none
    integer :: idx_z
    type (z_layout_type), intent (in) :: lo
    integer, intent (in) :: iblock, jblock, i, j, iv
    integer :: ii, jj

    ii = i-2 + lo % nx0 * iblock
    jj = j-2 + lo % ny0 * jblock

    idx_z = jj-1 + lo%ny0_tot*(ii-1 + lo%nx0_tot*(iv-1))
  end function idx_z

! returns extended i index (i.e., 1 <= i <= nx0_tot) from z layout 
  function i_idx_z (lo, j)
    implicit none
    integer :: i_idx_z
    type (z_layout_type), intent (in) :: lo
    integer, intent (in) :: j
    
    i_idx_z = 1 + mod(j/lo%ny0_tot, lo%nx0_tot)
  end function i_idx_z

! returns extended j index (i.e., 1 <= j <= ny0_tot) from z layout 
  function j_idx_z (lo, j)
    implicit none
    integer :: j_idx_z
    type (z_layout_type), intent (in) :: lo
    integer, intent (in) :: j
    
    j_idx_z = 1 + mod(j, lo%ny0_tot)
  end function j_idx_z

! returns second (local) index of z_data array, given jz
  function idx_jz (lo, jz)
    implicit none
    integer :: idx_jz
    type (z_layout_type), intent (in) :: lo
    integer, intent (in) :: jz

    idx_jz = mod (jz, lo%nz_fft) + 1
  end function idx_jz

! Returns second (master) index of z_data_array, given
! the second (local) index of z_data_array and processor number.
  function midx_jz (lo, jz, iproc)
    implicit none
    integer :: midx_jz
    type (z_layout_type), intent (in) :: lo
    integer, intent (in) :: jz
    integer, intent (in) :: iproc

    midx_jz = (jz - 1) + lo%nz_fft*iproc
  end function midx_jz
  
! returns PE number which has z_data(:,i)
  function proc_id_z (lo, i)
    implicit none
    integer :: proc_id_z
    type (z_layout_type), intent (in) :: lo
    integer, intent (in) :: i
    proc_id_z = i/lo%nz_fft
  end function proc_id_z
 
! returns true if z_data(:,i) is on this PE
 function idx_local_z (lo, i)
   implicit none
   logical idx_local_z
   type (z_layout_type), intent (in) :: lo
   integer, intent (in) :: i
   idx_local_z = lo%iproc == proc_id(lo, i)
 end function idx_local_z

 subroutine pidx2xidx (i, j, k, iv, iblock, jblock, kblock, ix, jx)
 
   use par, only: nx0, ny0, nz0
   implicit none
   integer, intent (in) :: i, j, k, iv, iblock, jblock, kblock
   integer, intent (out) :: ix, jx

! subtracting 2 from i eliminates x guard cells
   ix = i-2 + nx0*iblock
   jx = idx (x_lo, jblock, kblock, j, k, iv)

 end subroutine pidx2xidx

 subroutine xidx2yidx (i, j, k, iv, iblock, jblock, kblock, ix, jx, iy, jy)
 
   use par, only: nx0, ny0, nz0
   implicit none
   integer, intent (in) :: i, j, k, iv, iblock, jblock, kblock
   integer, intent (out) :: ix, jx, iy, jy

! subtracting 2 from i and j eliminates x and y guard cells
   ix = i-2 + nx0*iblock
   jx = idx (x_lo, jblock, kblock, j, k, iv)

   iy = j-2 + ny0*jblock
   jy = idx (y_lo, iblock, kblock, i, k, iv)

 end subroutine xidx2yidx

 subroutine yidx2zidx (i, j, k, iv, iblock, jblock, kblock, iy, jy, iz, jz)
 
   use par, only: nx0, ny0, nz0
   implicit none
   integer, intent (in) :: i, j, k, iv, iblock, jblock, kblock
   integer, intent (out) :: iy, jy, iz, jz

! subtracting 2 from j and k eliminates z guard cells
! This routine is only called if code is 3D, so 
! no modification for 2D case is required.
   iy = j-2 + ny0*jblock
   jy = idx (y_lo, iblock, jblock, i, j, iv)

   iz = k-2 + nz0*kblock
   jz = idx (z_lo, jblock, kblock, j, k, iv)

 end subroutine yidx2zidx
end module layouts
