Wave functions

Work in progress!

Wavefunctions in Octopus

The wave functions in Octopus are referred to as the states.

They are handled by the module states_abst_oct_m, which is defined in states/states_abst.F90.

The exact way how states are stored in memory is flexible and depends on optimization (and accelerator) settings in the input file. In particular, the order of indices depends on the PACKED setting.

States are stored in a hierarchy of ‘containers’. Concepts of this hierarchy include groups, batches and blocks.

The abstract class

The top level data structure, describing states is:

  type, abstract :: states_abst_t
    private
    type(type_t), public  :: wfs_type         !< real (TYPE_FLOAT) or complex (TYPE_CMPLX) wavefunctions
    integer, public  :: nst                   !< Number of states in each irreducible subspace
    logical, public  :: packed

  contains

    procedure(nullify),    deferred :: nullify
    procedure(pack),       deferred :: pack
    procedure(unpack),     deferred :: unpack
    procedure(write_info), deferred :: write_info
    procedure(set_zero),   deferred :: set_zero
    procedure, non_overridable      :: are_packed
    procedure, non_overridable      :: get_type
  end type states_abst_t

This structure contains mainly metadata about states, describing how states are represented, and defines the interface to the class. As it is an abstract class, it cannot contain any information about the actual system.

The states class for electrons

The class states_elec_t, specialized the abstract class and contains more data specific to the electron system, as well as pointers to other quantities, which are common to all states, such as the density, the current, etc.

Definition of "states_elec_t"

The dimensions object contains a number of variables. The most relevant for this discussion is the dim variable, which denotes the dimension of one state, being 1 for spin-less states and 2 for spinors.

  type states_elec_dim_t
    ! Components are public by default
    integer :: dim                  !< Dimension of the state (one, or two for spinors)
    integer :: nik                  !< Number of irreducible subspaces
    integer :: ispin                !< spin mode (unpolarized, spin-polarized, spinors)
    integer :: nspin                !< dimension of rho (1, 2 or 4)
    integer :: spin_channels        !< 1 or 2, whether spin is or not considered.
    FLOAT, allocatable  :: kweights(:)   !< weights for the k-point integrations
    type(distributed_t) :: kpt
    integer :: block_size
    integer :: orth_method = 0
    logical :: pack_states
    FLOAT   :: cl_states_mem
  contains
    procedure :: get_spin_index => states_elec_dim_get_spin_index
    procedure :: get_kpoint_index => states_elec_dim_get_kpoint_index
  end type states_elec_dim_t

The wave functions themselves are stored in

    type(states_elec_group_t)     :: group

which, in turn, is defined in the module states_elec_group_oct_m in src/states/states_elec_group.F90:

The group contains all wave functions, grouped together in blocks or batches. They are organised in an array of batch_t structures.

  type states_elec_group_t
    ! Components are public by default
    type(wfs_elec_t), allocatable :: psib(:, :)            !< A set of wave-functions blocks
    integer                  :: nblocks               !< The number of blocks
    integer                  :: block_start           !< The lowest index of local blocks
    integer                  :: block_end             !< The highest index of local blocks
    integer, allocatable     :: iblock(:, :)          !< A map, that for each state index, returns the index of block containing it
    integer, allocatable     :: block_range(:, :)     !< Each block contains states from block_range(:, 1) to block_range(:, 2)
    integer, allocatable     :: block_size(:)         !< The number of states in each block.
    logical, allocatable     :: block_is_local(:, :)  !< It is true if the block is in this node.
    integer, allocatable     :: block_node(:)         !< The node that contains each block
    integer, allocatable     :: rma_win(:, :)         !< The MPI window for one side communication
    logical                  :: block_initialized = .false. !< For keeping track of the blocks to avoid memory leaks
  end type states_elec_group_t
    type(wfs_elec_t), pointer   :: psib(:, :)            !< A set of wave-functions

The indexing is as follows: psib(ib,iqb) where ib is the block index, and iqn the k-point. See below for the routine states_init_block(st, mesh, verbose) which creates the group object. On a given node, only wave functions of local blocks are available. The group object does contain all information on how the batches are distributed over nodes.

  type, extends(batch_t) :: wfs_elec_t
    private
    integer, public :: ik
    logical, public :: has_phase
  contains
    procedure :: clone_to => wfs_elec_clone_to
    procedure :: clone_to_array => wfs_elec_clone_to_array
    procedure :: copy_to => wfs_elec_copy_to
    procedure :: check_compatibility_with => wfs_elec_check_compatibility_with
    procedure :: end => wfs_elec_end
  end type wfs_elec_t

Creating the wave functions

A number of steps in initializing the states_t object are called from the system_init() routine:

states_init():
parses states-related input variables, and allocates memory for some book keeping variables. It does not allocate any memory for the states themselves.

states_distribute_nodes():

states_density_init():
allocates memory for the density (rho) and the core density (rho_core).

states_exec_init():

  1. Fills in the block size (st\%d\%block_size);
  2. Finds out whether or not to pack the states (st\%d\%pack_states);
  3. Finds out the orthogonalization method (st\%d\%orth_method).

Memory for the actual wave functions is allocated in states_elec_allocate_wfns() which is called from the corresponding *_run() routines, such as scf_run() or td_run(), etc.

  subroutine states_elec_allocate_wfns(st, mesh, wfs_type, skip, packed)
    type(states_elec_t),    intent(inout)   :: st
    type(mesh_t),           intent(in)      :: mesh
    type(type_t), optional, intent(in)      :: wfs_type
    logical,      optional, intent(in)      :: skip(:)
    logical,      optional, intent(in)      :: packed

    PUSH_SUB(states_elec_allocate_wfns)

    if (present(wfs_type)) then
      ASSERT(wfs_type == TYPE_FLOAT .or. wfs_type == TYPE_CMPLX)
      st%wfs_type = wfs_type
    end if

    call states_elec_init_block(st, mesh, skip = skip, packed=packed)
    call states_elec_set_zero(st)

    POP_SUB(states_elec_allocate_wfns)
  end subroutine states_elec_allocate_wfns

The routine states_init_block initializes the data components in st that describe how the states are distributed in blocks:

st%nblocks: this is the number of blocks in which the states are divided. Note that this number is the total number of blocks, regardless of how many are actually stored in each node.
block_start: in each node, the index of the first block.
block_end: in each node, the index of the last block. If the states are not parallelized, then block_start is 1 and block_end is st%nblocks.
st%iblock(1:st%nst, 1:st%d%nik): it points, for each state, to the block that contains it.
st%block_is_local(): st%block_is_local(ib) is .true. if block ib is stored in the running node.
st%block_range(1:st%nblocks, 1:2): Block ib contains states fromn st%block_range(ib, 1) to st%block_range(ib, 2)
st%block_size(1:st%nblocks): Block ib contains a number st%block_size(ib) of states.
st%block_initialized: it should be .false. on entry, and .true. after exiting this routine.

The set of batches st%psib(1:st%nblocks) contains the blocks themselves.

  subroutine states_elec_init_block(st, mesh, verbose, skip, packed)
    type(states_elec_t),           intent(inout) :: st
    type(mesh_t),                  intent(in)    :: mesh
    logical, optional,             intent(in)    :: verbose
    logical, optional,             intent(in)    :: skip(:)
    logical, optional,             intent(in)    :: packed

    integer :: ib, iqn, ist, istmin, istmax
    logical :: same_node, verbose_, packed_
    integer, allocatable :: bstart(:), bend(:)

    PUSH_SUB(states_elec_init_block)

    SAFE_ALLOCATE(bstart(1:st%nst))
    SAFE_ALLOCATE(bend(1:st%nst))
    SAFE_ALLOCATE(st%group%iblock(1:st%nst, 1:st%d%nik))

    st%group%iblock = 0

    verbose_ = optional_default(verbose, .true.)
    packed_ = optional_default(packed, .false.)

    !In case we have a list of state to skip, we do not allocate them
    istmin = 1
    if (present(skip)) then
      do ist = 1, st%nst
        if (.not. skip(ist)) then
          istmin = ist
          exit
        end if
      end do
    end if

    istmax = st%nst
    if (present(skip)) then
      do ist = st%nst, istmin, -1
        if (.not. skip(ist)) then
          istmax = ist
          exit
        end if
      end do
    end if

    if (present(skip) .and. verbose_) then
      call messages_write('Info: Allocating states from ')
      call messages_write(istmin, fmt = 'i8')
      call messages_write(' to ')
      call messages_write(istmax, fmt = 'i8')
      call messages_info()
    end if

    ! count and assign blocks
    ib = 0
    st%group%nblocks = 0
    bstart(1) = istmin
    do ist = istmin, istmax
      ib = ib + 1

      st%group%iblock(ist, st%d%kpt%start:st%d%kpt%end) = st%group%nblocks + 1

      same_node = .true.
      if (st%parallel_in_states .and. ist /= istmax) then
        ! We have to avoid that states that are in different nodes end
        ! up in the same block
        same_node = (st%node(ist + 1) == st%node(ist))
      end if

      if (ib == st%d%block_size .or. ist == istmax .or. .not. same_node) then
        ib = 0
        st%group%nblocks = st%group%nblocks + 1
        bend(st%group%nblocks) = ist
        if (ist /= istmax) bstart(st%group%nblocks + 1) = ist + 1
      end if
    end do

    SAFE_ALLOCATE(st%group%psib(1:st%group%nblocks, st%d%kpt%start:st%d%kpt%end))

    SAFE_ALLOCATE(st%group%block_is_local(1:st%group%nblocks, st%d%kpt%start:st%d%kpt%end))
    st%group%block_is_local = .false.
    st%group%block_start  = -1
    st%group%block_end    = -2  ! this will make that loops block_start:block_end do not run if not initialized

    do ib = 1, st%group%nblocks
      if (bstart(ib) >= st%st_start .and. bend(ib) <= st%st_end) then
        if (st%group%block_start == -1) st%group%block_start = ib
        st%group%block_end = ib
        do iqn = st%d%kpt%start, st%d%kpt%end
          st%group%block_is_local(ib, iqn) = .true.

          if (states_are_real(st)) then
            call dwfs_elec_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), mesh%np_part, iqn, &
              special=.true., packed=packed_)
          else
            call zwfs_elec_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), mesh%np_part, iqn, &
              special=.true., packed=packed_)
          end if

        end do
      end if
    end do

    SAFE_ALLOCATE(st%group%block_range(1:st%group%nblocks, 1:2))
    SAFE_ALLOCATE(st%group%block_size(1:st%group%nblocks))

    st%group%block_range(1:st%group%nblocks, 1) = bstart(1:st%group%nblocks)
    st%group%block_range(1:st%group%nblocks, 2) = bend(1:st%group%nblocks)
    st%group%block_size(1:st%group%nblocks) = bend(1:st%group%nblocks) - bstart(1:st%group%nblocks) + 1

    st%group%block_initialized = .true.

    SAFE_ALLOCATE(st%group%block_node(1:st%group%nblocks))

    ASSERT(allocated(st%node))
    ASSERT(all(st%node >= 0) .and. all(st%node < st%mpi_grp%size))

    do ib = 1, st%group%nblocks
      st%group%block_node(ib) = st%node(st%group%block_range(ib, 1))
      ASSERT(st%group%block_node(ib) == st%node(st%group%block_range(ib, 2)))
    end do

    if (verbose_) then
      call messages_write('Info: Blocks of states')
      call messages_info()
      do ib = 1, st%group%nblocks
        call messages_write('      Block ')
        call messages_write(ib, fmt = 'i8')
        call messages_write(' contains ')
        call messages_write(st%group%block_size(ib), fmt = 'i8')
        call messages_write(' states')
        if (st%group%block_size(ib) > 0) then
          call messages_write(':')
          call messages_write(st%group%block_range(ib, 1), fmt = 'i8')
          call messages_write(' - ')
          call messages_write(st%group%block_range(ib, 2), fmt = 'i8')
        end if
        call messages_info()
      end do
    end if

!!$!!!!DEBUG
!!$    ! some debug output that I will keep here for the moment
!!$    if (mpi_grp_is_root(mpi_world)) then
!!$      print*, "NST       ", st%nst
!!$      print*, "BLOCKSIZE ", st%d%block_size
!!$      print*, "NBLOCKS   ", st%group%nblocks
!!$
!!$      print*, "==============="
!!$      do ist = 1, st%nst
!!$        print*, st%node(ist), ist, st%group%iblock(ist, 1)
!!$      end do
!!$      print*, "==============="
!!$
!!$      do ib = 1, st%group%nblocks
!!$        print*, ib, bstart(ib), bend(ib)
!!$      end do
!!$
!!$    end if
!!$!!!!ENDOFDEBUG

    SAFE_DEALLOCATE_A(bstart)
    SAFE_DEALLOCATE_A(bend)
    POP_SUB(states_elec_init_block)
  end subroutine states_elec_init_block

The allocation of memory for the actual wave functions is performed in batch_init_empty() and X(batch_allocate)(). This routine, and the related X(batch_add_state)() show most clearly how the different memory blocks are related.

batch_init_empty() allocates the memory for batch_state_t states and batch_states_l_t states_linear and nullifies the pointers within this types. Note that no memory for the actual wave functions has been allocated yet.

Questions:

How are the different objects pointing to states related?

What is the difference between batch_add_state and batch_add_state_linear?