Octopus
isdf_serial.F90
Go to the documentation of this file.
1!! Copyright (C) 2024 - 2025. Alexander Buccheri
2!!
3!! This program is free software; you can redistribute it and/or modify
4!! it under the terms of the GNU General Public License as published by
5!! the Free Software Foundation; either version 2, or (at your option)
6!! any later version.
7!!
8!! This program is distributed in the hope that it will be useful,
9!! but WITHOUT ANY WARRANTY; without even the implied warranty of
10!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11!! GNU General Public License for more details.
12!!
13!! You should have received a copy of the GNU General Public License
14!! along with this program; if not, write to the Free Software
15!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
16!! 02110-1301, U
17
18#include "global.h"
19
22 use, intrinsic :: iso_fortran_env, only: real64, int64
23 use batch_oct_m
25 use blas_oct_m
26 use debug_oct_m
30 use global_oct_m
31 use grid_oct_m
32 use ions_oct_m
40 use math_oct_m
41 use mesh_oct_m
44 use mpi_oct_m, only: mpi_world
51 use space_oct_m
54 use xc_cam_oct_m, only: xc_cam_t
55
56 implicit none
57 private
58
59 public :: &
63
64 ! TODO(Alex) Issue #1195 Extend ISDF to spin-polarised systems
66 integer, parameter :: ik = 1
67
68contains
69
73 subroutine isdf_serial_interpolation_vectors(isdf, namespace, mesh, st, indices, phi_mu, P_r_mu, isdf_vectors)
74 type(isdf_options_t), intent(in ) :: isdf
75 type(namespace_t), intent(in ) :: namespace
76 class(mesh_t), intent(in ) :: mesh
77 type(states_elec_t), intent(in ) :: st
78 integer(int64), contiguous, intent(in ) :: indices(:)
79
80 real(real64), allocatable, intent(out) :: phi_mu(:, :)
81 ! defined at interpolation points: \f$ \varphi_i(\mathbf{r}_\mu) \f$
82 real(real64), allocatable, intent(out) :: P_r_mu(:, :)
83 ! \f$P_{\mathbf{r},\mu}\f$, with size (np, n_int)
84 real(real64), allocatable, intent(out) :: isdf_vectors(:, :)
85
86 real(real64), allocatable :: phi(:, :), cct(:, :)
87 integer :: n_states, n_int, rank
88 logical :: data_is_packed
89
90 push_sub_with_profile(isdf_serial_interpolation_vectors)
91 call messages_write(1)
92
93 ! Reference serial implementation not parallel in states or domain
94 if (st%parallel_in_states .or. mesh%parallel_in_domains) then
95 message(1) = "Serial ISDF called when running state or domain-parallel"
96 call messages_fatal(1)
97 endif
98
99 ! TODO(Alex) Issue #1195 Extend ISDF to spin-polarised systems
100 if (st%d%nspin > 1) then
101 call messages_not_implemented("ISDF Serial for SPIN_POLARIZED and SPINOR calculations", namespace)
102 endif
103
104 ! TODO(Alex) Issue #1196 Template ISDF handle both real and complex states
105 if (.not. states_are_real(st)) then
106 call messages_not_implemented("ISDF Serial handling of complex states", namespace)
107 endif
108
109 ! TODO(Alex) Implement algorithms for unpacked data structure
110 data_is_packed = st%group%psib(st%group%block_start, 1)%status() == batch_packed
111
112 if (.not. data_is_packed) then
113 message(1) = "Serial ISDF only implemented for BATCH_PACKED"
114 call messages_fatal(1)
115 endif
117 n_states = isdf%n_ks_states
118 n_int = size(indices)
119
120 call collate_batches_get_state(mesh, st, n_states, phi)
121 if (debug%info) call output_matrix(namespace, "phi_r_serial.txt", phi)
122
123 safe_allocate(phi_mu(1:n_states, 1:n_int))
124 call sample_phi_at_centroids(phi, indices, phi_mu)
125 if (debug%info) call output_matrix(namespace, "phi_mu_serial.txt", phi_mu)
126
127 safe_allocate(p_r_mu(1:mesh%np, 1:n_int))
128 call construct_density_matrix_packed(phi, phi_mu, p_r_mu)
129 if (debug%info) call output_matrix(namespace, "p_r_mu_serial.txt", p_r_mu)
130
131 ! Mutate P_r_mu to [ZC^T] = P_r_mu o P_r_mu
132 call construct_zct(p_r_mu)
133 if (debug%info) call output_matrix(namespace, "zct_serial.txt", p_r_mu)
134
135 ! [CC^T] = ZC^T[indices, :]
136 safe_allocate(cct(1:n_int, 1:n_int))
137 call construct_cct(indices, p_r_mu, cct)
138 if (debug%info) call output_matrix(namespace, "cct_serial.txt", cct)
139 assert(is_symmetric(cct))
140
141 ! Note, in principle one could use lalg_pseudo_inverse and just add an optional
142 ! arg for returning the rank, but annoyingly the criterion in that is > rather than >=.
143 ! Quantify the optimal number of interpolation points to use.
144 ! If the rank of CC^T is <= ISDFNpoints, this does not give any indication how many
145 ! additional points are required. It is only indicative when oversampling.
146 if (isdf%check_n_interp) then
147 rank = lalg_matrix_rank_svd(cct, preserve_mat=.true.)
148 write(message(1),'(a, I4)') "ISDF Serial: Rank of CC^T is ", rank
149 if (rank < n_int) then
150 write(message(2),'(a)') " - This rank is the optimal ISDFNpoints to run the calculation with"
151 else
152 write(message(2),'(a)') " - This suggests that ISDFNpoints is either optimal, or could be larger"
153 endif
154 call messages_info(2, namespace=namespace)
155 endif
156
157 ! [CC^T]^{-1}, mutating cct in-place
158 ! NOTE, if the number of interpolation points exceeds the rank of CC^T, CC^T is by definition
159 ! ill-conditioned, and requires inverting with the pseudo-inverse (SVD).
160 ! Tests show that this is a much better solution than regularisation of either the diagonals
161 ! or off-diagonals of CC^T. If one limits the number of interpolation points, there is no problem
162 ! with inversion, but this limits the total accuracy achievable with the method.
163
164 ! As CC^T and its inverse should be symmetric, ultimately want to:
165 ! * Only compute the inverse of the upper (or lower) triangle
166 ! * Modify the GEMM operation below to only use the upper (or lower) triangle of [CC^T]^{-1}
167 write(message(1),'(a)') "ISDF Serial: Inverting [CC^T]"
168 call messages_info(1, namespace=namespace, debug_only=.true.)
169
170 ! Invert [CC^T] and symmetrise
171 call lalg_svd_inverse(n_int, n_int, cct)
172 call symmetrize_matrix(n_int, cct)
173
174 ! Compute interpolation vectors, [ZC^T] [CC^T]^{-1}
175 safe_allocate(isdf_vectors(1:mesh%np, 1:n_int))
176 ! CC^T is by definition symmetric, implying [CC^T]^{-1} also is
177 call lalg_gemm(mesh%np, n_int, n_int, 1.0_real64, p_r_mu, cct, 0.0_real64, isdf_vectors)
178 if (debug%info) call output_matrix(namespace, "isdf_serial.txt", isdf_vectors)
179 safe_deallocate_a(cct)
180
181 ! Rebuild P_r_mu, with occupation numbers absorbed into it. Used in construction of W_ace
182 call construct_density_matrix_with_occ_packed(st, phi, phi_mu, p_r_mu)
183 safe_deallocate_a(phi)
184 if (debug%info) call output_matrix(namespace, "OccP_r_mu_serial.txt", p_r_mu)
185
186 pop_sub_with_profile(isdf_serial_interpolation_vectors)
187
189
190
194 subroutine collate_batches_get_state(mesh, st, max_state, psi)
195 class(mesh_t), intent(in ) :: mesh
196 type(states_elec_t), intent(in ) :: st
197 integer, intent(in ) :: max_state
198 real(real64), allocatable, intent(out) :: psi(:, :)
199
200 integer :: istate, ib, ist, minst, maxst, block_end
201
202 push_sub_with_profile(collate_batches_get_state)
203
204 assert(max_state > 0 .and. max_state <= st%nst)
205
206 safe_allocate(psi(1:max_state, 1:mesh%np))
207 block_end = st%group%iblock(max_state)
208
209 istate = 0
210 do ib = 1, block_end
211 ! Normalisation did not affect condition number of CC^T matrix
212 !call dmesh_batch_normalize(mesh, st%group%psib(ib, ik))
213 minst = states_elec_block_min(st, ib)
214 maxst = min(states_elec_block_max(st, ib), max_state)
215 do ist = minst, maxst
216 istate = istate + 1
217 call states_elec_get_state(st, mesh, st%d%dim, ist, ik, psi(istate, :))
218 enddo
219 enddo
220
221 pop_sub_with_profile(collate_batches_get_state)
222
223 end subroutine collate_batches_get_state
224
225
227 subroutine sample_phi_at_centroids(phi_r, indices, phi_mu)
228 real(real64), contiguous, intent(in ) :: phi_r(:, :)
229 integer(int64), contiguous, intent(in ) :: indices(:)
230 real(real64), contiguous, intent(out) :: phi_mu(:, :)
231
232 integer :: ic, is, nst, n_int
233 integer(int64) :: ipg
234
235 push_sub_with_profile(sample_phi_at_centroids)
236
237 write(message(1),'(a)') "ISDF Serial: Sampling phi(r) at mu"
238 call messages_info(1, debug_only=.true.)
239
240 nst = size(phi_r, 1)
241 assert(size(phi_mu, 1) == nst)
242
243 n_int = size(indices)
244 assert(size(phi_mu, 2) == n_int)
245
246 do ic = 1, n_int
247 ipg = indices(ic)
248 do is = 1, nst
249 phi_mu(is, ic) = phi_r(is, ipg)
250 enddo
251 enddo
252
253 pop_sub_with_profile(sample_phi_at_centroids)
254
255 end subroutine sample_phi_at_centroids
256
257
268 subroutine construct_zct(zct)
269 real(real64), contiguous, intent(inout) :: zct(:, :)
270 ! Out: Contraction of Z and C^T matrices == element-wise square of quasi-density matrix
271
272 integer :: i, j
273
274 push_sub_with_profile(construct_zct)
275
276 write(message(1),'(a)') "ISDF Serial: Constructing ZC^T"
277 call messages_info(1, debug_only=.true.)
278
279 !$omp parallel do collapse(2)
280 do j = 1, size(zct, 2)
281 do i = 1, size(zct, 1)
282 zct(i, j) = zct(i, j)**2
283 end do
284 enddo
285 !$omp end parallel do
286
287 pop_sub_with_profile(construct_zct)
288
289 end subroutine construct_zct
290
291
294 subroutine construct_cct(indices, zct, cct)
295 integer(int64), contiguous, intent(in ) :: indices(:)
296 real(real64), contiguous, intent(in ) :: zct(:, :)
297
298 real(real64), contiguous, intent(out) :: cct(:, :)
299
300 integer(int64) :: ipg
301 integer :: i_mu, i_nu, n_int
302
303 push_sub_with_profile(construct_cct)
304
305 write(message(1),'(a)') "ISDF Serial: Constructing CC^T by sampling ZC^T"
306 call messages_info(1, debug_only=.true.)
307
308 n_int = size(indices)
309 assert(all(shape(cct) == [n_int, n_int]))
310 assert(size(zct, 1) > n_int)
311 assert(size(zct, 2) == n_int)
312
313 ! Mask ZC^T to obtain CC^T
314 do i_nu = 1, n_int
315 do i_mu = 1, n_int
316 ipg = indices(i_mu)
317 cct(i_mu, i_nu) = zct(ipg, i_nu)
318 enddo
319 enddo
320
321 pop_sub_with_profile(construct_cct)
323 end subroutine construct_cct
324
325
334 subroutine construct_density_matrix_packed(phi, phi_mu, P_r_mu)
335 real(real64), contiguous, intent(in ) :: phi(:, :)
336 ! of shape (m_states, np)
337 real(real64), contiguous, intent(in ) :: phi_mu(:, :)
338
339 real(real64), contiguous, intent(out) :: P_r_mu(:, :)
340
341 integer :: np
342 integer :: n_int
343 integer :: m_states
344
345 push_sub_with_profile(construct_density_matrix_packed)
346
347 write(message(1),'(a)') "ISDF Serial: Constructing P_r_mu"
348 call messages_info(1, debug_only=.true.)
349
350 m_states = size(phi, 1)
351 np = size(phi, 2)
352 n_int = size(phi_mu, 2)
353
354 assert(size(phi_mu, 1) == m_states)
355 assert(size(p_r_mu, 1) == np)
356 assert(size(p_r_mu, 2) == n_int)
357
358 ! Contract over the state index, P = phi^T @ phi_mu, with shape (np, m_states) (m_states, n_int)
359 call lalg_gemm(phi, phi_mu, p_r_mu, transa='T')
360
361 pop_sub_with_profile(construct_density_matrix_packed)
362
364
365
366 subroutine construct_density_matrix_with_occ_packed(st, phi, phi_mu, P_r_mu)
367 type(states_elec_t), intent(in ) :: st
368 real(real64), intent(in ) :: phi(:, :)
369 ! of shape (m_states, np)
370 real(real64), intent(in ) :: phi_mu(:, :)
371
372 real(real64), intent(out) :: P_r_mu(:, :)
373
374 integer :: np, n_int, m_states, imu, ist
375 real(real64), allocatable :: focc_phi_mu(:, :)
376
377 push_sub_with_profile(construct_p_with_occ_packed)
378
379 write(message(1),'(a)') "ISDF Serial: Constructing P_r_mu with occupations"
380 call messages_info(1, debug_only=.true.)
381
382 m_states = size(phi_mu, 1)
383 np = size(phi, 2)
384 n_int = size(phi_mu, 2)
385
386 assert(size(phi, 1) == m_states)
387 assert(size(p_r_mu, 1) == np)
388 assert(size(p_r_mu, 2) == n_int)
390 ! Element-wise multiply the occupations with the smaller of the two arrays
391 safe_allocate(focc_phi_mu(1:m_states, 1:n_int))
392 do imu = 1, n_int
393 do ist = 1, m_states
394 focc_phi_mu(ist, imu) = st%kweights(ik) * st%occ(ist, ik) * phi_mu(ist, imu)
395 enddo
396 enddo
397
398 ! Contract over the state index, P = phi^T @ focc_phi_mu, with shape (np, m_states) (m_states, n_int)
399 call lalg_gemm(phi, focc_phi_mu, p_r_mu, transa='T')
400 safe_deallocate_a(focc_phi_mu)
401
402 pop_sub_with_profile(construct_p_with_occ_packed)
403
405
406
408 subroutine quantify_error_and_visualise(isdf, namespace, st, space, mesh, ions, indices, isdf_vectors, output_cubes)
409 type(isdf_options_t), intent(in) :: isdf
410 type(namespace_t), intent(in) :: namespace
411 type(states_elec_t), intent(in) :: st
412 class(space_t), intent(in) :: space
413 class(mesh_t), intent(in) :: mesh
414 class(ions_t), pointer, intent(in) :: ions
415 integer(int64), contiguous, intent(in) :: indices(:)
416 real(real64), allocatable, intent(inout) :: isdf_vectors(:, :)
417 logical, intent(in) :: output_cubes
418
419 real(real64), allocatable :: product_basis(:, :), approx_product_basis(:, :)
420 real(real64), allocatable :: phi(:, :), phi_mu(:, :)
421 real(real64), allocatable :: product_error(:)
422 integer :: n_occ, n_products, n_int, i, j, ij, unit
423 real(real64) :: mean_error
424
425 push_sub_with_profile(quantify_error_and_visualise)
426
427 write(message(1),'(a)') "ISDF Serial: Computing exact pair products"
428 call messages_info(1, debug_only=.true.)
430 ! Rebuild phi matrix
431 n_occ = isdf%n_ks_states
432 call collate_batches_get_state(mesh, st, n_occ, phi)
433 assert(size(phi, 2) == mesh%np)
434
435 n_products = n_occ * n_occ
436 safe_allocate(product_basis(1:n_products, 1:mesh%np))
437 call column_wise_khatri_rao_product(phi, phi, product_basis)
438
439 if (output_cubes) then
440 call generate_product_state_cubes(namespace, space, mesh, ions, "exact_product_", &
441 product_basis)
442 endif
443
444 write(message(1),'(a)') "ISDF Serial Test: Computing approximate pair products"
445 call messages_info(1, namespace=namespace, debug_only=.true.)
446
447 ! Rebuild phi_mu, again only for occupied states
448 n_int = size(indices)
449 safe_allocate(phi_mu(1:n_occ, 1:n_int))
450 call sample_phi_at_centroids(phi, indices, phi_mu)
451 safe_deallocate_a(phi)
452
453 safe_allocate(approx_product_basis(1:n_products, 1:mesh%np))
454 call approximate_pair_products(phi_mu, isdf_vectors, approx_product_basis)
455 ! if (debug%info) call output_matrix(namespace, "approx_product_blas.txt", approx_product_basis)
456
457 safe_deallocate_a(phi_mu)
458 safe_deallocate_a(isdf_vectors)
459
460 if (output_cubes) then
461 call generate_product_state_cubes(namespace, space, mesh, ions, "approx_product_", &
462 approx_product_basis)
463 endif
464
465 ! Quantify the error
466 safe_allocate(product_error(1:n_products))
467 call error_in_product_basis(mesh, product_basis, approx_product_basis, product_error, mean_error)
468 safe_deallocate_a(product_basis)
469 safe_deallocate_a(approx_product_basis)
470
471 if (mpi_world%is_root()) then
472 open(newunit=unit, file="isdf_error_serial.txt")
473 write(unit, *) 'Mean error', mean_error
474 ij = 0
475 do i = 1, n_occ
476 do j = 1, n_occ
477 ij = ij + 1
478 write(unit, *) i, j, product_error(ij)
479 enddo
480 enddo
481 close(unit)
482 endif
483
484 safe_deallocate_a(product_error)
485
486 pop_sub_with_profile(quantify_error_and_visualise)
487
488 end subroutine quantify_error_and_visualise
489
490
501 subroutine approximate_pair_products(psi_mu, zeta, product_basis)
502 real(real64), contiguous, intent(in ) :: psi_mu(:, :)
503 real(real64), contiguous, intent(in ) :: zeta(:, :)
504 real(real64), contiguous, intent(out) :: product_basis(:, :)
505
506 real(real64), allocatable :: psi_ij_mu(:, :)
507 integer :: mn_states, n_int, np
508
509 push_sub_with_profile(approximate_pair_products)
510
511 mn_states = size(psi_mu, 1)**2
512 np = size(zeta, 1)
513 n_int = size(zeta, 2)
514
515 assert(size(product_basis, 1) == mn_states)
516 assert(size(product_basis, 2) == np)
517
518 safe_allocate(psi_ij_mu(1:mn_states, 1:n_int))
519 call column_wise_khatri_rao_product(psi_mu, psi_mu, psi_ij_mu)
520
521 ! Contract product_basis = [psi_ij_mu] [zeta]^T over interpolation vector index
522 call lalg_gemm(psi_ij_mu, zeta, product_basis, transb='T')
523
524 safe_deallocate_a(psi_ij_mu)
525
526 pop_sub_with_profile(approximate_pair_products)
527
528 end subroutine approximate_pair_products
529
530
538 subroutine error_in_product_basis(mesh, product_basis, approx_product_basis, error, mean_error)
539 class(mesh_t), intent(in ) :: mesh
540 real(real64), contiguous, intent(in ) :: product_basis(:, :)
541 real(real64), contiguous, intent(in ) :: approx_product_basis(:, :)
542
543 real(real64), contiguous, intent(out) :: error(:)
544 real(real64), intent(out) :: mean_error
545
546 integer :: mn_states, np, ij, ip
547
548 push_sub_with_profile(error_in_product_basis)
549
550 mn_states = size(product_basis, 1)
551 np = size(product_basis, 2)
552
553 ! product_basis shape is not as expected
554 assert(mesh%np == np)
555
556 ! Two arrays should be the same dimensions
557 assert(all(shape(product_basis) == shape(approx_product_basis)))
558
559 ! error should be allocated, and with the correct size
560 assert(size(error) == mn_states)
561
562 ! Initialise error with first point from the grid
563 do ij = 1, mn_states
564 error(ij) = (product_basis(ij, 1) - approx_product_basis(ij, 1))**2
565 enddo
566
567 do ip = 2, np
568 do ij = 1, mn_states
569 error(ij) = error(ij) + (product_basis(ij, ip) - approx_product_basis(ij, ip))**2
570 enddo
571 enddo
572
573 mean_error = 0.0_real64
574 do ij = 1, mn_states
575 error(ij) = sqrt(mesh%volume_element * error(ij))
576 mean_error = mean_error + error(ij)
577 enddo
578
579 mean_error = mean_error / real(mn_states, real64)
580
581 pop_sub_with_profile(error_in_product_basis)
582
583 end subroutine error_in_product_basis
584
585
587 subroutine generate_product_state_cubes(namespace, space, mesh, ions, file_prefix, data, limits)
588 type(namespace_t), intent(in) :: namespace
589 class(space_t), intent(in) :: space
590 class(mesh_t), intent(in) :: mesh
591 class(ions_t), pointer, intent(in) :: ions
592 character(len=*), intent(in) :: file_prefix
593 real(real64), contiguous, intent(in) :: data(:, :)
594 integer, optional, intent(in) :: limits(2)
595
596 integer :: m_states, limit_j, limit_i, i, j, ij, ierr
597 real(real64) :: size_data
598 character(len=4) :: i_char, j_char
599 character(len=120) :: file_name
600
601 ! product basis size is currently defined as (m_states * m_states)
602 size_data = real(size(data, 1), real64)
603 m_states = int(sqrt(size_data))
604
605 if (present(limits)) then
606 limit_j = limits(1)
607 limit_i = limits(2)
608 else
609 limit_j = m_states
610 limit_i = m_states
611 endif
612
613 do i = 1, limit_i
614 do j = 1, limit_j
615 ij = j + (i - 1) * m_states
616 write(i_char, '(I4)') i
617 write(j_char, '(I4)') j
618 file_name = trim(adjustl(file_prefix)) // trim(adjustl(i_char)) // '_' // trim(adjustl(j_char))
619 call dio_function_output(option__outputformat__cube, "./cubes", trim(adjustl(file_name)), namespace, space, mesh, &
620 data(ij,:) , unit_one, ierr, pos=ions%pos, atoms=ions%atom)
621 enddo
622 enddo
623
624 end subroutine generate_product_state_cubes
625
626
640 subroutine isdf_serial_ace_compute_potentials(exxop, namespace, space, mesh, st, Vx_on_st, kpoints)
641 type(exchange_operator_t), intent(in ) :: exxop
642 ! ISDF interpolation points, and cam parameters.
643 ! An ISDF instance is not passed directly so this API is consistent with the other "compute_potential" routines.
644 type(namespace_t), intent(in ) :: namespace
645 class(space_t), intent(in ) :: space
646 ! with the existing routines
647 class(mesh_t), intent(in ) :: mesh
648 type(states_elec_t), intent(in ) :: st
649 type(kpoints_t), intent(in ) :: kpoints
650 ! with the existing routines
651
652 type(states_elec_t), intent(out) :: Vx_on_st
653
654 real(real64), allocatable :: psi_mu(:, :), P_r_mu(:, :), W_ace(:, :), isdf_vectors(:, :)
655 integer(int64), allocatable :: indices(:)
656
658
659 ! TODO(Alex) Issue #1195 Extend ISDF to spin-polarised and periodic systems
660 assert(kpoints%gamma_only())
661 assert(.not. space%is_periodic())
662 assert(st%d%nspin == 1)
663
664 indices = exxop%isdf%centroids%global_mesh_indices()
665
666 call isdf_serial_interpolation_vectors(exxop%isdf, namespace, mesh, st, &
667 indices, psi_mu, p_r_mu, isdf_vectors)
668
669 safe_allocate(w_ace(1:mesh%np, exxop%isdf%n_ks_states))
670 call isdf_serial_ace_w_unpacked(namespace, p_r_mu, isdf_vectors, psi_mu, exxop%psolver, exxop%cam, st, w_ace)
671 safe_deallocate_a(psi_mu)
672 safe_deallocate_a(p_r_mu)
673 safe_deallocate_a(isdf_vectors)
674
675 call isdf_serial_ace_batch_w(exxop%isdf, st, w_ace, vx_on_st)
676 safe_deallocate_a(w_ace)
677
679
681
690 subroutine isdf_serial_ace_w_unpacked(namespace, P_r_mu, isdf_vectors, psi_mu, poisson_solver, cam, st, W_ace)
691 type(namespace_t), intent(in ) :: namespace
692 real(real64), intent(in ), contiguous :: P_r_mu(:, :)
693 real(real64), intent(in ), contiguous :: isdf_vectors(:, :)
694 real(real64), intent(in ), contiguous :: psi_mu(:, :)
695 type(poisson_t), intent(in ) :: poisson_solver
696 type(xc_cam_t), intent(in) :: cam
697 type(states_elec_t), intent(in ) :: st
698
699 real(real64), intent(out), contiguous :: W_ace(:, :)
700
701 integer :: ip, i_mu, ist, np, n_int, nst
702 real(real64) :: psi_ist_mu
703 real(real64), allocatable :: V_r_nu(:, :)
704 logical :: use_external_kernel
705 real(real64) :: exx_weight
706 real(real64) :: weight
707
708 push_sub_with_profile(isdf_serial_ace_w_unpacked)
709
710 ! Number of states defines the number used in ISDF, which is typically ~ N occupied states
711 nst = size(psi_mu, 1)
712 np = size(p_r_mu, 1)
713 n_int = size(p_r_mu, 2)
714
715 assert(all(shape(p_r_mu) == shape(isdf_vectors)))
716 assert(size(psi_mu, 2) == n_int)
717 assert(size(w_ace, 1) == np)
718 ! Implies a size issue with either W_ace or psi_mu
719 assert(size(w_ace, 2) == nst)
720
721 use_external_kernel = (st%nik > st%d%spin_channels .or. cam%omega > m_epsilon)
722 if (use_external_kernel) then
723 message(1) = "External kernel not supported in ISDF"
724 call messages_fatal(1)
725 endif
726 exx_weight = cam%alpha
727 weight = exx_weight / st%smear%el_per_state
728
729 safe_allocate(v_r_nu(1:np, 1:n_int))
730 call profiling_in('isdf_potential')
731 do i_mu = 1, n_int
732 call dpoisson_solve(poisson_solver, namespace, v_r_nu(:, i_mu), isdf_vectors(:, i_mu), all_nodes=.false.)
733 enddo
734 call profiling_out("isdf_potential")
736 write(message(1),'(a)') "ISDF: Writing V from isdf_ace_w_unpacked"
737 call messages_info(1, namespace=namespace, debug_only=.true.)
738
739 ! Initialise elements of W_ace with data from the first interpolation point
740 i_mu = 1
741 do ist = 1, nst
742 psi_ist_mu = weight * psi_mu(ist, i_mu)
743 do ip = 1, np
744 w_ace(ip, ist) = - (p_r_mu(ip, i_mu) * v_r_nu(ip, i_mu) * psi_ist_mu)
745 enddo
746 enddo
747
748 ! Construct W_ace
749 do i_mu = 2, n_int
750 do ist = 1, nst
751 psi_ist_mu = weight * psi_mu(ist, i_mu)
752 do ip = 1, np
753 w_ace(ip, ist) = w_ace(ip, ist) - (p_r_mu(ip, i_mu) * v_r_nu(ip, i_mu) * psi_ist_mu)
754 enddo
755 enddo
756 enddo
757
758 safe_deallocate_a(v_r_nu)
759
760 pop_sub_with_profile(isdf_serial_ace_w_unpacked)
761
762 end subroutine isdf_serial_ace_w_unpacked
763
764
768 subroutine isdf_serial_ace_batch_w(isdf, st, W_ace, Vx_on_st)
769 type(isdf_options_t), intent(in ) :: isdf
770 type(states_elec_t), intent(in ) :: st
771 real(real64), intent(in ), contiguous :: w_ace(:, :)
772
773 type(states_elec_t), intent(out) :: vx_on_st
774
775 integer :: ist, ib, ist_b, np, max_state, minst, maxst, block_end, block_size
776
777 push_sub_with_profile(isdf_serial_ace_batch_w)
778
779 assert(size(w_ace, 2) == isdf%n_ks_states)
780 assert(st%d%dim == 1)
781
782 ! Copy memory layout, without data
783 call states_elec_copy(vx_on_st, st)
784 call states_elec_set_zero(vx_on_st)
785 np = size(w_ace, 1)
786
787 ! Ensure we do not go beyond the total number of occupied states
788 max_state = min(isdf%n_ks_states, st%st_end)
789 block_end = min(st%group%block_end, st%group%iblock(max_state))
790
791 do ib = st%group%block_start, block_end
792 minst = states_elec_block_min(st, ib)
793 maxst = min(states_elec_block_max(st, ib), max_state)
794 block_size = maxst - minst + 1
795 ! States in a block
796 do ist_b = 1, block_size
797 ! Global state index
798 ist = minst - 1 + ist_b
799 call batch_set_state(vx_on_st%group%psib(ib, 1), ist_b, np, w_ace(:, ist))
800 enddo
801 enddo
802
803 pop_sub_with_profile(isdf_serial_ace_batch_w)
804
805 end subroutine isdf_serial_ace_batch_w
806
807end module isdf_serial_oct_m
808
809!! Local Variables:
810!! mode: f90
811!! coding: utf-8
812!! End:
There are several ways how to call batch_set_state and batch_get_state:
Definition: batch_ops.F90:218
Matrix-matrix multiplication plus matrix.
Definition: lalg_basic.F90:229
double sqrt(double __x) __attribute__((__nothrow__
This module implements batches of mesh functions.
Definition: batch.F90:135
integer, parameter, public batch_packed
functions are stored in CPU memory, in transposed (packed) order
Definition: batch.F90:286
This module implements common operations on batches of mesh functions.
Definition: batch_ops.F90:118
This module contains interfaces for BLAS routines You should not use these routines directly....
Definition: blas.F90:120
type(debug_t), save, public debug
Definition: debug.F90:158
subroutine, public column_wise_khatri_rao_product(y, x, z)
Column-wise Kronecker product.
real(real64), parameter, public m_epsilon
Definition: global.F90:216
This module implements the underlying real-space grid.
Definition: grid.F90:119
subroutine, public dio_function_output(how, dir, fname, namespace, space, mesh, ff, unit, ierr, pos, atoms, grp, root)
Top-level IO routine for functions defined on the mesh.
Serial prototype for benchmarking and validating ISDF implementation.
subroutine generate_product_state_cubes(namespace, space, mesh, ions, file_prefix, data, limits)
Helper function to output a set of pair product states.
subroutine construct_density_matrix_with_occ_packed(st, phi, phi_mu, P_r_mu)
subroutine collate_batches_get_state(mesh, st, max_state, psi)
Loop over states per block, which makes applying the maximum state limit much simpler Use this to com...
subroutine isdf_serial_ace_w_unpacked(namespace, P_r_mu, isdf_vectors, psi_mu, poisson_solver, cam, st, W_ace)
Compute the action of the exchange potential on KS states for adaptively-compressed exchange.
subroutine sample_phi_at_centroids(phi_r, indices, phi_mu)
Sample KS states at centroid points.
subroutine, public isdf_serial_interpolation_vectors(isdf, namespace, mesh, st, indices, phi_mu, P_r_mu, isdf_vectors)
Construct interpolative separable density fitting (ISDF) vectors and other intermediate quantities re...
subroutine error_in_product_basis(mesh, product_basis, approx_product_basis, error, mean_error)
Quantify the error in the product basis expansion.
subroutine, public isdf_serial_ace_compute_potentials(exxop, namespace, space, mesh, st, Vx_on_st, kpoints)
ISDF wrapper computing interpolation points and vectors, which are used to build the potential used ...
subroutine construct_zct(zct)
Construct the product of Z and C matrices from the element-wise product of the quasi-density matrix.
subroutine construct_cct(indices, zct, cct)
Construct the product from by masking the first dimension of .
subroutine construct_density_matrix_packed(phi, phi_mu, P_r_mu)
@ brief Construct the density matrix with shape (np, n_int). Denoted packed, because it expects phi i...
subroutine approximate_pair_products(psi_mu, zeta, product_basis)
Construct a set of approximate pair products using the ISDF interpolation vectors.
subroutine, public quantify_error_and_visualise(isdf, namespace, st, space, mesh, ions, indices, isdf_vectors, output_cubes)
Wrapper for quantifying the error in the expansion of the product basis.
subroutine isdf_serial_ace_batch_w(isdf, st, W_ace, Vx_on_st)
Put the bare array representation of W into a batch.
subroutine, public output_matrix(namespace, fname, matrix)
Helper routine to output a 2D matrix.
Definition: isdf_utils.F90:151
This module is intended to contain "only mathematical" functions and procedures.
Definition: math.F90:117
logical function, public is_symmetric(a, tol)
Check if a 2D array is symmetric.
Definition: math.F90:1452
This module defines functions over batches of mesh functions.
Definition: mesh_batch.F90:118
This module defines the meshes, which are used in Octopus.
Definition: mesh.F90:120
subroutine, public messages_not_implemented(feature, namespace)
Definition: messages.F90:1068
character(len=256), dimension(max_lines), public message
to be output by fatal, warning
Definition: messages.F90:162
subroutine, public messages_fatal(no_lines, only_root_writes, namespace)
Definition: messages.F90:410
subroutine, public messages_info(no_lines, iunit, debug_only, stress, all_nodes, namespace)
Definition: messages.F90:594
type(mpi_grp_t), public mpi_world
Definition: mpi.F90:272
subroutine, public dpoisson_solve(this, namespace, pot, rho, all_nodes, kernel, reset)
Calculates the Poisson equation. Given the density returns the corresponding potential.
Definition: poisson.F90:862
subroutine, public profiling_out(label)
Increment out counter and sum up difference between entry and exit time.
Definition: profiling.F90:631
subroutine, public profiling_in(label, exclude)
Increment in counter and save entry time.
Definition: profiling.F90:554
pure logical function, public states_are_real(st)
integer pure function, public states_elec_block_max(st, ib)
return index of last state in block ib
subroutine, public states_elec_copy(stout, stin, exclude_wfns, exclude_eigenval, special)
make a (selective) copy of a states_elec_t object
integer pure function, public states_elec_block_min(st, ib)
return index of first state in block ib
subroutine, public states_elec_set_zero(st)
Explicitly set all wave functions in the states to zero.
This module provides routines for communicating states when using states parallelization.
This module defines the unit system, used for input and output.
type(unit_t), public unit_one
some special units required for particular quantities
Describes mesh distribution to nodes.
Definition: mesh.F90:187
The states_elec_t class contains all electronic wave functions.
Coulomb-attenuating method parameters, used in the partitioning of the Coulomb potential into a short...
Definition: xc_cam.F90:141
int true(void)