1use async_trait::async_trait;
10use chrono::Duration;
11use mas_data_model::Clock;
12use mas_storage::queue::{QueueWorkerRepository, Worker};
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20pub struct PgQueueWorkerRepository<'c> {
22    conn: &'c mut PgConnection,
23}
24
25impl<'c> PgQueueWorkerRepository<'c> {
26    #[must_use]
29    pub fn new(conn: &'c mut PgConnection) -> Self {
30        Self { conn }
31    }
32}
33
34#[async_trait]
35impl QueueWorkerRepository for PgQueueWorkerRepository<'_> {
36    type Error = DatabaseError;
37
38    #[tracing::instrument(
39        name = "db.queue_worker.register",
40        skip_all,
41        fields(
42            worker.id,
43            db.query.text,
44        ),
45        err,
46    )]
47    async fn register(
48        &mut self,
49        rng: &mut (dyn RngCore + Send),
50        clock: &dyn Clock,
51    ) -> Result<Worker, Self::Error> {
52        let now = clock.now();
53        let worker_id = Ulid::from_datetime_with_source(now.into(), rng);
54        tracing::Span::current().record("worker.id", tracing::field::display(worker_id));
55
56        sqlx::query!(
57            r#"
58                INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)
59                VALUES ($1, $2, $2)
60            "#,
61            Uuid::from(worker_id),
62            now,
63        )
64        .traced()
65        .execute(&mut *self.conn)
66        .await?;
67
68        Ok(Worker { id: worker_id })
69    }
70
71    #[tracing::instrument(
72        name = "db.queue_worker.heartbeat",
73        skip_all,
74        fields(
75            %worker.id,
76            db.query.text,
77        ),
78        err,
79    )]
80    async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
81        let now = clock.now();
82        let res = sqlx::query!(
83            r#"
84                UPDATE queue_workers
85                SET last_seen_at = $2
86                WHERE queue_worker_id = $1 AND shutdown_at IS NULL
87            "#,
88            Uuid::from(worker.id),
89            now,
90        )
91        .traced()
92        .execute(&mut *self.conn)
93        .await?;
94
95        DatabaseError::ensure_affected_rows(&res, 1)?;
97
98        Ok(())
99    }
100
101    #[tracing::instrument(
102        name = "db.queue_worker.shutdown",
103        skip_all,
104        fields(
105            %worker.id,
106            db.query.text,
107        ),
108        err,
109    )]
110    async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
111        let now = clock.now();
112        let res = sqlx::query!(
113            r#"
114                UPDATE queue_workers
115                SET shutdown_at = $2
116                WHERE queue_worker_id = $1
117            "#,
118            Uuid::from(worker.id),
119            now,
120        )
121        .traced()
122        .execute(&mut *self.conn)
123        .await?;
124
125        DatabaseError::ensure_affected_rows(&res, 1)?;
126
127        let res = sqlx::query!(
129            r#"
130                DELETE FROM queue_leader
131                WHERE queue_worker_id = $1
132            "#,
133            Uuid::from(worker.id),
134        )
135        .traced()
136        .execute(&mut *self.conn)
137        .await?;
138
139        if res.rows_affected() > 0 {
141            sqlx::query!(
142                r#"
143                    NOTIFY queue_leader_stepdown
144                "#,
145            )
146            .traced()
147            .execute(&mut *self.conn)
148            .await?;
149        }
150
151        Ok(())
152    }
153
154    #[tracing::instrument(
155        name = "db.queue_worker.shutdown_dead_workers",
156        skip_all,
157        fields(
158            db.query.text,
159        ),
160        err,
161    )]
162    async fn shutdown_dead_workers(
163        &mut self,
164        clock: &dyn Clock,
165        threshold: Duration,
166    ) -> Result<(), Self::Error> {
167        let now = clock.now();
171        sqlx::query!(
172            r#"
173                UPDATE queue_workers
174                SET shutdown_at = $1
175                WHERE shutdown_at IS NULL
176                  AND last_seen_at < $2
177            "#,
178            now,
179            now - threshold,
180        )
181        .traced()
182        .execute(&mut *self.conn)
183        .await?;
184
185        Ok(())
186    }
187
188    #[tracing::instrument(
189        name = "db.queue_worker.remove_leader_lease_if_expired",
190        skip_all,
191        fields(
192            db.query.text,
193        ),
194        err,
195    )]
196    async fn remove_leader_lease_if_expired(
197        &mut self,
198        _clock: &dyn Clock,
199    ) -> Result<(), Self::Error> {
200        sqlx::query!(
203            r#"
204                DELETE FROM queue_leader
205                WHERE expires_at < NOW()
206            "#,
207        )
208        .traced()
209        .execute(&mut *self.conn)
210        .await?;
211
212        Ok(())
213    }
214
215    #[tracing::instrument(
216        name = "db.queue_worker.try_get_leader_lease",
217        skip_all,
218        fields(
219            %worker.id,
220            db.query.text,
221        ),
222        err,
223    )]
224    async fn try_get_leader_lease(
225        &mut self,
226        clock: &dyn Clock,
227        worker: &Worker,
228    ) -> Result<bool, Self::Error> {
229        let now = clock.now();
230        let res = sqlx::query!(
239            r#"
240                INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)
241                VALUES ($1, NOW() + INTERVAL '5 seconds', $2)
242                ON CONFLICT (active)
243                DO UPDATE SET expires_at = EXCLUDED.expires_at
244                WHERE queue_leader.queue_worker_id = $2
245            "#,
246            now,
247            Uuid::from(worker.id)
248        )
249        .traced()
250        .execute(&mut *self.conn)
251        .await?;
252
253        let am_i_the_leader = res.rows_affected() == 1;
256
257        Ok(am_i_the_leader)
258    }
259}