My Project
Loading...
Searching...
No Matches
CuVector.hpp
1/*
2 Copyright 2022-2023 SINTEF AS
3
4 This file is part of the Open Porous Media project (OPM).
5
6 OPM is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10
11 OPM is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with OPM. If not, see <http://www.gnu.org/licenses/>.
18*/
19#ifndef OPM_CUVECTOR_HEADER_HPP
20#define OPM_CUVECTOR_HEADER_HPP
21#include <dune/common/fvector.hh>
22#include <dune/istl/bvector.hh>
23#include <exception>
24#include <fmt/core.h>
25#include <opm/common/ErrorMacros.hpp>
26#include <opm/simulators/linalg/cuistl/detail/CuBlasHandle.hpp>
27#include <opm/simulators/linalg/cuistl/detail/safe_conversion.hpp>
28#include <vector>
29
30
31namespace Opm::cuistl
32{
33
63template <typename T>
64class CuVector
65{
66public:
67 using field_type = T;
68 using size_type = size_t;
69
70
79 CuVector(const CuVector<T>& other);
80
92 explicit CuVector(const std::vector<T>& data);
93
102 CuVector& operator=(const CuVector<T>& other);
103
111 CuVector& operator=(T scalar);
112
120 explicit CuVector(const size_t numberOfElements);
121
122
134 CuVector(const T* dataOnHost, const size_t numberOfElements);
135
139 virtual ~CuVector();
140
144 T* data();
145
149 const T* data() const;
150
158 template <int BlockDimension>
159 void copyFromHost(const Dune::BlockVector<Dune::FieldVector<T, BlockDimension>>& bvector)
160 {
161 // TODO: [perf] vector.dim() can be replaced by bvector.N() * BlockDimension
162 if (detail::to_size_t(m_numberOfElements) != bvector.dim()) {
163 OPM_THROW(std::runtime_error,
164 fmt::format("Given incompatible vector size. CuVector has size {}, \n"
165 "however, BlockVector has N() = {}, and dim = {}.",
166 m_numberOfElements,
167 bvector.N(),
168 bvector.dim()));
169 }
170 const auto dataPointer = static_cast<const T*>(&(bvector[0][0]));
171 copyFromHost(dataPointer, m_numberOfElements);
172 }
173
181 template <int BlockDimension>
182 void copyToHost(Dune::BlockVector<Dune::FieldVector<T, BlockDimension>>& bvector) const
183 {
184 // TODO: [perf] vector.dim() can be replaced by bvector.N() * BlockDimension
185 if (detail::to_size_t(m_numberOfElements) != bvector.dim()) {
186 OPM_THROW(std::runtime_error,
187 fmt::format("Given incompatible vector size. CuVector has size {},\n however, the BlockVector "
188 "has has N() = {}, and dim() = {}.",
189 m_numberOfElements,
190 bvector.N(),
191 bvector.dim()));
192 }
193 const auto dataPointer = static_cast<T*>(&(bvector[0][0]));
194 copyToHost(dataPointer, m_numberOfElements);
195 }
196
204 void copyFromHost(const T* dataPointer, size_t numberOfElements);
205
213 void copyToHost(T* dataPointer, size_t numberOfElements) const;
214
222 void copyFromHost(const std::vector<T>& data);
223
231 void copyToHost(std::vector<T>& data) const;
232
241 CuVector<T>& operator*=(const T& scalar);
242
251 CuVector<T>& axpy(T alpha, const CuVector<T>& y);
252
259 CuVector<T>& operator+=(const CuVector<T>& other);
260
267 CuVector<T>& operator-=(const CuVector<T>& other);
268
277 T dot(const CuVector<T>& other) const;
278
286 T two_norm() const;
287
293 T dot(const CuVector<T>& other, const CuVector<int>& indexSet, CuVector<T>& buffer) const;
294
300 T two_norm(const CuVector<int>& indexSet, CuVector<T>& buffer) const;
301
302
308 T dot(const CuVector<T>& other, const CuVector<int>& indexSet) const;
309
315 T two_norm(const CuVector<int>& indexSet) const;
316
317
322 size_type dim() const;
323
324
329 std::vector<T> asStdVector() const;
330
335 template <int blockSize>
336 Dune::BlockVector<Dune::FieldVector<T, blockSize>> asDuneBlockVector() const
337 {
338 OPM_ERROR_IF(dim() % blockSize != 0,
339 fmt::format("blockSize is not a multiple of dim(). Given blockSize = {}, and dim() = {}",
340 blockSize,
341 dim()));
342
343 Dune::BlockVector<Dune::FieldVector<T, blockSize>> returnValue(dim() / blockSize);
344 copyToHost(returnValue);
345 return returnValue;
346 }
347
348
362 void setZeroAtIndexSet(const CuVector<int>& indexSet);
363
364private:
365 T* m_dataOnDevice = nullptr;
366
367 // Note that we store this as int to make sure we are always cublas compatible.
368 // This gives the added benefit that a size_t to int conversion error occurs during construction.
369 const int m_numberOfElements;
370 detail::CuBlasHandle& m_cuBlasHandle;
371
372 void assertSameSize(const CuVector<T>& other) const;
373 void assertSameSize(int size) const;
374
375 void assertHasElements() const;
376};
377} // namespace Opm::cuistl
378#endif