21 #include <boost/algorithm/string.hpp>
24 #include <unordered_map>
26 using namespace Catalog_Namespace;
29 auto origin =
get(session_id);
34 throw std::runtime_error(
"No session with id " + session_id);
44 return boost::iequals(user_name, session_ptr->get_currentUser().userName);
50 return boost::iequals(db_name, session_ptr->getCatalog().getCurrentDB().dbName);
58 const auto dbname = session_ptr->getCatalog().getCurrentDB().dbName;
59 LOG(
INFO) <<
"User " << session_ptr->get_currentUser().userLoggable()
60 <<
" disconnected from database " << dbname
61 <<
" with public_session_id: " << session_ptr->get_public_session_id();
68 int idle_session_duration,
69 int max_session_duration) {
73 time_t last_used_time = session_ptr->get_last_used_time();
74 time_t
start_time = session_ptr->get_start_time();
75 const auto current_session_duration = time(0) - last_used_time;
76 if (current_session_duration > idle_session_duration) {
77 LOG(
INFO) <<
"Session " << session_ptr->get_public_session_id() <<
" idle duration "
78 << current_session_duration <<
" seconds exceeds maximum idle duration "
79 << idle_session_duration <<
" seconds. Invalidating session.";
82 const auto total_session_duration = time(0) -
start_time;
83 if (total_session_duration > max_session_duration) {
84 LOG(
INFO) <<
"Session " << session_ptr->get_public_session_id() <<
" total duration "
85 << total_session_duration
86 <<
" seconds exceeds maximum total session duration "
87 << max_session_duration <<
" seconds. Invalidating session.";
99 return session_ptr->get_currentUser().userName == user_name;
105 return session_ptr->get_public_session_id() == public_id;
107 if (sessions.empty()) {
117 int max_session_duration,
120 : idle_session_duration_(idle_session_duration)
121 , max_session_duration_(max_session_duration)
122 , capacity_(capacity > 0 ? capacity : INT_MAX)
123 , disconnect_callback_(disconnect_callback) {}
126 std::shared_ptr<Catalog>
cat,
129 if (
int(sessions_.size()) >= capacity_) {
130 std::vector<SessionInfoPtr> expired_sessions;
131 for (
auto it = sessions_.begin(); it != sessions_.end(); it++) {
132 if (isSessionExpired(it->second, idle_session_duration_, max_session_duration_)) {
133 expired_sessions.push_back(it->second);
136 for (
auto& session_ptr : expired_sessions) {
138 disconnect_callback_(session_ptr);
139 eraseUnlocked(session_ptr->get_session_id());
140 }
catch (
const std::exception& e) {
141 eraseUnlocked(session_ptr->get_session_id());
146 if (
int(sessions_.size()) < capacity_) {
149 if (sessions_.count(session_id) != 0) {
152 auto session_ptr = std::make_shared<Catalog_Namespace::SessionInfo>(
153 cat, user_meta, device, session_id);
154 sessions_[session_id] = session_ptr;
159 throw std::runtime_error(
"Too many active sessions");
164 auto session_ptr = getUnlocked(session_id);
167 session_ptr, idle_session_duration_, max_session_duration_)) {
169 disconnect_callback_(session_ptr);
170 eraseUnlocked(session_ptr->get_session_id());
171 }
catch (
const std::exception& e) {
172 eraseUnlocked(session_ptr->get_session_id());
177 session_ptr->update_last_used_time();
187 for (
auto it = sessions_.begin(); it != sessions_.end();) {
188 if (predicate(it->second)) {
189 it = sessions_.erase(it);
203 sessions_.erase(session_id);
207 return session_ptr.use_count() > 2;
211 if (
auto session_it = sessions_.find(session_id); session_it != sessions_.end()) {
212 return session_it->second;
221 std::vector<SessionInfoPtr> out;
223 for (
auto& [_, session] : sessions_) {
225 if (predicate(session)) {
226 out.push_back(session);
233 std::unordered_map<std::string, SessionInfoPtr>
sessions_;
242 const std::string& base_path,
244 int idle_session_duration,
245 int max_session_duration,
248 return std::make_unique<CachedSessionStore>(
249 idle_session_duration, max_session_duration, capacity, disconnect_callback);
std::lock_guard< T > lock_guard
void erase(const std::string &session_id)
virtual SessionInfoPtr getUnlocked(const std::string &session_id)=0
std::function< void(SessionInfoPtr &session)> DisconnectCallback
const int max_session_duration_
std::vector< SessionInfoPtr > getIf(std::function< bool(const SessionInfoPtr &)> predicate) override
std::unordered_map< std::string, SessionInfoPtr > sessions_
virtual void eraseUnlocked(const std::string &session_id)=0
const int idle_session_duration_
virtual DisconnectCallback getDisconnectCallback()=0
void eraseUnlocked(const std::string &session_id) override
~CachedSessionStore() override
std::shared_lock< T > shared_lock
This file contains the class specification and related data structures for Catalog.
SessionInfo getSessionCopy(const std::string &session_id)
virtual heavyai::shared_mutex & getLock()=0
bool isSessionExpired(const SessionInfoPtr &session_ptr, int idle_session_duration, int max_session_duration)
virtual bool isSessionInUse(const SessionInfoPtr &session_ptr)=0
virtual void eraseIf(std::function< bool(const SessionInfoPtr &)> predicate)=0
static std::unique_ptr< SessionsStore > create(const std::string &base_path, size_t n_workers, int idle_session_duration, int max_session_duration, int capacity, DisconnectCallback disconnect_callback)
DisconnectCallback getDisconnectCallback() override
const size_t SESSION_ID_LENGTH
virtual std::vector< SessionInfoPtr > getIf(std::function< bool(const SessionInfoPtr &)> predicate)=0
void eraseIf(std::function< bool(const SessionInfoPtr &)> predicate) override
CachedSessionStore(int idle_session_duration, int max_session_duration, int capacity, DisconnectCallback disconnect_callback)
void eraseByUser(const std::string &user_name)
DisconnectCallback disconnect_callback_
SessionInfoPtr add(const Catalog_Namespace::UserMetadata &user_meta, std::shared_ptr< Catalog > cat, ExecutorDeviceType device) override
SessionInfoPtr getUnlocked(const std::string &session_id) override
void disconnect(const std::string session_id)
SessionInfoPtr getByPublicID(const std::string &public_id)
std::shared_timed_mutex shared_mutex
heavyai::shared_mutex & getLock() override
std::vector< SessionInfoPtr > getAllSessions()
bool isSessionInUse(const SessionInfoPtr &session_ptr) override
heavyai::shared_mutex mtx_
void eraseByDB(const std::string &db_name)
std::vector< SessionInfoPtr > getUserSessions(const std::string &user_name)
std::shared_ptr< SessionInfo > SessionInfoPtr