23 #include <shared_mutex>
25 #include <unordered_map>
27 #include <tbb/parallel_for.h>
28 #include <tbb/task_arena.h>
44 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
49 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
50 const auto& cached_data_itr =
data_cache_.find(key);
54 return num_bytes == cached_data_itr->second->num_bytes;
60 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
61 const auto& cached_data_itr =
data_cache_.find(key);
63 const std::string error_msg =
"Data for key " + key +
" not found in cache.";
64 throw std::runtime_error(error_msg);
66 copyData(reinterpret_cast<int8_t*>(dest_buffer),
67 cached_data_itr->second->data_buffer,
68 cached_data_itr->second->num_bytes);
73 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
74 const auto& cached_data_itr =
data_cache_.find(key);
76 const std::string error_msg{
"Data for key " + key +
" not found in cache."};
77 throw std::runtime_error(error_msg);
79 return *
reinterpret_cast<const T*
>(cached_data_itr->second->data_buffer);
84 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
85 const auto& cached_data_itr =
data_cache_.find(key);
89 return reinterpret_cast<const T* const
>(cached_data_itr->second->data_buffer);
95 const size_t num_elements) {
97 const size_t num_bytes(num_elements *
sizeof(
T));
98 auto cache_data = std::make_shared<CacheDataTf>(num_bytes);
99 copyData(cache_data->data_buffer, reinterpret_cast<int8_t*>(data_buffer), num_bytes);
100 std::unique_lock<std::shared_mutex> write_lock(
cache_mutex_);
101 const auto& cached_data_itr =
data_cache_.find(key);
104 const std::string warning_msg =
105 "Data for key " + key +
" already exists in cache. Replacing.";
106 std::cout << warning_msg << std::endl;
108 cached_data_itr->second.reset();
109 cached_data_itr->second = cache_data;
112 data_cache_.insert(std::make_pair(key, cache_data));
118 void copyData(int8_t*
dest,
const int8_t* source,
const size_t num_bytes)
const {
120 std::memcpy(dest, source, num_bytes);
124 const size_t num_threads =
125 (num_bytes + max_bytes_per_thread - 1) / max_bytes_per_thread;
127 tbb::blocked_range<size_t>(0, num_threads, 1),
128 [&](
const tbb::blocked_range<size_t>& r) {
129 const size_t end_chunk_idx = r.end();
130 for (
size_t chunk_idx = r.begin(); chunk_idx != end_chunk_idx; ++chunk_idx) {
131 const size_t start_byte = chunk_idx * max_bytes_per_thread;
132 const size_t length =
133 std::min(start_byte + max_bytes_per_thread, num_bytes) - start_byte;
134 std::memcpy(dest + start_byte, source + start_byte, length);
139 std::unordered_map<std::string, std::shared_ptr<CacheDataTf>>
data_cache_;
144 template <
typename T>
148 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
153 std::shared_lock<std::shared_mutex> read_lock(
cache_mutex_);
154 const auto& cached_data_itr =
data_cache_.find(key);
156 const std::string error_msg{
"Data for key " + key +
" not found in cache."};
157 throw std::runtime_error(error_msg);
159 return cached_data_itr->second;
163 std::unique_lock<std::shared_mutex> write_lock(
cache_mutex_);
164 const auto& cached_data_itr =
data_cache_.find(key);
167 const std::string warning_msg =
168 "Data for key " + key +
" already exists in cache. Replacing.";
169 std::cout << warning_msg << std::endl;
171 cached_data_itr->second.reset();
172 cached_data_itr->second = data;
bool isKeyCachedAndSameLength(const std::string &key, const size_t num_bytes) const
void copyData(int8_t *dest, const int8_t *source, const size_t num_bytes) const
bool isKeyCached(const std::string &key) const
std::unordered_map< std::string, std::shared_ptr< T > > data_cache_
void putDataForKey(const std::string &key, T *const data_buffer, const size_t num_elements)
static constexpr bool debug_print_
bool isKeyCached(const std::string &key) const
void putDataForKey(const std::string &key, std::shared_ptr< T > const data)
void getDataForKey(const std::string &key, T *dest_buffer) const
std::shared_ptr< T > getDataForKey(const std::string &key) const
const T & getDataRefForKey(const std::string &key) const
const T * getDataPtrForKey(const std::string &key) const
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
const size_t parallel_copy_min_bytes
#define DEBUG_TIMER(name)
std::shared_mutex cache_mutex_
CacheDataTf(const size_t num_bytes)
static constexpr bool debug_print_
std::shared_timed_mutex shared_mutex
std::shared_mutex cache_mutex_
std::unordered_map< std::string, std::shared_ptr< CacheDataTf > > data_cache_