Skip to content

Commit 0aa6c81

Browse files
committed
Revert "Make handle_msg more performant ? #475"
1 parent 6604618 commit 0aa6c81

File tree

2 files changed

+56
-84
lines changed

2 files changed

+56
-84
lines changed

scylla-server/src/rule_structs.rs

Lines changed: 54 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ use serde_with::serde_as;
1212
use serde_with::DurationSeconds;
1313
use std::borrow::Borrow;
1414
use std::hash::Hash;
15-
use std::sync::Arc;
1615
use std::time::Duration;
1716
use tokio::sync::RwLock;
18-
use tokio::task::JoinSet;
1917
use tracing::trace;
2018
use tracing::warn;
2119

@@ -350,7 +348,7 @@ impl RuleManager {
350348

351349
/// Handles a new socket message, returning a RuleNotification for one to many clients if action should be taken
352350
pub async fn handle_msg(
353-
self: &Arc<Self>,
351+
&self,
354352
data: &ClientData,
355353
) -> Result<Option<Vec<(ClientId, RuleNotification)>>, RuleManagerError> {
356354
// Read from topic to rule index and drop lock immediately
@@ -362,90 +360,65 @@ impl RuleManager {
362360
}
363361
};
364362

365-
let data = Arc::new(data.clone());
366-
let notifications = rule_ids
367-
.into_iter()
368-
.fold(JoinSet::new(), |mut set, rule_id| {
369-
let data = data.clone();
370-
let self_ref = self.clone();
371-
set.spawn(async move {
372-
let triggered_fut = async {
373-
if let Some(rule) = self_ref.rules.write().await.get_mut(&rule_id) {
374-
if let Some(triggered) = rule.tick(&data.values) {
375-
Ok(triggered)
376-
} else {
377-
Err(RuleManagerError::RuleFailure)
378-
}
363+
let mut notifications: Vec<(ClientId, RuleNotification)> = Vec::new();
364+
for rule_id in rule_ids {
365+
let (triggered_result, clients_result) = {
366+
// Future for if rule was triggered
367+
let triggered_future = async {
368+
if let Some(rule) = self.rules.write().await.get_mut(&rule_id) {
369+
if let Some(triggered) = rule.tick(&data.values) {
370+
Ok(triggered)
379371
} else {
380-
trace!("Could not find rule in rules map: {}", rule_id);
381-
Err(RuleManagerError::NoMatchingRule)
382-
}
383-
};
384-
385-
// Future for getting subscribed clients
386-
let clients_fut = async {
387-
self_ref
388-
.subscriptions
389-
.read()
390-
.await
391-
.get_left(&rule_id)
392-
.cloned()
393-
};
394-
395-
tokio::pin!(triggered_fut);
396-
tokio::pin!(clients_fut);
397-
398-
// Check which operation finished first
399-
let rule_id = rule_id.clone();
400-
tokio::select! {
401-
triggered_result = &mut triggered_fut => {
402-
match triggered_result {
403-
Ok(true) => (Ok(true), clients_fut.await, rule_id),
404-
Ok(false) => (Ok(false), None, rule_id),
405-
_ => (triggered_result, None, rule_id),
406-
}
407-
},
408-
clients_result = &mut clients_fut => {
409-
match clients_result {
410-
Some(_) => (triggered_fut.await, clients_result, rule_id),
411-
None => (Ok(false), None, rule_id)
412-
}
372+
Err(RuleManagerError::RuleFailure)
413373
}
374+
} else {
375+
trace!("Could not find rule in rules map: {}", rule_id);
376+
Err(RuleManagerError::NoMatchingRule)
414377
}
415-
});
416-
set
417-
})
418-
.join_all()
419-
.await
420-
.into_iter()
421-
.filter_map(|(triggered_res, clients_op, rule_id)| {
422-
let Ok(triggered) = triggered_res else {
423-
return Some(Err(triggered_res.unwrap_err()));
424378
};
425379

426-
if !triggered || clients_op.is_none() {
427-
None
428-
} else {
429-
Some(Ok(clients_op
430-
.unwrap()
431-
.into_iter()
432-
.map(|client| {
433-
(
434-
client,
435-
RuleNotification {
436-
id: rule_id.clone(),
437-
topic: Topic(data.name.clone()),
438-
values: data.values.clone(),
439-
time: data.timestamp,
440-
},
441-
)
442-
})
443-
.collect::<Vec<_>>()))
380+
// Future for getting subscribed clients
381+
let clients_future =
382+
async { self.subscriptions.read().await.get_left(&rule_id).cloned() };
383+
384+
tokio::pin!(triggered_future);
385+
tokio::pin!(clients_future);
386+
387+
// Check which operation finished first
388+
tokio::select! {
389+
triggered_result = &mut triggered_future => {
390+
match triggered_result? {
391+
true => (Ok(true), clients_future.await),
392+
false => (Ok(false), None),
393+
}
394+
},
395+
clients_result = &mut clients_future => {
396+
match clients_result {
397+
Some(_) => (triggered_future.await, clients_result),
398+
None => (Ok(false), None)
399+
}
400+
}
444401
}
445-
})
446-
.flatten()
447-
.flatten()
448-
.collect::<Vec<_>>();
402+
};
403+
404+
let triggered = triggered_result?;
405+
if !triggered || clients_result.is_none() {
406+
continue;
407+
}
408+
409+
// Push notifications for all clients who are subscribed to this rule
410+
for client in clients_result.unwrap() {
411+
notifications.push((
412+
client.clone(),
413+
RuleNotification {
414+
id: rule_id.clone(),
415+
topic: Topic(data.name.clone()),
416+
values: data.values.clone(),
417+
time: data.timestamp,
418+
},
419+
));
420+
}
421+
}
449422

450423
if notifications.is_empty() {
451424
Ok(None)

scylla-server/tests/rule_structs_test.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use chrono::Utc;
22
use scylla_server::rule_structs::*;
33
use scylla_server::ClientData;
4-
use std::sync::Arc;
54
use tokio::task::JoinSet;
65

76
#[tokio::test]
@@ -85,7 +84,7 @@ async fn test_delete_client_success() -> Result<(), RuleManagerError> {
8584

8685
#[tokio::test]
8786
async fn test_handle_msg_rule_triggered() -> Result<(), RuleManagerError> {
88-
let rule_manager = Arc::new(RuleManager::new());
87+
let rule_manager = RuleManager::new();
8988
let client = ClientId("test_client".to_string());
9089

9190
let rule = Rule::new(
@@ -124,7 +123,7 @@ async fn test_handle_msg_rule_triggered() -> Result<(), RuleManagerError> {
124123

125124
#[tokio::test]
126125
async fn test_handle_msg_multiple_clients_same_rule() -> Result<(), RuleManagerError> {
127-
let rule_manager = Arc::new(RuleManager::new());
126+
let rule_manager = RuleManager::new();
128127
let client1 = ClientId("client1".to_string());
129128
let client2 = ClientId("client2".to_string());
130129

0 commit comments

Comments
 (0)