Skip to content

Commit

Permalink
fix: dev pose snapshot generator
Browse files Browse the repository at this point in the history
  • Loading branch information
DongGeon0908 committed Aug 26, 2024
1 parent 5a72325 commit a6d82f5
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package com.hero.alignlab.domain.dev.application

import com.hero.alignlab.common.extension.executes
import com.hero.alignlab.config.database.TransactionTemplates
import com.hero.alignlab.domain.dev.model.request.DevPoseSnapshotRequest
import com.hero.alignlab.domain.pose.application.PoseSnapshotService
import com.hero.alignlab.domain.pose.domain.PoseSnapshot
import com.hero.alignlab.domain.pose.domain.vo.PoseType
import com.hero.alignlab.domain.pose.model.PoseSnapshotModel
import com.hero.alignlab.domain.pose.model.request.PoseSnapshotRequest
import com.hero.alignlab.event.model.LoadPoseSnapshot
import org.springframework.context.ApplicationEventPublisher
import org.springframework.stereotype.Service
import java.math.BigDecimal
import java.time.temporal.ChronoUnit

@Service
class DevPoseSnapshotService(
private val poseSnapshotService: PoseSnapshotService,
private val txTemplates: TransactionTemplates,
private val publisher: ApplicationEventPublisher,
) {
suspend fun create(request: DevPoseSnapshotRequest) {
val daysBetween = ChronoUnit.DAYS.between(request.fromDate, request.toDate).toInt()
val dailySnapshots = request.dailyCount

for (day in 0..daysBetween) {
val currentDate = request.fromDate.plusDays(day.toLong())
for (count in 1..dailySnapshots) {
val poseSnapshotRequest = generatePoseSnapshotRequest(request.uid)

txTemplates.writer.executes {
val createdPoseSnapshot = poseSnapshotService.saveSync(
PoseSnapshot(
uid = request.uid,
score = poseSnapshotRequest.snapshot.score,
type = poseSnapshotRequest.type,
imageUrl = poseSnapshotRequest.imageUrl,
)
)

LoadPoseSnapshot(createdPoseSnapshot, poseSnapshotRequest.snapshot.keypoints)
.run { publisher.publishEvent(this) }
}
}
}
}

private fun generatePoseSnapshotRequest(uid: Long): PoseSnapshotRequest {
return PoseSnapshotRequest(
snapshot = PoseSnapshotModel(
keypoints = listOf(
PoseSnapshotModel.KeyPoint(
x = BigDecimal("340.15104727778066"),
y = BigDecimal("317.0014378682798"),
name = PoseSnapshotModel.PosePosition.nose,
confidence = BigDecimal("0.4566487967967987")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("377.45328131535103"),
y = BigDecimal("255.51991796000226"),
name = PoseSnapshotModel.PosePosition.left_eye,
confidence = BigDecimal("0.7127981781959534")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("296.70069464439877"),
y = BigDecimal("255.8227834988443"),
name = PoseSnapshotModel.PosePosition.right_eye,
confidence = BigDecimal("0.7658332586288452")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("409.9420757516539"),
y = BigDecimal("283.3670408973556"),
name = PoseSnapshotModel.PosePosition.left_ear,
confidence = BigDecimal("0.6566842794418335")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("243.45096777595006"),
y = BigDecimal("293.0522368514335"),
name = PoseSnapshotModel.PosePosition.right_ear,
confidence = BigDecimal("0.6061005592346191")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("474.296262666009"),
y = BigDecimal("461.63234326659443"),
name = PoseSnapshotModel.PosePosition.left_shoulder,
confidence = BigDecimal("0.3419668972492218")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("212.29553071932938"),
y = BigDecimal("405.9508591849053"),
name = PoseSnapshotModel.PosePosition.right_shoulder,
confidence = BigDecimal("0.1849706918001175")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("536.7644776153909"),
y = BigDecimal("472.4922462422612"),
name = PoseSnapshotModel.PosePosition.left_elbow,
confidence = BigDecimal("0.13967803120613098")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("115.94839192631409"),
y = BigDecimal("479.2769055909309"),
name = PoseSnapshotModel.PosePosition.right_elbow,
confidence = BigDecimal("0.14616511762142181")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("477.22051860157825"),
y = BigDecimal("463.2419869517527"),
name = PoseSnapshotModel.PosePosition.left_wrist,
confidence = BigDecimal("0.11746863275766373")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("115.32584645311717"),
y = BigDecimal("475.06100195679534"),
name = PoseSnapshotModel.PosePosition.right_wrist,
confidence = BigDecimal("0.06731665879487991")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("434.72166853076027"),
y = BigDecimal("481.99154533645805"),
name = PoseSnapshotModel.PosePosition.left_hip,
confidence = BigDecimal("0.05242015793919563")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("206.51050888234235"),
y = BigDecimal("487.54411662426895"),
name = PoseSnapshotModel.PosePosition.right_hip,
confidence = BigDecimal("0.04528944566845894")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("633.1560611828228"),
y = BigDecimal("357.2434379965817"),
name = PoseSnapshotModel.PosePosition.left_knee,
confidence = BigDecimal("0.09410975873470306")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("245.58417560462436"),
y = BigDecimal("473.55431375372086"),
name = PoseSnapshotModel.PosePosition.right_knee,
confidence = BigDecimal("0.22151082754135132")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("512.0801633054434"),
y = BigDecimal("321.2637690639215"),
name = PoseSnapshotModel.PosePosition.left_ankle,
confidence = BigDecimal("0.032857201993465424")
),
PoseSnapshotModel.KeyPoint(
x = BigDecimal("265.4604333833633"),
y = BigDecimal("355.20512497446424"),
name = PoseSnapshotModel.PosePosition.right_ankle,
confidence = BigDecimal("0.05365283787250519")
)
),
score = BigDecimal("0.5373632567269462")
),
type = PoseType.CHIN_UTP,
imageUrl = null
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.hero.alignlab.domain.dev.model.request

import java.time.LocalDateTime

data class DevPoseSnapshotRequest(
val uid: Long,
val fromDate: LocalDateTime,
val toDate: LocalDateTime,
val dailyCount: Int,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.hero.alignlab.domain.dev.resource

import com.hero.alignlab.common.extension.wrapVoid
import com.hero.alignlab.config.dev.DevResourceCheckConfig.Companion.devResource
import com.hero.alignlab.config.swagger.SwaggerTag.DEV_TAG
import com.hero.alignlab.domain.dev.application.DevPoseSnapshotService
import com.hero.alignlab.domain.dev.model.request.DevPoseSnapshotRequest
import io.swagger.v3.oas.annotations.Operation
import io.swagger.v3.oas.annotations.tags.Tag
import org.springframework.http.MediaType
import org.springframework.web.bind.annotation.*

@Tag(name = DEV_TAG)
@RestController
@RequestMapping(produces = [MediaType.APPLICATION_JSON_VALUE])
class DevPoseSnapshotResource(
private val devPoseSnapshotService: DevPoseSnapshotService,
) {
@Operation(summary = "[DEV] 포즈 스냅샷 생성")
@PostMapping("/api/dev/v1/pose-snapshots")
suspend fun createPoseSnapshots(
@RequestHeader("X-HERO-DEV-TOKEN") token: String,
@RequestBody request: DevPoseSnapshotRequest,
) = devResource(token) {
devPoseSnapshotService.create(request).wrapVoid()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,10 @@ class GroupUserService(
fun deleteSync(groupUserId: Long) {
groupUserRepository.deleteById(groupUserId)
}

suspend fun findByUid(uid: Long): GroupUser? {
return withContext(Dispatchers.IO) {
groupUserRepository.findByUid(uid)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ interface GroupUserRepository : JpaRepository<GroupUser, Long> {
fun existsByUid(uid: Long): Boolean

fun countByCreatedAtBetween(startAt: LocalDateTime, endAt: LocalDateTime): Long

fun findByUid(uid: Long): GroupUser?
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package com.hero.alignlab.domain.pose.application

import com.hero.alignlab.domain.pose.domain.PoseKeyPointSnapshot
import com.hero.alignlab.domain.pose.infrastructure.PoseKeyPointSnapshotRepository
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
import org.springframework.jdbc.core.namedparam.SqlParameterSourceUtils
import org.springframework.stereotype.Service
import org.springframework.transaction.annotation.Transactional

@Service
class PoseKeyPointSnapshotService(
private val heroNamedParameterJdbcTemplate: NamedParameterJdbcTemplate
private val heroNamedParameterJdbcTemplate: NamedParameterJdbcTemplate,
private val poseKeyPointSnapshotRepository: PoseKeyPointSnapshotRepository,
) {
@Transactional
fun bulkSave(poseKeyPointSnapshots: List<PoseKeyPointSnapshot>) {
Expand All @@ -21,4 +23,9 @@ class PoseKeyPointSnapshotService(

heroNamedParameterJdbcTemplate.batchUpdate(sql, batchParams)
}

@Transactional
fun saveAllSync(poseKeyPointSnapshots: List<PoseKeyPointSnapshot>): List<PoseKeyPointSnapshot> {
return poseKeyPointSnapshotRepository.saveAll(poseKeyPointSnapshots)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ enum class PoseType(val nameKor: String) {
;

companion object {
val BAD_POSE = setOf(TURTLE_NECK, SHOULDER_TWIST, CHIN_UTP, TURTLE_NECK)
val BAD_POSE = setOf(TURTLE_NECK, SHOULDER_TWIST, CHIN_UTP, TAILBONE_SIT)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import com.hero.alignlab.common.extension.coExecuteOrNull
import com.hero.alignlab.config.database.TransactionTemplates
import com.hero.alignlab.domain.group.application.GroupUserScoreService
import com.hero.alignlab.domain.group.application.GroupUserService
import com.hero.alignlab.domain.group.domain.GroupUserScore
import com.hero.alignlab.domain.pose.application.PoseCountService
import com.hero.alignlab.domain.pose.application.PoseKeyPointSnapshotService
import com.hero.alignlab.domain.pose.domain.PoseCount
Expand Down Expand Up @@ -34,8 +33,8 @@ class PoseSnapshotListener(
PoseKeyPointSnapshot(
poseSnapshotId = event.poseSnapshot.id,
position = keyPoint.name.toPosition(),
x = keyPoint.x,
y = keyPoint.y,
x = keyPoint.x,
confidence = keyPoint.confidence
)
}
Expand Down Expand Up @@ -63,23 +62,21 @@ class PoseSnapshotListener(
.values
.sum()

val groupUsers = groupUserService.findAllByUid(event.poseSnapshot.uid)
val groupUserScore = groupUserScoreService.findAllByGroupUserIdIn(groupUsers.map { it.id })
.associateBy { it.groupUserId }
val groupUser = groupUserService.findByUid(event.poseSnapshot.uid)
val groupUserScore = groupUserScoreService.findByUidOrNull(event.poseSnapshot.uid)

val needToUpdateScores = groupUsers.map { groupUser ->
groupUserScore[groupUser.id] ?: GroupUserScore(
groupId = groupUser.groupId,
groupUserId = groupUser.id,
uid = groupUser.uid,
score = score
)
val updatedGroupUserScore = when (groupUser != null && groupUserScore != null) {
true -> groupUserScore.apply {
this.score = score
}

false -> null
}

txTemplates.writer.coExecuteOrNull {
poseKeyPointSnapshotService.bulkSave(keyPoints)
poseKeyPointSnapshotService.saveAllSync(keyPoints)
poseCountService.saveSync(poseCount)
groupUserScoreService.saveAllSync(needToUpdateScores)
updatedGroupUserScore?.run { groupUserScoreService.saveSync(this) }
}
}
}
Expand Down

0 comments on commit a6d82f5

Please sign in to comment.