1use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use mas_data_model::UserRegistrationToken;
9use mas_storage::{
10 Clock, Page, Pagination,
11 user::{UserRegistrationTokenFilter, UserRegistrationTokenRepository},
12};
13use rand::RngCore;
14use sea_query::{Condition, Expr, PostgresQueryBuilder, Query, enum_def};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21 DatabaseInconsistencyError,
22 errors::DatabaseError,
23 filter::{Filter, StatementExt},
24 iden::UserRegistrationTokens,
25 pagination::QueryBuilderExt,
26 tracing::ExecuteExt,
27};
28
29pub struct PgUserRegistrationTokenRepository<'c> {
32 conn: &'c mut PgConnection,
33}
34
35impl<'c> PgUserRegistrationTokenRepository<'c> {
36 pub fn new(conn: &'c mut PgConnection) -> Self {
39 Self { conn }
40 }
41}
42
43#[derive(Debug, Clone, sqlx::FromRow)]
44#[enum_def]
45struct UserRegistrationTokenLookup {
46 user_registration_token_id: Uuid,
47 token: String,
48 usage_limit: Option<i32>,
49 times_used: i32,
50 created_at: DateTime<Utc>,
51 last_used_at: Option<DateTime<Utc>>,
52 expires_at: Option<DateTime<Utc>>,
53 revoked_at: Option<DateTime<Utc>>,
54}
55
56impl Filter for UserRegistrationTokenFilter {
57 #[expect(clippy::too_many_lines)]
58 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
59 sea_query::Condition::all()
60 .add_option(self.has_been_used().map(|has_been_used| {
61 if has_been_used {
62 Expr::col((
63 UserRegistrationTokens::Table,
64 UserRegistrationTokens::TimesUsed,
65 ))
66 .gt(0)
67 } else {
68 Expr::col((
69 UserRegistrationTokens::Table,
70 UserRegistrationTokens::TimesUsed,
71 ))
72 .eq(0)
73 }
74 }))
75 .add_option(self.is_revoked().map(|is_revoked| {
76 if is_revoked {
77 Expr::col((
78 UserRegistrationTokens::Table,
79 UserRegistrationTokens::RevokedAt,
80 ))
81 .is_not_null()
82 } else {
83 Expr::col((
84 UserRegistrationTokens::Table,
85 UserRegistrationTokens::RevokedAt,
86 ))
87 .is_null()
88 }
89 }))
90 .add_option(self.is_expired().map(|is_expired| {
91 if is_expired {
92 Condition::all()
93 .add(
94 Expr::col((
95 UserRegistrationTokens::Table,
96 UserRegistrationTokens::ExpiresAt,
97 ))
98 .is_not_null(),
99 )
100 .add(
101 Expr::col((
102 UserRegistrationTokens::Table,
103 UserRegistrationTokens::ExpiresAt,
104 ))
105 .lt(Expr::val(self.now())),
106 )
107 } else {
108 Condition::any()
109 .add(
110 Expr::col((
111 UserRegistrationTokens::Table,
112 UserRegistrationTokens::ExpiresAt,
113 ))
114 .is_null(),
115 )
116 .add(
117 Expr::col((
118 UserRegistrationTokens::Table,
119 UserRegistrationTokens::ExpiresAt,
120 ))
121 .gte(Expr::val(self.now())),
122 )
123 }
124 }))
125 .add_option(self.is_valid().map(|is_valid| {
126 let valid = Condition::all()
127 .add(
129 Condition::any()
130 .add(
131 Expr::col((
132 UserRegistrationTokens::Table,
133 UserRegistrationTokens::UsageLimit,
134 ))
135 .is_null(),
136 )
137 .add(
138 Expr::col((
139 UserRegistrationTokens::Table,
140 UserRegistrationTokens::TimesUsed,
141 ))
142 .lt(Expr::col((
143 UserRegistrationTokens::Table,
144 UserRegistrationTokens::UsageLimit,
145 ))),
146 ),
147 )
148 .add(
150 Expr::col((
151 UserRegistrationTokens::Table,
152 UserRegistrationTokens::RevokedAt,
153 ))
154 .is_null(),
155 )
156 .add(
158 Condition::any()
159 .add(
160 Expr::col((
161 UserRegistrationTokens::Table,
162 UserRegistrationTokens::ExpiresAt,
163 ))
164 .is_null(),
165 )
166 .add(
167 Expr::col((
168 UserRegistrationTokens::Table,
169 UserRegistrationTokens::ExpiresAt,
170 ))
171 .gte(Expr::val(self.now())),
172 ),
173 );
174
175 if is_valid { valid } else { valid.not() }
176 }))
177 }
178}
179
180impl TryFrom<UserRegistrationTokenLookup> for UserRegistrationToken {
181 type Error = DatabaseInconsistencyError;
182
183 fn try_from(res: UserRegistrationTokenLookup) -> Result<Self, Self::Error> {
184 let id = Ulid::from(res.user_registration_token_id);
185
186 let usage_limit = res
187 .usage_limit
188 .map(u32::try_from)
189 .transpose()
190 .map_err(|e| {
191 DatabaseInconsistencyError::on("user_registration_tokens")
192 .column("usage_limit")
193 .row(id)
194 .source(e)
195 })?;
196
197 let times_used = res.times_used.try_into().map_err(|e| {
198 DatabaseInconsistencyError::on("user_registration_tokens")
199 .column("times_used")
200 .row(id)
201 .source(e)
202 })?;
203
204 Ok(UserRegistrationToken {
205 id,
206 token: res.token,
207 usage_limit,
208 times_used,
209 created_at: res.created_at,
210 last_used_at: res.last_used_at,
211 expires_at: res.expires_at,
212 revoked_at: res.revoked_at,
213 })
214 }
215}
216
217#[async_trait]
218impl UserRegistrationTokenRepository for PgUserRegistrationTokenRepository<'_> {
219 type Error = DatabaseError;
220
221 #[tracing::instrument(
222 name = "db.user_registration_token.list",
223 skip_all,
224 fields(
225 db.query.text,
226 ),
227 err,
228 )]
229 async fn list(
230 &mut self,
231 filter: UserRegistrationTokenFilter,
232 pagination: Pagination,
233 ) -> Result<Page<UserRegistrationToken>, Self::Error> {
234 let (sql, values) = Query::select()
235 .expr_as(
236 Expr::col((
237 UserRegistrationTokens::Table,
238 UserRegistrationTokens::UserRegistrationTokenId,
239 )),
240 UserRegistrationTokenLookupIden::UserRegistrationTokenId,
241 )
242 .expr_as(
243 Expr::col((UserRegistrationTokens::Table, UserRegistrationTokens::Token)),
244 UserRegistrationTokenLookupIden::Token,
245 )
246 .expr_as(
247 Expr::col((
248 UserRegistrationTokens::Table,
249 UserRegistrationTokens::UsageLimit,
250 )),
251 UserRegistrationTokenLookupIden::UsageLimit,
252 )
253 .expr_as(
254 Expr::col((
255 UserRegistrationTokens::Table,
256 UserRegistrationTokens::TimesUsed,
257 )),
258 UserRegistrationTokenLookupIden::TimesUsed,
259 )
260 .expr_as(
261 Expr::col((
262 UserRegistrationTokens::Table,
263 UserRegistrationTokens::CreatedAt,
264 )),
265 UserRegistrationTokenLookupIden::CreatedAt,
266 )
267 .expr_as(
268 Expr::col((
269 UserRegistrationTokens::Table,
270 UserRegistrationTokens::LastUsedAt,
271 )),
272 UserRegistrationTokenLookupIden::LastUsedAt,
273 )
274 .expr_as(
275 Expr::col((
276 UserRegistrationTokens::Table,
277 UserRegistrationTokens::ExpiresAt,
278 )),
279 UserRegistrationTokenLookupIden::ExpiresAt,
280 )
281 .expr_as(
282 Expr::col((
283 UserRegistrationTokens::Table,
284 UserRegistrationTokens::RevokedAt,
285 )),
286 UserRegistrationTokenLookupIden::RevokedAt,
287 )
288 .from(UserRegistrationTokens::Table)
289 .apply_filter(filter)
290 .generate_pagination(
291 (
292 UserRegistrationTokens::Table,
293 UserRegistrationTokens::UserRegistrationTokenId,
294 ),
295 pagination,
296 )
297 .build_sqlx(PostgresQueryBuilder);
298
299 let tokens = sqlx::query_as_with::<_, UserRegistrationTokenLookup, _>(&sql, values)
300 .traced()
301 .fetch_all(&mut *self.conn)
302 .await?
303 .into_iter()
304 .map(TryInto::try_into)
305 .collect::<Result<Vec<_>, _>>()?;
306
307 let page = pagination.process(tokens);
308
309 Ok(page)
310 }
311
312 #[tracing::instrument(
313 name = "db.user_registration_token.count",
314 skip_all,
315 fields(
316 db.query.text,
317 user_registration_token.filter = ?filter,
318 ),
319 err,
320 )]
321 async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result<usize, Self::Error> {
322 let (sql, values) = Query::select()
323 .expr(
324 Expr::col((
325 UserRegistrationTokens::Table,
326 UserRegistrationTokens::UserRegistrationTokenId,
327 ))
328 .count(),
329 )
330 .from(UserRegistrationTokens::Table)
331 .apply_filter(filter)
332 .build_sqlx(PostgresQueryBuilder);
333
334 let count: i64 = sqlx::query_scalar_with(&sql, values)
335 .traced()
336 .fetch_one(&mut *self.conn)
337 .await?;
338
339 count
340 .try_into()
341 .map_err(DatabaseError::to_invalid_operation)
342 }
343
344 #[tracing::instrument(
345 name = "db.user_registration_token.lookup",
346 skip_all,
347 fields(
348 db.query.text,
349 user_registration_token.id = %id,
350 ),
351 err,
352 )]
353 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistrationToken>, Self::Error> {
354 let res = sqlx::query_as!(
355 UserRegistrationTokenLookup,
356 r#"
357 SELECT user_registration_token_id,
358 token,
359 usage_limit,
360 times_used,
361 created_at,
362 last_used_at,
363 expires_at,
364 revoked_at
365 FROM user_registration_tokens
366 WHERE user_registration_token_id = $1
367 "#,
368 Uuid::from(id)
369 )
370 .traced()
371 .fetch_optional(&mut *self.conn)
372 .await?;
373
374 let Some(res) = res else {
375 return Ok(None);
376 };
377
378 Ok(Some(res.try_into()?))
379 }
380
381 #[tracing::instrument(
382 name = "db.user_registration_token.find_by_token",
383 skip_all,
384 fields(
385 db.query.text,
386 token = %token,
387 ),
388 err,
389 )]
390 async fn find_by_token(
391 &mut self,
392 token: &str,
393 ) -> Result<Option<UserRegistrationToken>, Self::Error> {
394 let res = sqlx::query_as!(
395 UserRegistrationTokenLookup,
396 r#"
397 SELECT user_registration_token_id,
398 token,
399 usage_limit,
400 times_used,
401 created_at,
402 last_used_at,
403 expires_at,
404 revoked_at
405 FROM user_registration_tokens
406 WHERE token = $1
407 "#,
408 token
409 )
410 .traced()
411 .fetch_optional(&mut *self.conn)
412 .await?;
413
414 let Some(res) = res else {
415 return Ok(None);
416 };
417
418 Ok(Some(res.try_into()?))
419 }
420
421 #[tracing::instrument(
422 name = "db.user_registration_token.add",
423 skip_all,
424 fields(
425 db.query.text,
426 user_registration_token.token = %token,
427 ),
428 err,
429 )]
430 async fn add(
431 &mut self,
432 rng: &mut (dyn RngCore + Send),
433 clock: &dyn mas_storage::Clock,
434 token: String,
435 usage_limit: Option<u32>,
436 expires_at: Option<DateTime<Utc>>,
437 ) -> Result<UserRegistrationToken, Self::Error> {
438 let created_at = clock.now();
439 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
440
441 let usage_limit_i32 = usage_limit
442 .map(i32::try_from)
443 .transpose()
444 .map_err(DatabaseError::to_invalid_operation)?;
445
446 sqlx::query!(
447 r#"
448 INSERT INTO user_registration_tokens
449 (user_registration_token_id, token, usage_limit, created_at, expires_at)
450 VALUES ($1, $2, $3, $4, $5)
451 "#,
452 Uuid::from(id),
453 &token,
454 usage_limit_i32,
455 created_at,
456 expires_at,
457 )
458 .traced()
459 .execute(&mut *self.conn)
460 .await?;
461
462 Ok(UserRegistrationToken {
463 id,
464 token,
465 usage_limit,
466 times_used: 0,
467 created_at,
468 last_used_at: None,
469 expires_at,
470 revoked_at: None,
471 })
472 }
473
474 #[tracing::instrument(
475 name = "db.user_registration_token.use_token",
476 skip_all,
477 fields(
478 db.query.text,
479 user_registration_token.id = %token.id,
480 ),
481 err,
482 )]
483 async fn use_token(
484 &mut self,
485 clock: &dyn Clock,
486 token: UserRegistrationToken,
487 ) -> Result<UserRegistrationToken, Self::Error> {
488 let now = clock.now();
489 let new_times_used = sqlx::query_scalar!(
490 r#"
491 UPDATE user_registration_tokens
492 SET times_used = times_used + 1,
493 last_used_at = $2
494 WHERE user_registration_token_id = $1 AND revoked_at IS NULL
495 RETURNING times_used
496 "#,
497 Uuid::from(token.id),
498 now,
499 )
500 .traced()
501 .fetch_one(&mut *self.conn)
502 .await?;
503
504 let new_times_used = new_times_used
505 .try_into()
506 .map_err(DatabaseError::to_invalid_operation)?;
507
508 Ok(UserRegistrationToken {
509 times_used: new_times_used,
510 last_used_at: Some(now),
511 ..token
512 })
513 }
514
515 #[tracing::instrument(
516 name = "db.user_registration_token.revoke",
517 skip_all,
518 fields(
519 db.query.text,
520 user_registration_token.id = %token.id,
521 ),
522 err,
523 )]
524 async fn revoke(
525 &mut self,
526 clock: &dyn Clock,
527 mut token: UserRegistrationToken,
528 ) -> Result<UserRegistrationToken, Self::Error> {
529 let revoked_at = clock.now();
530 let res = sqlx::query!(
531 r#"
532 UPDATE user_registration_tokens
533 SET revoked_at = $2
534 WHERE user_registration_token_id = $1
535 "#,
536 Uuid::from(token.id),
537 revoked_at,
538 )
539 .traced()
540 .execute(&mut *self.conn)
541 .await?;
542
543 DatabaseError::ensure_affected_rows(&res, 1)?;
544
545 token.revoked_at = Some(revoked_at);
546
547 Ok(token)
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use chrono::Duration;
554 use mas_storage::{
555 Clock as _, Pagination, clock::MockClock, user::UserRegistrationTokenFilter,
556 };
557 use rand::SeedableRng;
558 use rand_chacha::ChaChaRng;
559 use sqlx::PgPool;
560
561 use crate::PgRepository;
562
563 #[sqlx::test(migrator = "crate::MIGRATOR")]
564 async fn test_list_and_count(pool: PgPool) {
565 let mut rng = ChaChaRng::seed_from_u64(42);
566 let clock = MockClock::default();
567
568 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
569
570 let _token1 = repo
573 .user_registration_token()
574 .add(&mut rng, &clock, "token1".to_owned(), None, None)
575 .await
576 .unwrap();
577
578 let token2 = repo
580 .user_registration_token()
581 .add(&mut rng, &clock, "token2".to_owned(), None, None)
582 .await
583 .unwrap();
584 let token2 = repo
585 .user_registration_token()
586 .use_token(&clock, token2)
587 .await
588 .unwrap();
589
590 let past_time = clock.now() - Duration::days(1);
592 let token3 = repo
593 .user_registration_token()
594 .add(&mut rng, &clock, "token3".to_owned(), None, Some(past_time))
595 .await
596 .unwrap();
597
598 let token4 = repo
600 .user_registration_token()
601 .add(&mut rng, &clock, "token4".to_owned(), None, None)
602 .await
603 .unwrap();
604 let token4 = repo
605 .user_registration_token()
606 .revoke(&clock, token4)
607 .await
608 .unwrap();
609
610 let empty_filter = UserRegistrationTokenFilter::new(clock.now());
612 let page = repo
613 .user_registration_token()
614 .list(empty_filter, Pagination::first(10))
615 .await
616 .unwrap();
617 assert_eq!(page.edges.len(), 4);
618
619 let count = repo
621 .user_registration_token()
622 .count(empty_filter)
623 .await
624 .unwrap();
625 assert_eq!(count, 4);
626
627 let used_filter = UserRegistrationTokenFilter::new(clock.now()).with_been_used(true);
629 let page = repo
630 .user_registration_token()
631 .list(used_filter, Pagination::first(10))
632 .await
633 .unwrap();
634 assert_eq!(page.edges.len(), 1);
635 assert_eq!(page.edges[0].id, token2.id);
636
637 let unused_filter = UserRegistrationTokenFilter::new(clock.now()).with_been_used(false);
639 let page = repo
640 .user_registration_token()
641 .list(unused_filter, Pagination::first(10))
642 .await
643 .unwrap();
644 assert_eq!(page.edges.len(), 3);
645
646 let expired_filter = UserRegistrationTokenFilter::new(clock.now()).with_expired(true);
648 let page = repo
649 .user_registration_token()
650 .list(expired_filter, Pagination::first(10))
651 .await
652 .unwrap();
653 assert_eq!(page.edges.len(), 1);
654 assert_eq!(page.edges[0].id, token3.id);
655
656 let not_expired_filter = UserRegistrationTokenFilter::new(clock.now()).with_expired(false);
657 let page = repo
658 .user_registration_token()
659 .list(not_expired_filter, Pagination::first(10))
660 .await
661 .unwrap();
662 assert_eq!(page.edges.len(), 3);
663
664 let revoked_filter = UserRegistrationTokenFilter::new(clock.now()).with_revoked(true);
666 let page = repo
667 .user_registration_token()
668 .list(revoked_filter, Pagination::first(10))
669 .await
670 .unwrap();
671 assert_eq!(page.edges.len(), 1);
672 assert_eq!(page.edges[0].id, token4.id);
673
674 let not_revoked_filter = UserRegistrationTokenFilter::new(clock.now()).with_revoked(false);
675 let page = repo
676 .user_registration_token()
677 .list(not_revoked_filter, Pagination::first(10))
678 .await
679 .unwrap();
680 assert_eq!(page.edges.len(), 3);
681
682 let valid_filter = UserRegistrationTokenFilter::new(clock.now()).with_valid(true);
684 let page = repo
685 .user_registration_token()
686 .list(valid_filter, Pagination::first(10))
687 .await
688 .unwrap();
689 assert_eq!(page.edges.len(), 2);
690
691 let invalid_filter = UserRegistrationTokenFilter::new(clock.now()).with_valid(false);
692 let page = repo
693 .user_registration_token()
694 .list(invalid_filter, Pagination::first(10))
695 .await
696 .unwrap();
697 assert_eq!(page.edges.len(), 2);
698
699 let combined_filter = UserRegistrationTokenFilter::new(clock.now())
701 .with_been_used(false)
702 .with_revoked(true);
703 let page = repo
704 .user_registration_token()
705 .list(combined_filter, Pagination::first(10))
706 .await
707 .unwrap();
708 assert_eq!(page.edges.len(), 1);
709 assert_eq!(page.edges[0].id, token4.id);
710
711 let page = repo
713 .user_registration_token()
714 .list(empty_filter, Pagination::first(2))
715 .await
716 .unwrap();
717 assert_eq!(page.edges.len(), 2);
718 }
719}