Skip to content

Commit

Permalink
feat: registrationName in CSR + Warning fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasBousselin committed Oct 31, 2024
1 parent f74ceec commit deca20c
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.egm.stellio.search.csr.model

import org.springframework.http.HttpHeaders
import org.springframework.http.ResponseEntity
import java.util.Base64

/**
* Implements NGSILD-Warning as defined in 6.3.17
Expand All @@ -15,7 +15,7 @@ open class NGSILDWarning(
fun getHeaderMessage(): String = "$code ${getWarnAgent()} \"${getWarnText()}\""

// new line are forbidden in headers
private fun getWarnText(): String = Base64.getEncoder().encodeToString(message.toByteArray())
private fun getWarnText(): String = message.replace("\n", " ")
private fun getWarnAgent(): String = csr.registrationName ?: csr.id.toString()

companion object {
Expand Down Expand Up @@ -48,10 +48,11 @@ data class MiscellaneousPersistentWarning(
) : NGSILDWarning(MISCELLANEOUS_PERSISTENT_WARNING_CODE, message, csr)

fun ResponseEntity<*>.addWarnings(warnings: List<NGSILDWarning>?): ResponseEntity<*> {
val headers = HttpHeaders.writableHttpHeaders(this.headers)
if (!warnings.isNullOrEmpty())
this.headers.addAll(
NGSILDWarning.HEADER_NAME,
warnings.map { it.getHeaderMessage() }
)
return this
headers.addAll(NGSILDWarning.HEADER_NAME, warnings.map { it.getHeaderMessage() })

return ResponseEntity.status(this.statusCode)
.headers(headers)
.body(this.body)
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ContextSourceRegistrationService(
mode,
information,
operations,
registration_name,
observation_interval_start,
observation_interval_end,
management_interval_start,
Expand All @@ -63,6 +64,7 @@ class ContextSourceRegistrationService(
:mode,
:information,
:operations,
:registration_name,
:observation_interval_start,
:observation_interval_end,
:management_interval_start,
Expand All @@ -80,6 +82,7 @@ class ContextSourceRegistrationService(
Json.of(mapper.writeValueAsString(contextSourceRegistration.information))
)
.bind("operations", contextSourceRegistration.operations.map { it.key }.toTypedArray())
.bind("registration_name", contextSourceRegistration.registrationName)
.bind("observation_interval_start", contextSourceRegistration.observationInterval?.start)
.bind("observation_interval_end", contextSourceRegistration.observationInterval?.end)
.bind("management_interval_start", contextSourceRegistration.managementInterval?.start)
Expand Down Expand Up @@ -124,6 +127,7 @@ class ContextSourceRegistrationService(
mode,
information,
operations,
registration_name,
observation_interval_start,
observation_interval_end,
management_interval_start,
Expand Down Expand Up @@ -175,6 +179,7 @@ class ContextSourceRegistrationService(
mode,
information,
operations,
registration_name,
observation_interval_start,
observation_interval_end,
management_interval_start,
Expand Down Expand Up @@ -218,6 +223,7 @@ class ContextSourceRegistrationService(
information = mapper.readerForListOf(RegistrationInfo::class.java)
.readValue((row["information"] as Json).asString()),
operations = (row["operations"] as Array<String>).mapNotNull { Operation.fromString(it) },
registrationName = row["registration_name"] as? String,
createdAt = toZonedDateTime(row["created_at"]),
modifiedAt = toOptionalZonedDateTime(row["modified_at"]),
observationInterval = row["observation_interval_start"]?.let {
Expand All @@ -238,7 +244,7 @@ class ContextSourceRegistrationService(
suspend fun updateContextSourceStatus(
csr: ContextSourceRegistration,
success: Boolean
): Long {
) {
val updateStatement = if (success)
Update.update("status", ContextSourceRegistration.StatusType.OK.name)
.set("times_sent", csr.timesSent + 1)
Expand All @@ -248,7 +254,7 @@ class ContextSourceRegistrationService(
.set("times_failed", csr.timesFailed + 1)
.set("last_failure", ngsiLdDateTime())

return r2dbcEntityTemplate.update(
r2dbcEntityTemplate.update(
query(where("id").`is`(csr.id)),
updateStatement,
ContextSourceRegistration::class.java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.egm.stellio.search.csr.service

import arrow.core.*
import arrow.core.raise.either
import arrow.core.raise.iorNel
import com.egm.stellio.search.csr.model.ContextSourceRegistration
import com.egm.stellio.search.csr.model.NGSILDWarning
import com.egm.stellio.search.csr.model.RevalidationFailedWarning
Expand All @@ -28,22 +27,21 @@ object ContextSourceUtils {
fun mergeEntities(
localEntity: CompactedEntity?,
remoteEntitiesWithCSR: List<CompactedEntityWithCSR>
): IorNel<NGSILDWarning, CompactedEntity?> = iorNel {
if (localEntity == null && remoteEntitiesWithCSR.isEmpty()) return@iorNel null
): IorNel<NGSILDWarning, CompactedEntity?> {
if (localEntity == null && remoteEntitiesWithCSR.isEmpty()) return Ior.Right(null)

val mergedEntity: MutableMap<String, Any> = localEntity?.toMutableMap() ?: mutableMapOf()

remoteEntitiesWithCSR.sortedBy { (_, csr) -> csr.isAuxiliary() }
.forEach { (entity, csr) ->
mergedEntity.putAll(
getMergeNewValues(mergedEntity, entity, csr).toIor().toIorNel().bind()
)
}
val warnings = remoteEntitiesWithCSR.sortedBy { (_, csr) -> csr.isAuxiliary() }
.mapNotNull { (entity, csr) ->
getMergeNewValues(mergedEntity, entity, csr)
.onRight { mergedEntity.putAll(it) }.leftOrNull()
}.toNonEmptyListOrNull()

mergedEntity.toMap()
return if (warnings == null) Ior.Right(mergedEntity) else Ior.Both(warnings, mergedEntity)
}

private fun getMergeNewValues(
fun getMergeNewValues(
currentEntity: CompactedEntity,
remoteEntity: CompactedEntity,
csr: ContextSourceRegistration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CREATE TABLE context_source_registration
mode text NOT NULL,
information jsonb NOT NULL,
operations text[] NOT NULL,
registration_name text,
observation_interval_start timestamp with time zone,
observation_interval_end timestamp with time zone,
management_interval_start timestamp with time zone,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.egm.stellio.search.csr

import com.egm.stellio.search.csr.model.ContextSourceRegistration
import com.egm.stellio.search.csr.model.Operation
import com.egm.stellio.shared.util.ngsiLdDateTime
import com.egm.stellio.shared.util.toUri

object CsrUtils {
fun gimmeRawCSR() = ContextSourceRegistration(
id = "urn:ngsi-ld:ContextSourceRegistration:test".toUri(),
endpoint = "http://localhost:8089".toUri(),
information = emptyList(),
operations = listOf(Operation.FEDERATION_OPS),
createdAt = ngsiLdDateTime(),
)
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.egm.stellio.search.csr.service

import com.egm.stellio.search.csr.CsrUtils.gimmeRawCSR
import com.egm.stellio.search.csr.model.*
import com.egm.stellio.shared.util.*
import com.egm.stellio.shared.util.JsonUtils.serializeObject
Expand All @@ -21,13 +22,6 @@ class ContextSourceCallerTests {

private val apiaryId = "urn:ngsi-ld:Apiary:TEST"

private fun gimmeRawCSR() = ContextSourceRegistration(
id = "urn:ngsi-ld:ContextSourceRegistration:test".toUri(),
endpoint = "http://localhost:8089".toUri(),
information = emptyList(),
operations = listOf(Operation.FEDERATION_OPS),
createdAt = ngsiLdDateTime(),
)
private val emptyParams = LinkedMultiValueMap<String, String>()
private val entityWithSysAttrs =
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.egm.stellio.search.csr.service

import arrow.core.left
import arrow.core.right
import com.egm.stellio.search.csr.model.ContextSourceRegistration
import com.egm.stellio.search.csr.model.MiscellaneousWarning
import com.egm.stellio.search.csr.model.Mode
import com.egm.stellio.shared.model.CompactedAttributeInstance
import com.egm.stellio.shared.model.CompactedEntity
Expand Down Expand Up @@ -86,22 +88,6 @@ class ContextSourceUtilsTests {
assertEquals(entityWithName + entityWithLastName + entityWithSurName, mergedEntity.getOrNull())
}

@Test
fun `merge entity should call mergeAttribute or mergeTypeOrScope when keys are equal`() = runTest {
mockkObject(ContextSourceUtils) {
every { ContextSourceUtils.mergeAttribute(any(), any(), any()) } returns listOf(
nameAttribute
).right()
every { ContextSourceUtils.mergeTypeOrScope(any(), any()) } returns listOf("Beehive")
ContextSourceUtils.mergeEntities(
entityWithName,
listOf(entityWithName to auxiliaryCSR, entityWithName to inclusiveCSR)
)
verify(exactly = 2) { ContextSourceUtils.mergeAttribute(any(), any(), any()) }
verify(exactly = 2) { ContextSourceUtils.mergeTypeOrScope(any(), any()) }
}
}

@Test
fun `merge entity should merge the types correctly `() = runTest {
val mergedEntity = ContextSourceUtils.mergeEntities(
Expand Down Expand Up @@ -192,4 +178,38 @@ class ContextSourceUtilsTests {
(mergedEntity.getOrNull()?.get(NGSILD_CREATED_AT_TERM))
)
}

@Test
fun `merge entity should merge each entity using getMergeNewValues and return the received warnings`() = runTest {
val warning1 = MiscellaneousWarning("1", inclusiveCSR)
val warning2 = MiscellaneousWarning("2", inclusiveCSR)
mockkObject(ContextSourceUtils) {
every { ContextSourceUtils.getMergeNewValues(any(), any(), any()) } returns
warning1.left() andThen warning2.left()

val (warnings, entity) = ContextSourceUtils.mergeEntities(
entityWithName,
listOf(entityWithName to inclusiveCSR, entityWithName to inclusiveCSR)
).toPair()
verify(exactly = 2) { ContextSourceUtils.getMergeNewValues(any(), any(), any()) }
assertThat(warnings).hasSize(2).contains(warning1, warning2)
assertEquals(entityWithName, entity)
}
}

@Test
fun `merge entity should call mergeAttribute or mergeTypeOrScope when keys are equal`() = runTest {
mockkObject(ContextSourceUtils) {
every { ContextSourceUtils.mergeAttribute(any(), any(), any()) } returns listOf(
nameAttribute
).right()
every { ContextSourceUtils.mergeTypeOrScope(any(), any()) } returns listOf("Beehive")
ContextSourceUtils.mergeEntities(
entityWithName,
listOf(entityWithName to auxiliaryCSR, entityWithName to inclusiveCSR)
)
verify(exactly = 2) { ContextSourceUtils.mergeAttribute(any(), any(), any()) }
verify(exactly = 2) { ContextSourceUtils.mergeTypeOrScope(any(), any()) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package com.egm.stellio.search.entity.web
import arrow.core.left
import arrow.core.right
import com.egm.stellio.search.common.config.SearchProperties
import com.egm.stellio.search.csr.CsrUtils.gimmeRawCSR
import com.egm.stellio.search.csr.model.MiscellaneousWarning
import com.egm.stellio.search.csr.model.NGSILDWarning
import com.egm.stellio.search.csr.service.ContextSourceCaller
import com.egm.stellio.search.csr.service.ContextSourceRegistrationService
import com.egm.stellio.search.entity.model.*
import com.egm.stellio.search.entity.service.EntityQueryService
Expand Down Expand Up @@ -804,6 +808,43 @@ class EntityHandlerTests {
)
}

@Test
fun `get entity by id should return the warnings send by the csr and update the csr status`() {
val csr = gimmeRawCSR()
coEvery {
entityQueryService.queryEntity("urn:ngsi-ld:BeeHive:TEST".toUri(), sub.getOrNull())
} returns ResourceNotFoundException("no entity").left()

coEvery {
contextSourceRegistrationService
.getContextSourceRegistrations(any(), any(), any())
} returns listOf(csr, csr)

mockkObject(ContextSourceCaller) {
coEvery {
ContextSourceCaller.getDistributedInformation(any(), any(), any(), any())
} returns MiscellaneousWarning(
"message\nwith\nline\nbreaks",
csr
).left() andThen
MiscellaneousWarning("message", csr).left()

coEvery { contextSourceRegistrationService.updateContextSourceStatus(any(), any()) } returns Unit
webClient.get()
.uri("/ngsi-ld/v1/entities/urn:ngsi-ld:BeeHive:TEST")
.header(HttpHeaders.LINK, AQUAC_HEADER_LINK)
.exchange()
.expectStatus().isNotFound
.expectHeader().valueEquals(
NGSILDWarning.HEADER_NAME,
"199 urn:ngsi-ld:ContextSourceRegistration:test \"message with line breaks\"",
"199 urn:ngsi-ld:ContextSourceRegistration:test \"message\""
)

coVerify(exactly = 2) { contextSourceRegistrationService.updateContextSourceStatus(any(), false) }
}
}

@Test
fun `get entities by type should not include temporal properties if query param sysAttrs is not present`() {
coEvery { entityQueryService.queryEntities(any(), any<Sub>()) } returns Pair(
Expand Down

0 comments on commit deca20c

Please sign in to comment.