1
use crate::{PRes, PersyError};
2
use std::{
3
    collections::{hash_map::Entry, HashMap},
4
    sync::{Arc, Condvar, Mutex, MutexGuard},
5
    time::Duration,
6
};
7

8
struct RwLockVar {
9
    write: bool,
10
    read_count: u32,
11
    cond: Arc<Condvar>,
12
}
13
impl RwLockVar {
14 1
    fn new_write() -> RwLockVar {
15 1
        RwLockVar {
16
            write: true,
17
            read_count: 0,
18 1
            cond: Arc::new(Condvar::new()),
19
        }
20 1
    }
21

22 1
    fn new_read() -> RwLockVar {
23 1
        RwLockVar {
24
            write: false,
25
            read_count: 1,
26 1
            cond: Arc::new(Condvar::new()),
27
        }
28 1
    }
29

30 1
    fn inc_read(&mut self) {
31 1
        self.read_count += 1;
32 1
    }
33 1
    fn dec_read(&mut self) -> bool {
34 1
        self.read_count -= 1;
35 1
        self.read_count == 0
36 1
    }
37
}
38

39
pub struct RwLockManager<T>
40
where
41
    T: std::cmp::Eq,
42
    T: std::hash::Hash,
43
    T: Clone,
44
{
45
    locks: Mutex<HashMap<T, RwLockVar>>,
46
}
47

48
impl<T> Default for RwLockManager<T>
49
where
50
    T: std::cmp::Eq,
51
    T: std::hash::Hash,
52
    T: Clone,
53
{
54 1
    fn default() -> Self {
55 1
        RwLockManager {
56 1
            locks: Mutex::new(HashMap::<T, RwLockVar>::new()),
57
        }
58 1
    }
59
}
60

61
impl<T> RwLockManager<T>
62
where
63
    T: std::cmp::Eq,
64
    T: std::hash::Hash,
65
    T: Clone,
66
{
67 1
    pub fn lock_all_write(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
68 1
        let mut locked = Vec::with_capacity(to_lock.len());
69 1
        for single in to_lock {
70 1
            let mut lock_manager = self.locks.lock()?;
71 1
            loop {
72 1
                let cond = match lock_manager.entry(single.clone()) {
73 1
                    Entry::Occupied(o) => o.get().cond.clone(),
74 1
                    Entry::Vacant(v) => {
75 1
                        let lock = RwLockVar::new_write();
76 1
                        v.insert(lock);
77 1
                        locked.push(single.clone());
78
                        break;
79 1
                    }
80 0
                };
81 1
                match cond.wait_timeout(lock_manager, timeout) {
82 1
                    Ok((guard, timedout)) => {
83 1
                        lock_manager = guard;
84 1
                        if timedout.timed_out() {
85 1
                            RwLockManager::unlock_all_write_with_guard(&mut lock_manager, &locked);
86 1
                            return Err(PersyError::TransactionTimeout);
87
                        }
88
                    }
89 0
                    Err(x) => {
90
                        // TODO: Check this, it may not be possible to unlock, but may be safe
91
                        // anyway because no-one can actually lock anything.
92 0
                        self.unlock_all_write(&locked)?;
93 0
                        return Err(PersyError::from(x));
94 0
                    }
95
                }
96 1
            }
97 1
        }
98 1
        Ok(())
99 1
    }
100 1
    pub fn lock_all_read(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
101 1
        let mut locked = Vec::with_capacity(to_lock.len());
102 1
        for single in to_lock {
103 1
            let mut lock_manager = self.locks.lock()?;
104 1
            loop {
105
                let cond;
106 1
                match lock_manager.entry(single.clone()) {
107 1
                    Entry::Occupied(mut o) => {
108 1
                        if o.get().write {
109 1
                            cond = o.get().cond.clone();
110
                        } else {
111 1
                            o.get_mut().inc_read();
112 1
                            locked.push(single.clone());
113 1
                            break;
114
                        }
115 1
                    }
116 1
                    Entry::Vacant(v) => {
117 1
                        v.insert(RwLockVar::new_read());
118 1
                        locked.push(single.clone());
119
                        break;
120
                    }
121 1
                };
122 1
                match cond.wait_timeout(lock_manager, timeout) {
123 1
                    Ok((guard, timedout)) => {
124 1
                        lock_manager = guard;
125 1
                        if timedout.timed_out() {
126 1
                            RwLockManager::unlock_all_read_with_guard(&mut lock_manager, &locked);
127 1
                            return Err(PersyError::TransactionTimeout);
128
                        }
129
                    }
130 0
                    Err(x) => {
131
                        // TODO: Check this, it may not be possible to unlock, but may be safe
132
                        // anyway because no-one can actually lock anything.
133 0
                        self.unlock_all_read(&locked)?;
134 0
                        return Err(PersyError::from(x));
135 0
                    }
136
                }
137 1
            }
138 1
        }
139 1
        Ok(())
140 1
    }
141

142 1
    fn unlock_all_read_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
143 1
        for single in to_unlock {
144 1
            if let Entry::Occupied(mut lock) = lock_manager.entry(single.clone()) {
145 1
                if lock.get_mut().dec_read() {
146 1
                    let cond = lock.get().cond.clone();
147 1
                    lock.remove();
148 1
                    cond.notify_all();
149 1
                }
150 1
            }
151 1
        }
152 1
    }
153 1
    pub fn unlock_all_read(&self, to_unlock: &[T]) -> PRes<()> {
154 1
        let mut lock_manager = self.locks.lock()?;
155 1
        RwLockManager::unlock_all_read_with_guard(&mut lock_manager, to_unlock);
156 1
        Ok(())
157 1
    }
158

159 1
    fn unlock_all_write_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
160 1
        for single in to_unlock {
161 1
            if let Some(lock) = lock_manager.remove(single) {
162 1
                lock.cond.notify_all();
163 0
            }
164 1
        }
165 1
    }
166

167 1
    pub fn unlock_all_write(&self, to_unlock: &[T]) -> PRes<()> {
168 1
        let mut lock_manager = self.locks.lock()?;
169 1
        RwLockManager::unlock_all_write_with_guard(&mut lock_manager, to_unlock);
170 1
        Ok(())
171 1
    }
172
}
173
pub struct LockManager<T>
174
where
175
    T: std::cmp::Eq,
176
    T: std::hash::Hash,
177
    T: Clone,
178
{
179
    locks: Mutex<HashMap<T, Arc<Condvar>>>,
180
}
181

182
impl<T> Default for LockManager<T>
183
where
184
    T: std::cmp::Eq,
185
    T: std::hash::Hash,
186
    T: Clone,
187
{
188 1
    fn default() -> Self {
189 1
        LockManager {
190 1
            locks: Mutex::new(HashMap::<T, Arc<Condvar>>::new()),
191
        }
192 1
    }
193
}
194

195
impl<T> LockManager<T>
196
where
197
    T: std::cmp::Eq + std::hash::Hash + Clone,
198
{
199 1
    pub fn lock_all(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
200 1
        let mut locked = Vec::with_capacity(to_lock.len());
201 1
        for single in to_lock {
202 1
            let cond = Arc::new(Condvar::new());
203 1
            let mut lock_manager = self.locks.lock()?;
204 1
            loop {
205 1
                let cond = match lock_manager.entry(single.clone()) {
206 1
                    Entry::Occupied(o) => o.get().clone(),
207 1
                    Entry::Vacant(v) => {
208 1
                        v.insert(cond);
209 1
                        locked.push(single.clone());
210
                        break;
211
                    }
212 0
                };
213 1
                match cond.wait_timeout(lock_manager, timeout) {
214 1
                    Ok((guard, timedout)) => {
215 1
                        lock_manager = guard;
216 1
                        if timedout.timed_out() {
217 1
                            LockManager::unlock_all_with_guard(&mut lock_manager, locked.iter());
218 1
                            return Err(PersyError::TransactionTimeout);
219
                        }
220
                    }
221 0
                    Err(x) => {
222
                        // TODO: Check this, it may not be possible to unlock, but may be safe
223
                        // anyway because no-one can actually lock anything.
224 0
                        self.unlock_all(&locked)?;
225 0
                        return Err(PersyError::from(x));
226 0
                    }
227
                }
228 1
            }
229 1
        }
230 1
        Ok(())
231 1
    }
232

233 1
    fn unlock_all_with_guard<'a, Q: 'a>(
234
        lock_manager: &mut MutexGuard<HashMap<T, Arc<Condvar>>>,
235
        to_unlock: impl Iterator<Item = &'a Q>,
236
    ) where
237
        T: std::borrow::Borrow<Q>,
238
        Q: std::hash::Hash + Eq,
239
    {
240 1
        for single in to_unlock {
241 1
            if let Some(cond) = lock_manager.remove(single) {
242 1
                cond.notify_all();
243 0
            }
244 1
        }
245 1
    }
246

247
    #[inline]
248 1
    pub fn unlock_all<Q>(&self, to_unlock: &[Q]) -> PRes<()>
249
    where
250
        T: std::borrow::Borrow<Q>,
251
        Q: std::hash::Hash + Eq,
252
    {
253 1
        self.unlock_all_iter(to_unlock.iter())
254 1
    }
255

256
    #[inline]
257 1
    pub fn unlock_all_iter<'a, Q: 'a>(&self, to_unlock: impl Iterator<Item = &'a Q>) -> PRes<()>
258
    where
259
        T: std::borrow::Borrow<Q>,
260
        Q: std::hash::Hash + Eq,
261
    {
262 1
        let mut lock_manager = self.locks.lock()?;
263 1
        LockManager::unlock_all_with_guard(&mut lock_manager, to_unlock);
264 1
        Ok(())
265 1
    }
266
}
267

268
#[cfg(test)]
269
mod tests {
270
    use super::{LockManager, RwLockManager};
271
    use std::time::Duration;
272

273
    #[test]
274 1
    fn test_lock_manager_unlock_if_lock_fail() {
275 1
        let manager: LockManager<_> = Default::default();
276 1
        manager.lock_all(&[5], Duration::new(1, 0)).expect("no issue here");
277 1
        assert!(manager.lock_all(&[1, 5], Duration::new(0, 1)).is_err());
278 1
        manager.lock_all(&[1], Duration::new(1, 0)).expect("no issue here");
279 1
        manager.unlock_all(&[1, 5]).expect("no issue here");
280 1
    }
281

282
    #[test]
283 1
    fn test_rw_lock_manager_unlock_if_lock_fail() {
284 1
        let manager: RwLockManager<_> = Default::default();
285 1
        manager
286 1
            .lock_all_write(&[5], Duration::new(1, 0))
287
            .expect("no issue here");
288 1
        assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
289 1
        manager
290 1
            .lock_all_write(&[1], Duration::new(1, 0))
291
            .expect("no issue here");
292 1
        manager.unlock_all_write(&[1, 5]).expect("no issue here");
293

294 1
        manager
295 1
            .lock_all_write(&[5], Duration::new(1, 0))
296
            .expect("no issue here");
297 1
        assert!(manager.lock_all_read(&[1, 5], Duration::new(0, 1)).is_err());
298 1
        manager
299 1
            .lock_all_write(&[1], Duration::new(1, 0))
300
            .expect("no issue here");
301 1
        manager.unlock_all_write(&[1, 5]).expect("no issue here");
302

303 1
        manager.lock_all_read(&[5], Duration::new(1, 0)).expect("no issue here");
304 1
        assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
305 1
        manager.lock_all_read(&[1], Duration::new(1, 0)).expect("no issue here");
306 1
        manager.unlock_all_read(&[1, 5]).expect("no issue here");
307 1
    }
308
}

Read our documentation on viewing source code .

Loading