diff --git a/core/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/samModels.scala b/core/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/samModels.scala index 070c4101c5..e663433090 100644 --- a/core/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/samModels.scala +++ b/core/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/samModels.scala @@ -239,6 +239,9 @@ object WorkspaceAction { final case object Delete extends WorkspaceAction { val asString = "delete" } + final case object Compute extends WorkspaceAction { + val asString = "compute" + } val allActions = sealerate.values[WorkspaceAction] val stringToAction: Map[String, WorkspaceAction] = diff --git a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/dao/sam/SamUtils.scala b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/dao/sam/SamUtils.scala new file mode 100644 index 0000000000..ea8578f039 --- /dev/null +++ b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/dao/sam/SamUtils.scala @@ -0,0 +1,80 @@ +package org.broadinstitute.dsde.workbench.leonardo.dao.sam + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers.OAuth2BearerToken +import cats.effect.Async +import cats.implicits.{catsSyntaxApplicativeError, toFlatMapOps} +import cats.mtl.Ask +import org.broadinstitute.dsde.workbench.leonardo.model.{ + ForbiddenError, + LeoException, + RuntimeNotFoundByWorkspaceIdException, + RuntimeNotFoundException +} +import org.broadinstitute.dsde.workbench.leonardo.{ + AppContext, + CloudContext, + RuntimeAction, + RuntimeName, + SamResourceId, + WorkspaceId +} +import org.broadinstitute.dsde.workbench.model.{UserInfo, WorkbenchEmail} + +trait SamUtils[F[_]] { + val samService: SamService[F] + + def checkRuntimeAction(userInfo: UserInfo, + cloudContext: CloudContext, + runtimeName: RuntimeName, + samResourceId: SamResourceId, + action: RuntimeAction, + userEmail: Option[WorkbenchEmail] = None + )(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] = + checkRuntimeActionInternal( + userInfo.accessToken, + userEmail.getOrElse(userInfo.userEmail), + samResourceId, + action, + RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database") + ) + + def checkRuntimeAction(userInfo: UserInfo, + workspaceId: WorkspaceId, + runtimeName: RuntimeName, + samResourceId: SamResourceId, + action: RuntimeAction + )(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] = + checkRuntimeActionInternal( + userInfo.accessToken, + userInfo.userEmail, + samResourceId, + action, + RuntimeNotFoundByWorkspaceIdException(workspaceId, runtimeName, "Not found in database") + ) + + private def checkRuntimeActionInternal(userToken: OAuth2BearerToken, + userEmail: WorkbenchEmail, + samResourceId: SamResourceId, + action: RuntimeAction, + notFoundException: LeoException + )(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] = + samService + .checkAuthorized(userToken.token, samResourceId, action) + .handleErrorWith { + // If we've already checked read access and the user doesn't have it, pretend the runtime doesn't exist to avoid leaking its existence + case e: SamException if e.statusCode == StatusCodes.Forbidden && action == RuntimeAction.GetRuntimeStatus => + F.raiseError(notFoundException) + // Check if the user can read the runtime to determine which error to raise + case e: SamException if e.statusCode == StatusCodes.Forbidden => + samService + .checkAuthorized(userToken.token, samResourceId, RuntimeAction.GetRuntimeStatus) + .attempt + .flatMap { + // The user can read the runtime, but they don't have the required action. Raise the original Forbidden action from Sam + case Right(_) => F.raiseError(ForbiddenError(userEmail)) + // The user can't read the runtime, pretend it doesn't exist to avoid leaking its existence + case Left(_) => F.raiseError(notFoundException) + } + } +} diff --git a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/AppDependenciesBuilder.scala b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/AppDependenciesBuilder.scala index e2f908a0a0..0b9c5d6756 100644 --- a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/AppDependenciesBuilder.scala +++ b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/AppDependenciesBuilder.scala @@ -98,7 +98,6 @@ class AppDependenciesBuilder(baselineDependenciesBuilder: BaselineDependenciesBu val azureService = new RuntimeV2ServiceInterp[IO]( baselineDependencies.runtimeServicesConfig, - baselineDependencies.authProvider, baselineDependencies.publisherQueue, baselineDependencies.dateAccessedUpdaterQueue, baselineDependencies.wsmClientProvider, diff --git a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/GcpDependenciesBuilder.scala b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/GcpDependenciesBuilder.scala index 4228b7aed3..eacc768daf 100644 --- a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/GcpDependenciesBuilder.scala +++ b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/GcpDependenciesBuilder.scala @@ -268,7 +268,8 @@ class GcpDependencyBuilder extends CloudDependenciesBuilder { baselineDependencies.proxyResolver, baselineDependencies.samDAO, baselineDependencies.googleTokenCache, - baselineDependencies.samResourceCache + baselineDependencies.samResourceCache, + baselineDependencies.samService ) val diskService = new DiskServiceInterp[IO]( diff --git a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/ProxyService.scala b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/ProxyService.scala index f4aebcb7bd..65573e74be 100644 --- a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/ProxyService.scala +++ b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/ProxyService.scala @@ -23,6 +23,7 @@ import org.broadinstitute.dsde.workbench.leonardo.SamResourceId._ import org.broadinstitute.dsde.workbench.leonardo.config.ProxyConfig import org.broadinstitute.dsde.workbench.leonardo.dao.HostStatus._ import org.broadinstitute.dsde.workbench.leonardo.dao.google.GoogleOAuth2Service +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamService, SamUtils} import org.broadinstitute.dsde.workbench.leonardo.dao.{HostStatus, JupyterDAO, Proxy, SamDAO, TerminalName} import org.broadinstitute.dsde.workbench.leonardo.db.{appQuery, clusterQuery, DbReference, KubernetesServiceDbQueries} import org.broadinstitute.dsde.workbench.leonardo.dns.{KubernetesDnsCache, ProxyResolver, RuntimeDnsCache} @@ -91,14 +92,16 @@ class ProxyService( proxyResolver: ProxyResolver[IO], samDAO: SamDAO[IO], googleTokenCache: Cache[IO, String, (UserInfo, Instant)], - samResourceCache: Cache[IO, SamResourceCacheKey, (Option[String], Option[AppAccessScope])] + samResourceCache: Cache[IO, SamResourceCacheKey, (Option[String], Option[AppAccessScope])], + val samService: SamService[IO] )(implicit val system: ActorSystem, executionContext: ExecutionContext, dbRef: DbReference[IO], loggerIO: StructuredLogger[IO], metrics: OpenTelemetryMetrics[IO] -) extends LazyLogging { +) extends LazyLogging + with SamUtils[IO] { val httpsConnectionContext = ConnectionContext.httpsClient(sslContext) val clientConnectionSettings = ClientConnectionSettings(system).withTransport(ClientTransport.withCustomResolver(proxyResolver.resolveAkka)) @@ -267,38 +270,9 @@ class ProxyService( for { ctx <- ev.ask[AppContext] - hasWorkspacePermission <- workspaceId match { - case Some(wid) => - authProvider - .isUserWorkspaceReader( - WorkspaceResourceSamResourceId(wid), - userInfo - ) - case None => IO.pure(true) - } - - _ <- IO.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - samResource <- getCachedRuntimeSamResource(RuntimeCacheKey(cloudContext, runtimeName)) - // Note both these Sam actions are cached so it should be okay to call hasPermission twice - hasViewPermission <- authProvider.hasPermission[RuntimeSamResourceId, RuntimeAction]( - samResource, - RuntimeAction.GetRuntimeStatus, - userInfo - ) - _ <- - if (!hasViewPermission) { - IO.raiseError(RuntimeNotFoundException(cloudContext, runtimeName, ctx.traceId.asString)) - } else IO.unit - hasConnectPermission <- authProvider.hasPermission[RuntimeSamResourceId, RuntimeAction]( - samResource, - RuntimeAction.ConnectToRuntime, - userInfo - ) - _ <- - if (!hasConnectPermission) { - IO.raiseError(ForbiddenError(userInfo.userEmail)) - } else IO.unit + _ <- checkRuntimeAction(userInfo, cloudContext, runtimeName, samResource, RuntimeAction.ConnectToRuntime) + hostStatus <- getRuntimeTargetHost(cloudContext, runtimeName) _ <- hostStatus match { case HostReady(_, _, _) => diff --git a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterp.scala b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterp.scala index b0bad734d6..5aa49e7b7f 100644 --- a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterp.scala +++ b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterp.scala @@ -29,12 +29,11 @@ import org.broadinstitute.dsde.workbench.leonardo.SamResourceId.{ } import org.broadinstitute.dsde.workbench.leonardo.config._ import org.broadinstitute.dsde.workbench.leonardo.dao.DockerDAO -import org.broadinstitute.dsde.workbench.leonardo.dao.sam.SamService +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService, SamUtils} import org.broadinstitute.dsde.workbench.leonardo.db._ import org.broadinstitute.dsde.workbench.leonardo.http.service.DiskServiceInterp.getDiskSamPolicyMap import org.broadinstitute.dsde.workbench.leonardo.model.SamResourceAction.{ projectSamResourceAction, - runtimeSamResourceAction, workspaceSamResourceAction } import org.broadinstitute.dsde.workbench.leonardo.http.service.RuntimeServiceInterp._ @@ -67,14 +66,15 @@ class RuntimeServiceInterp[F[_]: Parallel]( googleStorageService: Option[GoogleStorageService[F]], googleComputeService: Option[GoogleComputeService[F]], publisherQueue: Queue[F, LeoPubsubMessage], - samService: SamService[F] + val samService: SamService[F] )(implicit F: Async[F], log: StructuredLogger[F], dbReference: DbReference[F], ec: ExecutionContext, metrics: OpenTelemetryMetrics[F] -) extends RuntimeService[F] { +) extends RuntimeService[F] + with SamUtils[F] { override def createRuntime( userInfo: UserInfo, @@ -90,13 +90,17 @@ class RuntimeServiceInterp[F[_]: Parallel]( ) // Resolve the user email in Sam from the user token. This translates a pet token to the owner email. userEmail <- samService.getUserEmail(userInfo.accessToken.token) - hasPermission <- authProvider.hasPermission[ProjectSamResourceId, ProjectAction]( - ProjectSamResourceId(googleProject), - ProjectAction.CreateRuntime, - userInfo - ) + // Check if the user has launch_notebook_cluster on the google-project resource. + _ <- samService + .checkAuthorized( + userInfo.accessToken.token, + ProjectSamResourceId(googleProject), + ProjectAction.CreateRuntime + ) + .adaptError { + case e: SamException if e.statusCode == StatusCodes.Forbidden => ForbiddenError(userEmail) + } _ <- context.span.traverse(s => F.delay(s.addAnnotation("Done Sam call for cluster permission"))) - _ <- F.raiseUnless(hasPermission)(ForbiddenError(userEmail)) // Grab the pet service account for the user petSA <- samService.getPetServiceAccount(userInfo.accessToken.token, googleProject) _ <- context.span.traverse(s => F.delay(s.addAnnotation("Done Sam call for getPetServiceAccount"))) @@ -234,31 +238,9 @@ class RuntimeServiceInterp[F[_]: Parallel]( as: Ask[F, AppContext] ): F[GetRuntimeResponse] = for { - ctx <- as.ask - // throw 403 if no project-level permission - hasProjectPermission <- authProvider.isUserProjectReader( - cloudContext, - userInfo - ) - _ <- F.raiseWhen(!hasProjectPermission)(ForbiddenError(userInfo.userEmail, Some(ctx.traceId))) - // throws 404 if not existent resp <- RuntimeServiceDbQueries.getRuntime(cloudContext, runtimeName).transaction - - // throw 404 if no GetClusterStatus permission - hasPermission <- authProvider.hasPermissionWithProjectFallback[RuntimeSamResourceId, RuntimeAction]( - resp.samResource, - RuntimeAction.GetRuntimeStatus, - ProjectAction.GetRuntimeStatus, - userInfo, - GoogleProject(cloudContext.asString) - ) - _ <- - if (hasPermission) F.unit - else - F.raiseError[Unit]( - RuntimeNotFoundException(cloudContext, runtimeName, "permission denied") - ) + _ <- checkRuntimeAction(userInfo, cloudContext, runtimeName, resp.samResource, RuntimeAction.GetRuntimeStatus) } yield resp override def listRuntimes(userInfo: UserInfo, cloudContext: Option[CloudContext], params: Map[String, String])( @@ -292,51 +274,13 @@ class RuntimeServiceInterp[F[_]: Parallel]( // Resolve the user email in Sam from the user token. This translates a pet token to the owner email. userEmail <- samService.getUserEmail(req.userInfo.accessToken.token) - - // throw 403 if no project-level permission - hasProjectPermission <- authProvider.isUserProjectReader( - cloudContext, - req.userInfo - ) - _ <- F.raiseWhen(!hasProjectPermission)(ForbiddenError(userEmail, Some(ctx.traceId))) - - // throw 404 if not existent - runtimeOpt <- clusterQuery.getActiveClusterByNameMinimal(cloudContext, req.runtimeName).transaction - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("DB | Done getActiveClusterByNameMinimal"))) - runtime <- runtimeOpt.fold( - F.raiseError[Runtime](RuntimeNotFoundException(cloudContext, req.runtimeName, "no record in database")) - )(F.pure) - // throw 404 if no GetClusterStatus permission - // Note: the general pattern is to 404 (e.g. pretend the runtime doesn't exist) if the caller doesn't have - // GetClusterStatus permission. We return 403 if the user can view the runtime but can't perform some other action. - listOfPermissions <- authProvider.getActionsWithProjectFallback( - runtime.samResource, - req.googleProject, - req.userInfo + runtime <- getRuntimeWithRequiredAction(req.userInfo, + cloudContext, + req.runtimeName, + RuntimeAction.DeleteRuntime, + userEmail.some ) - hasStatusPermission = listOfPermissions._1.toSet.contains(RuntimeAction.GetRuntimeStatus) || - listOfPermissions._2.contains(ProjectAction.GetRuntimeStatus) - - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Sam | Done get list of allowed actions"))) - - _ <- - if (hasStatusPermission) F.unit - else - log.info(ctx.loggingCtx)(s"${userEmail.value} has no permission to get runtime status") >> F - .raiseError[Unit]( - RuntimeNotFoundException( - cloudContext, - req.runtimeName, - "no active runtime record in database", - Some(ctx.traceId) - ) - ) - - // throw 403 if no DeleteCluster permission - hasDeletePermission = listOfPermissions._1.toSet.contains(RuntimeAction.DeleteRuntime) || - listOfPermissions._2.contains(ProjectAction.DeleteRuntime) - _ <- if (hasDeletePermission) F.unit else F.raiseError[Unit](ForbiddenError(userEmail)) // throw 409 if the cluster is not deletable _ <- if (runtime.status.isDeletable) F.unit @@ -437,22 +381,12 @@ class RuntimeServiceInterp[F[_]: Parallel]( ): F[Unit] = for { ctx <- as.ask - - listOfPermissions <- authProvider.getActionsWithProjectFallback(runtime.samResource, cloudContext.value, userInfo) - // throw 404 if no GetRuntime permission - hasStatusPermission = listOfPermissions._1.toSet.contains(RuntimeAction.GetRuntimeStatus) || - listOfPermissions._2.contains(ProjectAction.GetRuntimeStatus) - _ <- - if (hasStatusPermission) F.unit - else - F.raiseError[Unit]( - RuntimeNotFoundException(cloudContext, runtime.clusterName, "Permission Denied", Some(ctx.traceId)) - ) - - // throw 403 if no DeleteApp permission - hasDeletePermission = listOfPermissions._1.toSet.contains(RuntimeAction.DeleteRuntime) || - listOfPermissions._2.contains(ProjectAction.DeleteRuntime) - _ <- if (hasDeletePermission) F.unit else F.raiseError[Unit](ForbiddenError(userInfo.userEmail)) + _ <- checkRuntimeAction(userInfo, + cloudContext, + runtime.clusterName, + runtime.samResource, + RuntimeAction.DeleteRuntime + ) // Mark the resource as deleted in Leo's DB _ <- dbReference.inTransaction(clusterQuery.completeDeletion(runtime.id, ctx.now)) @@ -464,7 +398,6 @@ class RuntimeServiceInterp[F[_]: Parallel]( as: Ask[F, AppContext] ): F[Unit] = for { - ctx <- as.ask runtimes <- listRuntimes(userInfo, Some(cloudContext), Map.empty) _ <- runtimes.traverse(runtime => deleteRuntimeRecords(userInfo, cloudContext, runtime)) } yield () @@ -474,56 +407,7 @@ class RuntimeServiceInterp[F[_]: Parallel]( ): F[Unit] = for { ctx <- as.ask - // throw 403 if no project-level permission - hasProjectPermission <- authProvider.isUserProjectReader( - cloudContext, - userInfo - ) - _ <- F.raiseWhen(!hasProjectPermission)(ForbiddenError(userInfo.userEmail, Some(ctx.traceId))) - - googleProject <- F.fromOption( - LeoLenses.cloudContextToGoogleProject.get(cloudContext), - AzureUnimplementedException("Azure runtime is not supported yet") - ) - // throw 404 if not existent - runtimeOpt <- clusterQuery - .getActiveClusterByNameMinimal(cloudContext, runtimeName)(scala.concurrent.ExecutionContext.global) - .transaction - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Finish query for active runtime"))) - runtime <- runtimeOpt.fold( - F.raiseError[Runtime]( - RuntimeNotFoundException(cloudContext, runtimeName, "no active runtime found in database") - ) - )(F.pure) - // throw 404 if no GetClusterStatus permission - // Note: the general pattern is to 404 (e.g. pretend the runtime doesn't exist) if the caller doesn't have - // GetClusterStatus permission. We return 403 if the user can view the runtime but can't perform some other action. - - listOfPermissions <- authProvider.getActionsWithProjectFallback(runtime.samResource, googleProject, userInfo) - - hasStatusPermission = listOfPermissions._1.toSet.contains(RuntimeAction.GetRuntimeStatus) || - listOfPermissions._2.contains(ProjectAction.GetRuntimeStatus) - - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Sam | Done get list of allowed actions"))) - - _ <- - if (hasStatusPermission) F.unit - else - F.raiseError[Unit]( - RuntimeNotFoundException( - cloudContext, - runtimeName, - "GetRuntimeStatus permission is required for stopRuntime" - ) - ) - - // throw 403 if no StopStartCluster permission - hasStopPermission = listOfPermissions._1.toSet.contains(RuntimeAction.StopStartRuntime) || - listOfPermissions._2.contains(ProjectAction.StopStartRuntime) - - _ <- if (hasStopPermission) F.unit else F.raiseError[Unit](ForbiddenError(userInfo.userEmail)) - // throw 409 if the cluster is not stoppable - + runtime <- getRuntimeWithRequiredAction(userInfo, cloudContext, runtimeName, RuntimeAction.StopStartRuntime) _ <- if (runtime.status.isStopping) F.unit else if (runtime.status.isStoppable) { @@ -544,45 +428,7 @@ class RuntimeServiceInterp[F[_]: Parallel]( // TODO: take cloudContext directly instead of googleProject once we start supporting patching an Azure VM cloudContext = CloudContext.Gcp(googleProject) - // throw 403 if no project-level permission - hasProjectPermission <- authProvider.isUserProjectReader( - cloudContext, - userInfo - ) - _ <- F.raiseWhen(!hasProjectPermission)(ForbiddenError(userInfo.userEmail, Some(ctx.traceId))) - - // throw 404 if not existent - runtimeOpt <- clusterQuery - .getActiveClusterByNameMinimal(cloudContext, runtimeName)(scala.concurrent.ExecutionContext.global) - .transaction - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done query for active runtime"))) - runtime <- runtimeOpt.fold( - F.raiseError[Runtime](RuntimeNotFoundException(cloudContext, runtimeName, "no record in database")) - )(F.pure) - // throw 404 if no GetClusterStatus permission - // Note: the general pattern is to 404 (e.g. pretend the runtime doesn't exist) if the caller doesn't have - // GetClusterStatus permission. We return 403 if the user can view the runtime but can't perform some other action. - - listOfPermissions <- authProvider.getActionsWithProjectFallback(runtime.samResource, googleProject, userInfo) - - hasStatusPermission = listOfPermissions._1.toSet.contains(RuntimeAction.GetRuntimeStatus) || - listOfPermissions._2.contains(ProjectAction.GetRuntimeStatus) - - _ <- - if (hasStatusPermission) F.unit - else - F.raiseError[Unit]( - RuntimeNotFoundException(cloudContext, runtimeName, "GetRuntimeStatus permission is required") - ) - - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Sam | Done get list of allowed actions"))) - - hasStartPermission = listOfPermissions._1.toSet.contains(RuntimeAction.StopStartRuntime) || - listOfPermissions._2.toSet.contains(ProjectAction.StopStartRuntime) - - // throw 403 if no StopStartCluster permission - _ <- if (hasStartPermission) F.unit else F.raiseError[Unit](ForbiddenError(userInfo.userEmail)) - + runtime <- getRuntimeWithRequiredAction(userInfo, cloudContext, runtimeName, RuntimeAction.StopStartRuntime) // throw 409 if the cluster is not startable _ <- if (runtime.status.isStartable) F.unit @@ -605,46 +451,18 @@ class RuntimeServiceInterp[F[_]: Parallel]( // TODO: take cloudContext directly instead of googleProject once we start supporting patching an Azure VM cloudContext = CloudContext.Gcp(googleProject) - // throw 403 if no project-level permission - hasProjectPermission <- authProvider.isUserProjectReader( - cloudContext, - userInfo - ) - _ <- F.raiseWhen(!hasProjectPermission)(ForbiddenError(userInfo.userEmail, Some(ctx.traceId))) - // throw 404 if not existent runtimeOpt <- clusterQuery.getActiveClusterRecordByName(cloudContext, runtimeName).transaction runtime <- runtimeOpt.fold( F.raiseError[ClusterRecord](RuntimeNotFoundException(cloudContext, runtimeName, "no record in database")) )(F.pure) - // throw 404 if no GetClusterStatus permission - // Note: the general pattern is to 404 (e.g. pretend the runtime doesn't exist) if the caller doesn't have - // GetClusterStatus permission. We return 403 if the user can view the runtime but can't perform some other action. - - listOfPermissions <- authProvider.getActionsWithProjectFallback( - RuntimeSamResourceId(runtime.internalId), - googleProject, - userInfo - ) - - hasStatusPermission = listOfPermissions._1.toSet.contains(RuntimeAction.GetRuntimeStatus) || - listOfPermissions._2.contains(ProjectAction.GetRuntimeStatus) - - _ <- - if (hasStatusPermission) F.unit - else - F.raiseError[Unit]( - RuntimeNotFoundException( - cloudContext, - runtimeName, - "GetRuntimeStatus permission is required for update runtime" - ) - ) - - // throw 403 if no ModifyCluster permission - hasModifyPermission = listOfPermissions._1.toSet.contains(RuntimeAction.ModifyRuntime) - _ <- if (hasModifyPermission) F.unit else F.raiseError[Unit](ForbiddenError(userInfo.userEmail)) + _ <- checkRuntimeAction(userInfo, + cloudContext, + runtimeName, + RuntimeSamResourceId(runtime.internalId), + RuntimeAction.ModifyRuntime + ) // throw 409 if the cluster is not updatable _ <- if (runtime.status.isUpdatable) F.unit @@ -989,6 +807,22 @@ class RuntimeServiceInterp[F[_]: Parallel]( else Async[F].pure((mt, true)) } } yield targetMachineType + + private def getRuntimeWithRequiredAction( + userInfo: UserInfo, + cloudContext: CloudContext, + runtimeName: RuntimeName, + action: RuntimeAction, + userEmail: Option[WorkbenchEmail] = None + )(implicit as: Ask[F, AppContext]): F[Runtime] = + for { + runtimeOpt <- clusterQuery.getActiveClusterByNameMinimal(cloudContext, runtimeName).transaction + runtime <- runtimeOpt.fold( + F.raiseError[Runtime](RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database")) + )(F.pure) + + _ <- checkRuntimeAction(userInfo, cloudContext, runtimeName, runtime.samResource, action, userEmail) + } yield runtime } object RuntimeServiceInterp { diff --git a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterp.scala b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterp.scala index d7ba4f0007..ff0dc77519 100644 --- a/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterp.scala +++ b/http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterp.scala @@ -11,23 +11,17 @@ import cats.syntax.all._ import org.broadinstitute.dsde.workbench.google2.{DiskName, MachineTypeName, ZoneName} import org.broadinstitute.dsde.workbench.leonardo.SamResourceId.{ PersistentDiskSamResourceId, - ProjectSamResourceId, RuntimeSamResourceId, WorkspaceResourceSamResourceId, WsmResourceSamResourceId } import org.broadinstitute.dsde.workbench.leonardo.config.PersistentDiskConfig import org.broadinstitute.dsde.workbench.leonardo.dao._ -import org.broadinstitute.dsde.workbench.leonardo.dao.sam.SamService +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService, SamUtils} import org.broadinstitute.dsde.workbench.leonardo.db._ import org.broadinstitute.dsde.workbench.leonardo.http.service.DiskServiceInterp.getDiskSamPolicyMap import org.broadinstitute.dsde.workbench.leonardo.http.service.RuntimeServiceInterp.getRuntimeSamPolicyMap import org.broadinstitute.dsde.workbench.leonardo.model.SamResource.RuntimeSamResource -// do not remove: `projectSamResourceAction`, `runtimeSamResourceAction`, `workspaceSamResourceAction`, `wsmResourceSamResourceAction`; `AppSamResourceAction` they are implicit -import org.broadinstitute.dsde.workbench.leonardo.model.SamResourceAction.{ - workspaceSamResourceAction, - wsmResourceSamResourceAction -} import org.broadinstitute.dsde.workbench.leonardo.model._ import org.broadinstitute.dsde.workbench.leonardo.monitor.LeoPubsubMessage.{ CreateAzureRuntimeMessage, @@ -45,13 +39,13 @@ import scala.concurrent.ExecutionContext class RuntimeV2ServiceInterp[F[_]: Parallel]( config: RuntimeServiceConfig, - authProvider: LeoAuthProvider[F], publisherQueue: Queue[F, LeoPubsubMessage], dateAccessUpdaterQueue: Queue[F, UpdateDateAccessedMessage], wsmClientProvider: WsmApiClientProvider[F], - samService: SamService[F] + val samService: SamService[F] )(implicit F: Async[F], dbReference: DbReference[F], ec: ExecutionContext, log: StructuredLogger[F]) - extends RuntimeV2Service[F] { + extends RuntimeV2Service[F] + with SamUtils[F] { override def createRuntime( userInfo: UserInfo, @@ -63,6 +57,16 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( for { ctx <- as.ask + _ <- samService + .checkAuthorized(userInfo.accessToken.token, + WorkspaceResourceSamResourceId(workspaceId), + WorkspaceAction.Compute + ) + .adaptError { + case e: SamException if e.statusCode == StatusCodes.Forbidden => ForbiddenError(userInfo.userEmail) + } + _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done auth call for azure runtime permission"))) + workspaceDescOpt <- wsmClientProvider.getWorkspace(userInfo.accessToken.token, workspaceId) workspaceDesc <- F.fromOption(workspaceDescOpt, WorkspaceNotFoundException(workspaceId, ctx.traceId)) @@ -73,20 +77,9 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( case (None, None) => F.raiseError[CloudContext](CloudContextNotFoundException(workspaceId, ctx.traceId)) } - samResource = WorkspaceResourceSamResourceId(workspaceId) - // Resolve the user email in Sam from the user token. This translates a pet token to the owner email. userEmail <- samService.getUserEmail(userInfo.accessToken.token) - hasPermission <- authProvider.hasPermission[WorkspaceResourceSamResourceId, WorkspaceAction]( - samResource, - WorkspaceAction.CreateControlledUserResource, - userInfo - ) - - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done auth call for azure runtime permission"))) - _ <- F.raiseUnless(hasPermission)(ForbiddenError(userEmail)) - // enforcing one runtime per workspace/user at a time samResources <- samService.listResources(userInfo.accessToken.token, RuntimeSamResource.resourceType) runtimes <- RuntimeServiceDbQueries @@ -244,31 +237,9 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( for { ctx <- as.ask - hasWorkspacePermission <- authProvider.isUserWorkspaceReader( - WorkspaceResourceSamResourceId(workspaceId), - userInfo - ) - _ <- F.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - runtime <- RuntimeServiceDbQueries.getRuntimeByWorkspaceId(workspaceId, runtimeName).transaction - - hasPermission <- - if (runtime.auditInfo.creator == userInfo.userEmail) - F.pure(true) - else - checkSamPermission( - WsmResourceSamResourceId(WsmControlledResourceId(UUID.fromString(runtime.samResource.resourceId))), - userInfo, - WsmResourceAction.Read - ).map(_._1) - + _ <- checkRuntimeAction(userInfo, workspaceId, runtimeName, runtime.samResource, RuntimeAction.GetRuntimeStatus) _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done auth call for get azure runtime permission"))) - _ <- F - .raiseError[Unit]( - RuntimeNotFoundException(runtime.cloudContext, runtimeName, "permission denied", Some(ctx.traceId)) - ) - .whenA(!hasPermission) - } yield runtime override def updateRuntime( @@ -288,29 +259,7 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( for { ctx <- as.ask - hasWorkspacePermission <- authProvider.isUserWorkspaceReader( - WorkspaceResourceSamResourceId(workspaceId), - userInfo - ) - _ <- F.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - - runtime <- RuntimeServiceDbQueries.getActiveRuntimeRecord(workspaceId, runtimeName).transaction - - hasPermission <- - if (runtime.auditInfo.creator == userInfo.userEmail) F.pure(true) - else { - // users who have workspace level delete privileges should be able to delete all resources in the workspace - authProvider - .hasPermission[WorkspaceResourceSamResourceId, WorkspaceAction](WorkspaceResourceSamResourceId(workspaceId), - WorkspaceAction.Delete, - userInfo - ) - } - - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done auth call for delete azure runtime permission"))) - _ <- F - .raiseError[Unit](RuntimeNotFoundException(runtime.cloudContext, runtimeName, "permission denied")) - .whenA(!hasPermission) + runtime <- getClusterRecordWithRequiredAction(userInfo, workspaceId, runtimeName, RuntimeAction.DeleteRuntime) diskIdOpt <- RuntimeConfigQueries.getDiskId(runtime.runtimeConfigId).transaction diskId <- diskIdOpt match { @@ -384,15 +333,6 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( as: Ask[F, AppContext] ): F[Unit] = for { - ctx <- as.ask - - workspaceSamId = WorkspaceResourceSamResourceId(workspaceId) - hasWorkspacePermission <- authProvider.isUserWorkspaceReader( - workspaceSamId, - userInfo - ) - _ <- F.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - samResources <- samService.listResources(userInfo.accessToken.token, RuntimeSamResource.resourceType) runtimes <- RuntimeServiceDbQueries .listRuntimes( @@ -425,28 +365,15 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( ): F[Unit] = for { ctx <- as.ask - hasWorkspacePermission <- authProvider.isUserWorkspaceReader( - WorkspaceResourceSamResourceId(workspaceId), - userInfo - ) - _ <- F.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - runtime <- RuntimeServiceDbQueries.getRuntimeByWorkspaceId(workspaceId, runtimeName).transaction - hasResourcePermission <- checkSamPermission( - WsmResourceSamResourceId(WsmControlledResourceId(UUID.fromString(runtime.samResource.resourceId))), + _ <- checkRuntimeAction( userInfo, - WsmResourceAction.Write - ).map(_._1) - - _ <- ctx.span.traverse(s => - F.delay(s.addAnnotation("Done auth call for update date accessed runtime permission")) + workspaceId, + runtimeName, + WsmResourceSamResourceId(WsmControlledResourceId(UUID.fromString(runtime.samResource.resourceId))), + RuntimeAction.ModifyRuntime ) - _ <- F - .raiseError[Unit]( - RuntimeNotFoundException(runtime.cloudContext, runtimeName, "permission denied", Some(ctx.traceId)) - ) - .whenA(!hasResourcePermission) _ <- dateAccessUpdaterQueue.offer( UpdateDateAccessedMessage(UpdateTarget.Runtime(runtimeName), runtime.cloudContext, ctx.now) @@ -458,25 +385,7 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( as: Ask[F, AppContext] ): F[Unit] = for { ctx <- as.ask - hasWorkspacePermission <- authProvider.isUserWorkspaceReader( - WorkspaceResourceSamResourceId(workspaceId), - userInfo - ) - _ <- F.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - - runtime <- RuntimeServiceDbQueries.getActiveRuntimeRecord(workspaceId, runtimeName).transaction - - hasResourcePermission <- checkPermission( - runtime.auditInfo.creator, - userInfo, - WsmResourceSamResourceId(WsmControlledResourceId(UUID.fromString(runtime.internalId))) - ) - - _ <- F - .raiseError[Unit]( - RuntimeNotFoundException(runtime.cloudContext, runtimeName, "permission denied", Some(ctx.traceId)) - ) - .whenA(!hasResourcePermission) + runtime <- getClusterRecordWithRequiredAction(userInfo, workspaceId, runtimeName, RuntimeAction.StopStartRuntime) _ <- if (runtime.status.isStartable) F.unit else @@ -490,25 +399,7 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( ): F[Unit] = for { ctx <- as.ask - hasWorkspacePermission <- authProvider.isUserWorkspaceReader( - WorkspaceResourceSamResourceId(workspaceId), - userInfo - ) - _ <- F.raiseUnless(hasWorkspacePermission)(ForbiddenError(userInfo.userEmail)) - - runtime <- RuntimeServiceDbQueries.getActiveRuntimeRecord(workspaceId, runtimeName).transaction - - hasResourcePermission <- checkPermission( - runtime.auditInfo.creator, - userInfo, - WsmResourceSamResourceId(WsmControlledResourceId(UUID.fromString(runtime.internalId))) - ) - - _ <- F - .raiseError[Unit]( - RuntimeNotFoundException(runtime.cloudContext, runtimeName, "permission denied", Some(ctx.traceId)) - ) - .whenA(!hasResourcePermission) + runtime <- getClusterRecordWithRequiredAction(userInfo, workspaceId, runtimeName, RuntimeAction.StopStartRuntime) _ <- if (runtime.status.isStoppable) F.unit else @@ -595,32 +486,16 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( ) } - private def checkPermission( - creator: WorkbenchEmail, - userInfo: UserInfo, - wsmResourceSamResourceId: WsmResourceSamResourceId - )(implicit ev: Ask[F, AppContext]) = if (creator == userInfo.userEmail) F.pure(true) - else { - for { - ctx <- ev.ask - res <- checkSamPermission(wsmResourceSamResourceId, userInfo, WsmResourceAction.Read).map(_._1) - _ <- ctx.span.traverse(s => F.delay(s.addAnnotation("Done auth call for azure runtime permission check"))) - } yield res - } - - private def checkSamPermission( - wsmResourceSamResourceId: WsmResourceSamResourceId, + private def getClusterRecordWithRequiredAction( userInfo: UserInfo, - wsmResourceAction: WsmResourceAction - )(implicit ctx: Ask[F, AppContext]): F[(Boolean, WsmControlledResourceId)] = + workspaceId: WorkspaceId, + runtimeName: RuntimeName, + action: RuntimeAction + )(implicit as: Ask[F, AppContext]): F[ClusterRecord] = for { - // TODO: generalize for google - res <- authProvider.hasPermission( - wsmResourceSamResourceId, - wsmResourceAction, - userInfo - ) - } yield (res, wsmResourceSamResourceId.controlledResourceId) + runtime <- RuntimeServiceDbQueries.getActiveRuntimeRecord(workspaceId, runtimeName).transaction + _ <- checkRuntimeAction(userInfo, workspaceId, runtimeName, RuntimeSamResourceId(runtime.internalId), action) + } yield runtime private def errorHandler(runtimeId: Long, ctx: AppContext): Throwable => F[Unit] = e => @@ -689,14 +564,6 @@ class RuntimeV2ServiceInterp[F[_]: Parallel]( } -final case class AuthorizedIds( - val ownerGoogleProjectIds: Set[ProjectSamResourceId], - val ownerWorkspaceIds: Set[WorkspaceResourceSamResourceId], - val readerGoogleProjectIds: Set[ProjectSamResourceId], - val readerRuntimeIds: Set[SamResourceId], - val readerWorkspaceIds: Set[WorkspaceResourceSamResourceId] -) - final case class WorkspaceNotFoundException(workspaceId: WorkspaceId, traceId: TraceId) extends LeoException( s"WorkspaceId not found in workspace manager for workspace ${workspaceId}", diff --git a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/api/TestLeoRoutes.scala b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/api/TestLeoRoutes.scala index 6820ee2b83..8031d97abc 100644 --- a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/api/TestLeoRoutes.scala +++ b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/api/TestLeoRoutes.scala @@ -130,7 +130,6 @@ trait TestLeoRoutes { val runtimev2Service = new RuntimeV2ServiceInterp[IO]( serviceConfig, - allowListAuthProvider, QueueFactory.makePublisherQueue(), QueueFactory.makeDateAccessedQueue(), wsmClientProvider, diff --git a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/MockProxyService.scala b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/MockProxyService.scala index 1c0aa84962..fadb8fcec8 100644 --- a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/MockProxyService.scala +++ b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/MockProxyService.scala @@ -11,6 +11,7 @@ import org.broadinstitute.dsde.workbench.leonardo.config.ProxyConfig import org.broadinstitute.dsde.workbench.leonardo.dao.HostStatus.HostReady import org.broadinstitute.dsde.workbench.leonardo.dao.google.GoogleOAuth2Service import org.broadinstitute.dsde.workbench.leonardo.dao._ +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.SamService import org.broadinstitute.dsde.workbench.leonardo.db.DbReference import org.broadinstitute.dsde.workbench.leonardo.dns.{KubernetesDnsCache, RuntimeDnsCache} import org.broadinstitute.dsde.workbench.leonardo.model._ @@ -33,7 +34,8 @@ class MockProxyService( samResourceCache: Cache[IO, SamResourceCacheKey, (Option[String], Option[AppAccessScope])], googleOauth2Service: GoogleOAuth2Service[IO], samDAO: Option[SamDAO[IO]] = None, - queue: Option[Queue[IO, UpdateDateAccessedMessage]] = None + queue: Option[Queue[IO, UpdateDateAccessedMessage]] = None, + samService: SamService[IO] = MockSamService )(implicit system: ActorSystem, executionContext: ExecutionContext, @@ -52,7 +54,8 @@ class MockProxyService( LocalProxyResolver, samDAO.getOrElse(new MockSamDAO()), googleTokenCache, - samResourceCache + samResourceCache, + samService ) { override def getRuntimeTargetHost(cloudContext: CloudContext, clusterName: RuntimeName): IO[HostStatus] = diff --git a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterpSpec.scala b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterpSpec.scala index a2f83948f8..eec9e50bf4 100644 --- a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterpSpec.scala +++ b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeServiceInterpSpec.scala @@ -37,7 +37,7 @@ import org.broadinstitute.dsde.workbench.leonardo.TestUtils.{appContext, default import org.broadinstitute.dsde.workbench.leonardo.auth.AllowlistAuthProvider import org.broadinstitute.dsde.workbench.leonardo.config.Config import org.broadinstitute.dsde.workbench.leonardo.dao.MockDockerDAO -import org.broadinstitute.dsde.workbench.leonardo.dao.sam.SamService +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService} import org.broadinstitute.dsde.workbench.leonardo.db._ import org.broadinstitute.dsde.workbench.leonardo.http.service.RuntimeServiceInterp.{ calculateAutopauseThreshold, @@ -213,6 +213,16 @@ class RuntimeServiceInterpTest with MockitoSugar { it should "fail if user doesn't have project level permission" in { + val samService = mock[SamService[IO]] + val runtimeService = makeRuntimeService(samService = samService) + when(samService.getUserEmail(isEq(unauthorizedUserInfo.accessToken.token))(any())) + .thenReturn(IO.pure(unauthorizedUserInfo.userEmail)) + when( + samService.checkAuthorized(isEq(unauthorizedUserInfo.accessToken.token), + isEq(ProjectSamResourceId(cloudContextGcp.value)), + isEq(ProjectAction.CreateRuntime) + )(any()) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) val res = for { r <- runtimeService .createRuntime( @@ -841,15 +851,23 @@ class RuntimeServiceInterpTest exc shouldBe a[RuntimeNotFoundException] } - it should "fail to get a runtime when users don't have access to the project" in isolatedDbTest { - val exc = runtimeService - .getRuntime(unauthorizedUserInfo, cloudContextGcp, RuntimeName("cluster")) - .attempt - .unsafeRunSync()(cats.effect.unsafe.IORuntime.global) - .swap - .toOption - .get - exc shouldBe a[ForbiddenError] + it should "throw RuntimeNotFoundException when users don't have permission on the runtime" in isolatedDbTest { + val samService = mock[SamService[IO]] + val runtimeService = makeRuntimeService(samService = samService) + val samResource = RuntimeSamResourceId(UUID.randomUUID.toString) + when( + samService.checkAuthorized(isEq(unauthorizedUserInfo.accessToken.token), + isEq(samResource), + isEq(RuntimeAction.GetRuntimeStatus) + )(any()) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + val res = for { + testRuntime <- IO(makeCluster(1).copy(samResource = samResource).save()) + getResponse <- runtimeService + .getRuntime(unauthorizedUserInfo, testRuntime.cloudContext, testRuntime.runtimeName) + .attempt + } yield getResponse.swap.toOption.get.isInstanceOf[RuntimeNotFoundException] shouldBe true + res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } it should "list runtimes" in isolatedDbTest { @@ -1184,8 +1202,16 @@ class RuntimeServiceInterpTest res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } - it should "fail to delete a runtime if user loses project access" in isolatedDbTest { - val runtimeService = makeRuntimeService(authProvider = allowListAuthProvider2) + it should "fail to delete a runtime if the user doesn't have permission" in isolatedDbTest { + val samService = mock[SamService[IO]] + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.DeleteRuntime))(any())) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) + when(samService.getUserEmail(isEq(userInfo.accessToken.token))(any())).thenReturn(IO.pure(userInfo.userEmail)) + + val runtimeService = makeRuntimeService(samService = samService) val res = for { context <- appContext.ask[AppContext] pd <- makePersistentDisk().save() @@ -1206,14 +1232,52 @@ class RuntimeServiceInterpTest DeleteRuntimeRequest(userInfo, GoogleProject(cloudContextGcp.asString), testRuntime.runtimeName, false) ) .attempt - } yield r shouldBe Left(ForbiddenError(userInfo.userEmail, Some(context.traceId))) + } yield r.swap.toOption.get shouldBe a[ForbiddenError] + + res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) + } + + it should "fail to delete a runtime and not reveal its existence when user has no access to it" in isolatedDbTest { + val samService = mock[SamService[IO]] + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.DeleteRuntime))(any())) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when(samService.getUserEmail(isEq(userInfo.accessToken.token))(any())).thenReturn(IO.pure(userInfo.userEmail)) + + val runtimeService = makeRuntimeService(samService = samService) + val res = for { + context <- appContext.ask[AppContext] + pd <- makePersistentDisk().save() + testRuntime <- IO( + makeCluster(1).saveWithRuntimeConfig( + RuntimeConfig + .GceWithPdConfig( + MachineTypeName("n1-standard-4"), + Some(pd.id), + bootDiskSize = DiskSize(50), + zone = ZoneName("us-central1-a"), + None + ) + ) + ) + r <- runtimeService + .deleteRuntime( + DeleteRuntimeRequest(userInfo, GoogleProject(cloudContextGcp.asString), testRuntime.runtimeName, false) + ) + .attempt + } yield r.swap.toOption.get shouldBe a[RuntimeNotFoundException] res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } it should "delete runtime records, update all status appropriately, and not queue messages" in isolatedDbTest { val publisherQueue = QueueFactory.makePublisherQueue() - val runtimeService = makeRuntimeService(authProvider = allowListAuthProvider, publisherQueue = publisherQueue) + val samService = mock[SamService[IO]] + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.DeleteRuntime))(any())) + .thenReturn(IO.unit) + val runtimeService = makeRuntimeService(publisherQueue = publisherQueue, samService = samService) val res = for { pd <- makePersistentDisk().save() testRuntime <- IO( @@ -1260,9 +1324,11 @@ class RuntimeServiceInterpTest } it should "fail to delete runtime records if user loses project access" in isolatedDbTest { - val runtimeService = makeRuntimeService(authProvider = allowListAuthProvider2) + val samService = mock[SamService[IO]] + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), any())(any())) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + val runtimeService = makeRuntimeService(publisherQueue = publisherQueue, samService = samService) val res = for { - context <- appContext.ask[AppContext] pd <- makePersistentDisk().save() testRuntime <- IO( makeCluster(1).saveWithRuntimeConfig( @@ -1294,9 +1360,7 @@ class RuntimeServiceInterpTest r <- runtimeService .deleteRuntimeRecords(userInfo, cloudContextGcp, listRuntimeResponse2) .attempt - } yield r shouldBe Left( - RuntimeNotFoundException(cloudContextGcp, testRuntime.runtimeName, "Permission Denied", Some(context.traceId)) - ) + } yield r.swap.toOption.get shouldBe a[RuntimeNotFoundException] res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } @@ -1304,31 +1368,15 @@ class RuntimeServiceInterpTest it should "deleteAllRuntimeRecords, update all status appropriately, and not queue messages" in isolatedDbTest { val runtimeIds = Vector(RuntimeSamResourceId(UUID.randomUUID.toString), RuntimeSamResourceId(UUID.randomUUID.toString)) - val mockAuthProvider = mockAuthorize( - userInfo, - readerRuntimeSamIds = Set(runtimeIds(0), runtimeIds(1)), - readerProjectSamIds = Set(ProjectSamResourceId(project)) - ) - when(mockAuthProvider.isUserProjectReader(any, isEq(userInfo))(any)).thenReturn(IO.pure(true)) - when( - mockAuthProvider.getActionsWithProjectFallback[RuntimeSamResourceId, RuntimeAction](any, any, isEq(userInfo))(any, - any - ) - ) - .thenReturn( - IO.pure( - (List(RuntimeAction.GetRuntimeStatus, RuntimeAction.DeleteRuntime), - List(ProjectAction.GetRuntimeStatus, ProjectAction.DeleteRuntime) - ) - ) - ) val samService = mock[SamService[IO]] when(samService.listResources(isEq(userInfo.accessToken.token), isEq(RuntimeSamResource.resourceType))(any())) .thenReturn(IO.pure(runtimeIds.map(_.resourceId).toList)) + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.DeleteRuntime))(any())) + .thenReturn(IO.unit) when(samService.deleteResource(any(), any())(any())).thenReturn(IO.unit) val publisherQueue = QueueFactory.makePublisherQueue() val service = - makeRuntimeService(authProvider = mockAuthProvider, publisherQueue = publisherQueue, samService = samService) + makeRuntimeService(publisherQueue = publisherQueue, samService = samService) val res = for { pd1 <- makePersistentDisk().save() @@ -1378,32 +1426,16 @@ class RuntimeServiceInterpTest it should "deleteAll runtimes" in isolatedDbTest { val runtimeIds = Vector(RuntimeSamResourceId(UUID.randomUUID.toString), RuntimeSamResourceId(UUID.randomUUID.toString)) - val mockAuthProvider = mockAuthorize( - userInfo, - readerRuntimeSamIds = Set(runtimeIds(0), runtimeIds(1)), - readerProjectSamIds = Set(ProjectSamResourceId(project)) - ) - when(mockAuthProvider.isUserProjectReader(any, isEq(userInfo))(any)).thenReturn(IO.pure(true)) - when( - mockAuthProvider.getActionsWithProjectFallback[RuntimeSamResourceId, RuntimeAction](any, any, isEq(userInfo))(any, - any - ) - ) - .thenReturn( - IO.pure( - (List(RuntimeAction.GetRuntimeStatus, RuntimeAction.DeleteRuntime), - List(ProjectAction.GetRuntimeStatus, ProjectAction.DeleteRuntime) - ) - ) - ) + val samService = mock[SamService[IO]] when(samService.listResources(isEq(userInfo.accessToken.token), isEq(RuntimeSamResource.resourceType))(any())) .thenReturn(IO.pure(runtimeIds.map(_.resourceId).toList)) when(samService.getUserEmail(isEq(userInfo.accessToken.token))(any())).thenReturn(IO.pure(userInfo.userEmail)) + when(samService.checkAuthorized(any(), any(), any())(any())).thenReturn(IO.unit) when(samService.deleteResource(any(), any())(any())).thenReturn(IO.unit) val publisherQueue = QueueFactory.makePublisherQueue() val service = - makeRuntimeService(authProvider = mockAuthProvider, publisherQueue = publisherQueue, samService = samService) + makeRuntimeService(publisherQueue = publisherQueue, samService = samService) val res = for { pd1 <- makePersistentDisk().save() @@ -1581,6 +1613,40 @@ class RuntimeServiceInterpTest res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } + it should "fail to stop a runtime if the user doesn't have permission" in isolatedDbTest { + val samService = mock[SamService[IO]] + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.StopStartRuntime))(any()) + ) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) + + val runtimeService = makeRuntimeService(samService = samService) + val res = for { + context <- appContext.ask[AppContext] + pd <- makePersistentDisk().save() + testRuntime <- IO( + makeCluster(1).saveWithRuntimeConfig( + RuntimeConfig + .GceWithPdConfig( + MachineTypeName("n1-standard-4"), + Some(pd.id), + bootDiskSize = DiskSize(50), + zone = ZoneName("us-central1-a"), + None + ) + ) + ) + r <- runtimeService + .stopRuntime(userInfo, testRuntime.cloudContext, testRuntime.runtimeName) + .attempt + } yield r.swap.toOption.get shouldBe a[ForbiddenError] + + res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) + } + it should "start a runtime" in isolatedDbTest { val res = for { publisherQueue <- Queue.bounded[IO, LeoPubsubMessage](10) @@ -1605,6 +1671,29 @@ class RuntimeServiceInterpTest res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } + it should "fail to start a runtime if user doesn't have permission" in isolatedDbTest { + val samService = mock[SamService[IO]] + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.StopStartRuntime))(any()) + ) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) + + val runtimeService = makeRuntimeService(samService = samService) + val res = for { + samResource <- IO(RuntimeSamResourceId(UUID.randomUUID.toString)) + testRuntime <- IO(makeCluster(1).copy(samResource = samResource, status = RuntimeStatus.Stopped).save()) + + r <- runtimeService + .startRuntime(userInfo, GoogleProject(testRuntime.cloudContext.asString), testRuntime.runtimeName) + .attempt + } yield r.swap.toOption.get shouldBe a[ForbiddenError] + + res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) + } + it should "update autopause" in isolatedDbTest { val res = for { // remove some existing items in the queue just to be safe diff --git a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterpSpec.scala b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterpSpec.scala index d8b71282fd..5cc5b50244 100644 --- a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterpSpec.scala +++ b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/http/service/RuntimeV2ServiceInterpSpec.scala @@ -2,41 +2,26 @@ package org.broadinstitute.dsde.workbench.leonardo package http package service +import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.model.headers.OAuth2BearerToken import cats.effect.IO import cats.effect.std.Queue import cats.mtl.Ask import com.azure.resourcemanager.compute.models.VirtualMachineSizeTypes -import io.circe.Decoder import org.broadinstitute.dsde.workbench.azure._ import org.broadinstitute.dsde.workbench.google2.DiskName import org.broadinstitute.dsde.workbench.leonardo.CommonTestData._ -import org.broadinstitute.dsde.workbench.leonardo.JsonCodec.{ - projectSamResourceDecoder, - runtimeSamResourceDecoder, - workspaceSamResourceIdDecoder, - wsmResourceSamResourceIdDecoder -} import org.broadinstitute.dsde.workbench.leonardo.SamResourceId.{ - ProjectSamResourceId, RuntimeSamResourceId, WorkspaceResourceSamResourceId, WsmResourceSamResourceId } -import org.broadinstitute.dsde.workbench.leonardo.TestUtils.{appContext, defaultMockitoAnswer} -import org.broadinstitute.dsde.workbench.leonardo.auth.AllowlistAuthProvider +import org.broadinstitute.dsde.workbench.leonardo.TestUtils.appContext import org.broadinstitute.dsde.workbench.leonardo.config.Config import org.broadinstitute.dsde.workbench.leonardo.dao._ -import org.broadinstitute.dsde.workbench.leonardo.dao.sam.SamService +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService} import org.broadinstitute.dsde.workbench.leonardo.db._ import org.broadinstitute.dsde.workbench.leonardo.model.SamResource.RuntimeSamResource -import org.broadinstitute.dsde.workbench.leonardo.model.SamResourceAction.{ - projectSamResourceAction, - runtimeSamResourceAction, - workspaceSamResourceAction, - wsmResourceSamResourceAction, - AppSamResourceAction -} import org.broadinstitute.dsde.workbench.leonardo.model._ import org.broadinstitute.dsde.workbench.leonardo.monitor.LeoPubsubMessage.{ CreateAzureRuntimeMessage, @@ -48,10 +33,9 @@ import org.broadinstitute.dsde.workbench.leonardo.monitor.{LeoPubsubMessage, Upd import org.broadinstitute.dsde.workbench.leonardo.util.QueueFactory import org.broadinstitute.dsde.workbench.model.google.GoogleProject import org.broadinstitute.dsde.workbench.model.{TraceId, UserInfo, WorkbenchEmail, WorkbenchUserId} -import org.mockito.ArgumentMatchers.{any, argThat, eq => isEq} +import org.mockito.ArgumentMatchers.{any, eq => isEq} import org.mockito.Mockito.when import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar import org.typelevel.log4cats.StructuredLogger @@ -74,12 +58,11 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with // used when we care about queue state def makeInterp( queue: Queue[IO, LeoPubsubMessage] = QueueFactory.makePublisherQueue(), - authProvider: AllowlistAuthProvider = allowListAuthProvider, dateAccessedQueue: Queue[IO, UpdateDateAccessedMessage] = QueueFactory.makeDateAccessedQueue(), wsmClientProvider: WsmApiClientProvider[IO] = wsmClientProvider, samService: SamService[IO] = MockSamService ) = - new RuntimeV2ServiceInterp[IO](serviceConfig, authProvider, queue, dateAccessedQueue, wsmClientProvider, samService) + new RuntimeV2ServiceInterp[IO](serviceConfig, queue, dateAccessedQueue, wsmClientProvider, samService) // need to set previous runtime to deleted status before creating next to avoid exception def setRuntimeDeleted(workspaceId: WorkspaceId, name: RuntimeName): IO[Long] = @@ -94,7 +77,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with .transaction } yield runtime.id - def mockSamForCreateRuntimeTest(userInfo: UserInfo): SamService[IO] = { + def mockSamForCreateRuntime(userInfo: UserInfo): SamService[IO] = { val samService = mock[SamService[IO]] when(samService.checkAuthorized(any(), any(), any())(any())).thenReturn(IO.unit) when(samService.getUserEmail(userInfo.accessToken.token)).thenReturn(IO.pure(userInfo.userEmail)) @@ -104,174 +87,12 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with samService } - /** - * Generate a mocked AuthProvider which will permit action on the given resource IDs by the given user. - * TODO: cover actions beside `checkUserEnabled` and `listResourceIds` - * @param userInfo - * @param readerRuntimeSamIds - * @param readerWorkspaceSamIds - * @param readerProjectSamIds - * @param ownerWorkspaceSamIds - * @param ownerProjectSamIds - * @return - */ - def mockAuthorize( - userInfo: UserInfo, - readerRuntimeSamIds: Set[RuntimeSamResourceId] = Set.empty, - readerWsmSamIds: Set[WsmResourceSamResourceId] = Set.empty, - readerWorkspaceSamIds: Set[WorkspaceResourceSamResourceId] = Set.empty, - readerProjectSamIds: Set[ProjectSamResourceId] = Set.empty, - ownerWorkspaceSamIds: Set[WorkspaceResourceSamResourceId] = Set.empty, - ownerProjectSamIds: Set[ProjectSamResourceId] = Set.empty - ): AllowlistAuthProvider = { - val mockAuthProvider: AllowlistAuthProvider = mock[AllowlistAuthProvider](defaultMockitoAnswer[IO]) - - when(mockAuthProvider.checkUserEnabled(isEq(userInfo))(any)).thenReturn(IO.unit) - when( - mockAuthProvider.listResourceIds[RuntimeSamResourceId](isEq(true), isEq(userInfo))( - any(runtimeSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[RuntimeSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(readerRuntimeSamIds)) - when( - mockAuthProvider.listResourceIds[WsmResourceSamResourceId](isEq(false), isEq(userInfo))( - any(wsmResourceSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[WsmResourceSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(readerWsmSamIds)) - when( - mockAuthProvider.listResourceIds[WorkspaceResourceSamResourceId](isEq(false), isEq(userInfo))( - any(workspaceSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[WorkspaceResourceSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(readerWorkspaceSamIds)) - when( - mockAuthProvider.listResourceIds[ProjectSamResourceId](isEq(false), isEq(userInfo))( - any(projectSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[ProjectSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ) - .thenReturn(IO.pure(readerProjectSamIds)) - when( - mockAuthProvider.listResourceIds[WorkspaceResourceSamResourceId](isEq(true), isEq(userInfo))( - any(workspaceSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[WorkspaceResourceSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ) - .thenReturn(IO.pure(ownerWorkspaceSamIds)) - when( - mockAuthProvider.listResourceIds[ProjectSamResourceId](isEq(true), isEq(userInfo))( - any(projectSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[ProjectSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ) - .thenReturn(IO.pure(ownerProjectSamIds)) - - mockAuthProvider - } - - /** - * Generate a mocked AuthProvider which will permit action on the given resource IDs by the given user, - * when the list request is restricted to one workspace. Expects isUserWorkspace* instead of listResourceIds. - * TODO: cover actions beside `checkUserEnabled` and `listResourceIds` - * - * @param userInfo - * @param readerRuntimeSamIds - * @param readerWorkspaceSamIds - * @param readerProjectSamIds - * @param ownerWorkspaceSamIds - * @param ownerProjectSamIds - * @return - */ - def mockAuthorizeForOneWorkspace( - userInfo: UserInfo, - readerRuntimeSamIds: Set[RuntimeSamResourceId] = Set.empty, - readerWsmSamIds: Set[WsmResourceSamResourceId] = Set.empty, - readerWorkspaceSamIds: Set[WorkspaceResourceSamResourceId] = Set.empty, - readerProjectSamIds: Set[ProjectSamResourceId] = Set.empty, - ownerWorkspaceSamIds: Set[WorkspaceResourceSamResourceId] = Set.empty, - ownerProjectSamIds: Set[ProjectSamResourceId] = Set.empty - ): AllowlistAuthProvider = { - val mockAuthProvider: AllowlistAuthProvider = mock[AllowlistAuthProvider](defaultMockitoAnswer[IO]) - - when(mockAuthProvider.checkUserEnabled(isEq(userInfo))(any)).thenReturn(IO.unit) - when( - mockAuthProvider.listResourceIds[RuntimeSamResourceId](isEq(true), isEq(userInfo))( - any(runtimeSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[RuntimeSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(readerRuntimeSamIds)) - when( - mockAuthProvider.listResourceIds[WsmResourceSamResourceId](isEq(false), isEq(userInfo))( - any(wsmResourceSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[WsmResourceSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(readerWsmSamIds)) - when( - mockAuthProvider.isUserWorkspaceReader(any, isEq(userInfo))( - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(false)) - when( - mockAuthProvider.isUserWorkspaceReader(argThat(readerWorkspaceSamIds.contains(_)), isEq(userInfo))( - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(true)) - when( - mockAuthProvider.listResourceIds[ProjectSamResourceId](isEq(false), isEq(userInfo))( - any(projectSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[ProjectSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ) - .thenReturn(IO.pure(readerProjectSamIds)) - when( - mockAuthProvider.isUserWorkspaceOwner(any, isEq(userInfo))( - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(false)) - when( - mockAuthProvider.isUserWorkspaceOwner(argThat(ownerWorkspaceSamIds.contains(_)), isEq(userInfo))( - any(Ask[IO, TraceId].getClass) - ) - ).thenReturn(IO.pure(true)) - when( - mockAuthProvider.listResourceIds[ProjectSamResourceId](isEq(true), isEq(userInfo))( - any(projectSamResourceAction.getClass), - any(AppSamResourceAction.getClass), - any(Decoder[ProjectSamResourceId].getClass), - any(Ask[IO, TraceId].getClass) - ) - ) - .thenReturn(IO.pure(ownerProjectSamIds)) - - mockAuthProvider - } - def mockUserInfo(email: String = userEmail.toString()): UserInfo = UserInfo(OAuth2BearerToken(""), WorkbenchUserId(s"userId-${email}"), WorkbenchEmail(email), 0) val runtimeV2Service = new RuntimeV2ServiceInterp[IO]( serviceConfig, - allowListAuthProvider, QueueFactory.makePublisherQueue(), QueueFactory.makeDateAccessedQueue(), wsmClientProvider, @@ -281,7 +102,6 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with val runtimeV2Service2 = new RuntimeV2ServiceInterp[IO]( serviceConfig, - allowListAuthProvider2, QueueFactory.makePublisherQueue(), QueueFactory.makeDateAccessedQueue(), wsmClientProvider, @@ -376,6 +196,15 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with val runtimeName = RuntimeName("clusterName1") val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mock[SamService[IO]] + when( + samService.checkAuthorized(unauthorizedUserInfo.accessToken.token, + WorkspaceResourceSamResourceId(workspaceId), + WorkspaceAction.Compute + ) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + val runtimeV2Service = makeInterp(samService = samService) + val thrown = the[ForbiddenError] thrownBy { runtimeV2Service .createRuntime(unauthorizedUserInfo, runtimeName, workspaceId, false, defaultCreateAzureRuntimeReq) @@ -514,7 +343,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with } it should "fail to create a runtime with existing disk if disk is attached to non-deleted runtime" in isolatedDbTest { - val samService = mockSamForCreateRuntimeTest(userInfo) + val samService = mockSamForCreateRuntime(userInfo) val runtimeV2Service = makeInterp(samService = samService) val res = for { _ <- runtimeV2Service @@ -542,7 +371,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with } it should "fail to create a runtime if one exists in the workspace" in isolatedDbTest { - val samService = mockSamForCreateRuntimeTest(userInfo) + val samService = mockSamForCreateRuntime(userInfo) val runtimeV2Service = makeInterp(samService = samService) runtimeV2Service @@ -678,9 +507,13 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with ) // this email is allowlisted val runtimeName = RuntimeName("clusterName1") val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mockSamForCreateRuntime(userInfo) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) val publisherQueue = QueueFactory.makePublisherQueue() - val azureService = makeInterp(publisherQueue) + val azureService = makeInterp(publisherQueue, samService = samService) val res = for { _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with @@ -710,152 +543,32 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } - it should "get a runtime when caller is creator" in isolatedDbTest { - - val userInfo = UserInfo( - OAuth2BearerToken(""), - WorkbenchUserId("userId"), - WorkbenchEmail("user1@example.com"), - 0 - ) // this email is allowlisted - - implicit val mockSamResourceAction: SamResourceAction[WsmResourceSamResourceId, WsmResourceAction] = - mock[SamResourceAction[WsmResourceSamResourceId, WsmResourceAction]] - - // test: user does not have access permission for this resource (but they are the creator) - val mockAuthProvider = mock[AllowlistAuthProvider](defaultMockitoAnswer[IO]) - // User passes isUserWorkspaceReader - when(mockAuthProvider.isUserWorkspaceReader(any, any)(any)).thenReturn(IO.pure(true)) - // Calls to a method on a mock which is not stubbed explicitly will return null; - // the user cannot pass mockAuthProvider.hasPermission unless we stub it - - val runtimeName = RuntimeName("clusterName1") - val workspaceId = WorkspaceId(UUID.randomUUID()) - - val publisherQueue = QueueFactory.makePublisherQueue() - - val setupAzureService = makeInterp(publisherQueue) - val testAzureService = makeInterp(publisherQueue, mockAuthProvider) - - val res = for { - _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with - - _ <- setupAzureService - .createRuntime( - userInfo, - runtimeName, - workspaceId, - false, - defaultCreateAzureRuntimeReq - ) - azureCloudContext <- wsmClientProvider.getWorkspace("token", workspaceId).map(_.get.azureContext) - clusterOpt <- clusterQuery - .getActiveClusterByNameMinimal(CloudContext.Azure(azureCloudContext.get), runtimeName)( - scala.concurrent.ExecutionContext.global - ) - .transaction - cluster = clusterOpt.get - getResponse <- testAzureService.getRuntime(userInfo, runtimeName, workspaceId) - } yield { - getResponse.clusterName shouldBe runtimeName - getResponse.auditInfo.creator shouldBe userInfo.userEmail - - } - - res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) - } - - it should "fail to get a runtime when caller has workspace permission, lacks resource permissions, and is not creator" in isolatedDbTest { - val userInfoNoncreator = - UserInfo(OAuth2BearerToken(""), WorkbenchUserId("anotherUserId"), WorkbenchEmail("another_user@example.com"), 0) - - implicit val mockSamResourceAction: SamResourceAction[WsmResourceSamResourceId, WsmResourceAction] = - mock[SamResourceAction[WsmResourceSamResourceId, WsmResourceAction]] - - // test: user does not have access permission for this resource (and they are not the creator) - val mockAuthProvider = mock[AllowlistAuthProvider](defaultMockitoAnswer[IO]) - // User passes isUserWorkspaceReader - when(mockAuthProvider.isUserWorkspaceReader(any(), any())(any())).thenReturn(IO.pure(true)) - when(mockAuthProvider.hasPermission(any(), any(), any())(any(), any())).thenReturn(IO.pure(false)) - - val runtimeName = RuntimeName("clusterName1") - val workspaceId = WorkspaceId(UUID.randomUUID()) - - val publisherQueue = QueueFactory.makePublisherQueue() - - val setupAzureService = makeInterp(publisherQueue) - val testAzureService = makeInterp(publisherQueue, mockAuthProvider) - - val res = for { - _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with - - _ <- setupAzureService - .createRuntime( - userInfo, - runtimeName, - workspaceId, - false, - defaultCreateAzureRuntimeReq - ) - azureCloudContext <- wsmClientProvider.getWorkspace("token", workspaceId).map(_.get.azureContext) - clusterOpt <- clusterQuery - .getActiveClusterByNameMinimal(CloudContext.Azure(azureCloudContext.get), runtimeName)( - scala.concurrent.ExecutionContext.global - ) - .transaction - cluster = clusterOpt.get - _ <- testAzureService.getRuntime(userInfoNoncreator, runtimeName, workspaceId) - } yield () - - the[RuntimeNotFoundException] thrownBy { - res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) - } + it should "fail to get a non-existent runtime" in isolatedDbTest { + runtimeV2Service + .getRuntime(userInfo, RuntimeName("non-existent"), workspaceId) + .attempt + .unsafeRunSync()(cats.effect.unsafe.IORuntime.global) + .swap + .toOption + .get shouldBe a[RuntimeNotFoundByWorkspaceIdException] } - it should "fail to get a runtime when caller has no permission" in isolatedDbTest { + it should "fail to get a runtime when caller lacks permission" in isolatedDbTest { val runtimeName = RuntimeName("clusterName1") val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mockSamForCreateRuntime(userInfo) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) val publisherQueue = QueueFactory.makePublisherQueue() - val azureService = makeInterp(publisherQueue) - - val res = for { - _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with - - _ <- azureService - .createRuntime( - userInfo, - runtimeName, - workspaceId, - false, - defaultCreateAzureRuntimeReq - ) - azureCloudContext <- wsmClientProvider.getWorkspace("token", workspaceId).map(_.get.azureContext) - clusterOpt <- clusterQuery - .getActiveClusterByNameMinimal(CloudContext.Azure(azureCloudContext.get), runtimeName)( - scala.concurrent.ExecutionContext.global - ) - .transaction - _ <- azureService.getRuntime(unauthorizedUserInfo, runtimeName, workspaceId) - } yield () - the[ForbiddenError] thrownBy { - res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) - } - } - - it should "fail to get a runtime if the creator loses access to workspace" in isolatedDbTest { - val runtimeName = RuntimeName("clusterName1") - val workspaceId = WorkspaceId(UUID.randomUUID()) - - val publisherQueue = QueueFactory.makePublisherQueue() - val azureService = makeInterp(publisherQueue, allowListAuthProvider) - val azureService2 = makeInterp(publisherQueue, allowListAuthProvider2) + val testAzureService = makeInterp(publisherQueue, samService = samService) val res = for { _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with - _ <- azureService + _ <- testAzureService .createRuntime( userInfo, runtimeName, @@ -863,16 +576,10 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with false, defaultCreateAzureRuntimeReq ) - azureCloudContext <- wsmClientProvider.getWorkspace("token", workspaceId).map(_.get.azureContext) - clusterOpt <- clusterQuery - .getActiveClusterByNameMinimal(CloudContext.Azure(azureCloudContext.get), runtimeName)( - scala.concurrent.ExecutionContext.global - ) - .transaction - _ <- azureService2.getRuntime(userInfo, runtimeName, workspaceId) + _ <- testAzureService.getRuntime(userInfo, runtimeName, workspaceId) } yield () - the[ForbiddenError] thrownBy { + the[RuntimeNotFoundByWorkspaceIdException] thrownBy { res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } } @@ -905,6 +612,15 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with // User is runtime creator, but does not have access to the workspace val userInfo = UserInfo(OAuth2BearerToken(""), WorkbenchUserId("user"), WorkbenchEmail("email"), 0) val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mock[SamService[IO]] + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.StopStartRuntime))(any()) + ) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) + val interp = makeInterp(samService = samService) val res = for { runtime <- IO( @@ -916,7 +632,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with ) .save() ) - r <- runtimeV2Service + r <- interp .startRuntime(userInfo, runtime.runtimeName, runtime.workspaceId.get) .attempt } yield { @@ -990,6 +706,15 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with it should "fail to stop a runtime if permission denied" in isolatedDbTest { val userInfo = UserInfo(OAuth2BearerToken(""), WorkbenchUserId("user"), WorkbenchEmail("email"), 0) val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mock[SamService[IO]] + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.StopStartRuntime))(any()) + ) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) + val interp = makeInterp(samService = samService) val res = for { runtime <- IO( @@ -1001,7 +726,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with ) .save() ) - r <- runtimeV2Service + r <- interp .stopRuntime(userInfo, runtime.runtimeName, runtime.workspaceId.get) .attempt } yield { @@ -1442,12 +1167,18 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } - it should "fail to delete a runtime when caller has no permission" in isolatedDbTest { + it should "fail to delete a runtime when caller is missing delete permission" in isolatedDbTest { val runtimeName = RuntimeName("clusterName1") val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mockSamForCreateRuntime(userInfo) + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.DeleteRuntime))(any())) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.unit) val publisherQueue = QueueFactory.makePublisherQueue() - val azureService = makeInterp(publisherQueue) + val azureService = makeInterp(publisherQueue, samService = samService) val res = for { _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with @@ -1470,7 +1201,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with cluster = clusterOpt.get now <- IO.realTimeInstant _ <- clusterQuery.updateClusterStatus(cluster.id, RuntimeStatus.Running, now).transaction - _ <- azureService.deleteRuntime(unauthorizedUserInfo, runtimeName, workspaceId, true) + _ <- azureService.deleteRuntime(userInfo, runtimeName, workspaceId, true) } yield () the[ForbiddenError] thrownBy { @@ -1478,13 +1209,18 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with } } - it should "fail to delete a runtime when creator has lost workspace permission" in isolatedDbTest { + it should "fail to delete a runtime and not reveal its existence when user has no access to it" in isolatedDbTest { val runtimeName = RuntimeName("clusterName1") val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mockSamForCreateRuntime(userInfo) + when(samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.DeleteRuntime))(any())) + .thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.GetRuntimeStatus))(any()) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) val publisherQueue = QueueFactory.makePublisherQueue() - val azureService = makeInterp(publisherQueue, allowListAuthProvider) - val azureService2 = makeInterp(publisherQueue, allowListAuthProvider2) + val azureService = makeInterp(publisherQueue, samService = samService) val res = for { _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with @@ -1507,10 +1243,10 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with cluster = clusterOpt.get now <- IO.realTimeInstant _ <- clusterQuery.updateClusterStatus(cluster.id, RuntimeStatus.Running, now).transaction - _ <- azureService2.deleteRuntime(userInfo, runtimeName, workspaceId, true) + _ <- azureService.deleteRuntime(userInfo, runtimeName, workspaceId, true) } yield () - the[ForbiddenError] thrownBy { + the[RuntimeNotFoundByWorkspaceIdException] thrownBy { res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } } @@ -1520,7 +1256,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with val runtimeName_2 = RuntimeName("clusterName2") val runtimeName_3 = RuntimeName("clusterName3") val workspaceId = WorkspaceId(UUID.randomUUID()) - val samService = mockSamForCreateRuntimeTest(userInfo) + val samService = mockSamForCreateRuntime(userInfo) when(samService.getUserEmail(userInfo2.accessToken.token)).thenReturn(IO.pure(userInfo2.userEmail)) when(samService.getUserEmail(userInfo3.accessToken.token)).thenReturn(IO.pure(userInfo3.userEmail)) @@ -1658,7 +1394,7 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with val runtimeName_2 = RuntimeName("clusterName2") val workspaceId = WorkspaceId(UUID.randomUUID()) - val samService = mockSamForCreateRuntimeTest(userInfo) + val samService = mockSamForCreateRuntime(userInfo) when(samService.getUserEmail(userInfo2.accessToken.token)).thenReturn(IO.pure(userInfo2.userEmail)) val publisherQueue = QueueFactory.makePublisherQueue() @@ -1721,18 +1457,10 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with val projectIdGcp = cloudContextGcp.asString val workspaceIdAzure = UUID.randomUUID.toString - val mockAuthProvider = mockAuthorize( - userInfo, - Set(RuntimeSamResourceId(runtimeId1), RuntimeSamResourceId(runtimeId2)), - Set.empty, - Set(WorkspaceResourceSamResourceId(WorkspaceId(UUID.fromString(workspaceIdAzure)))), - Set(ProjectSamResourceId(GoogleProject(projectIdGcp))) - ) - val samService = mock[SamService[IO]] when(samService.listResources(isEq(userInfo.accessToken.token), isEq(RuntimeSamResource.resourceType))(any())) .thenReturn(IO.pure(List(runtimeId1, runtimeId2))) - val testService = makeInterp(authProvider = mockAuthProvider, samService = samService) + val testService = makeInterp(samService = samService) val res = for { samResource1 <- IO(RuntimeSamResourceId(runtimeId1)) @@ -2127,59 +1855,36 @@ class RuntimeV2ServiceInterpSpec extends AnyFlatSpec with LeonardoTestSuite with val runtimeName = RuntimeName("clusterName1") val workspaceId = WorkspaceId(UUID.randomUUID()) + val samService = mockSamForCreateRuntime(userInfo) + when( + samService.checkAuthorized(isEq(userInfo.accessToken.token), any(), isEq(RuntimeAction.ModifyRuntime))( + any() + ) + ).thenReturn(IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId("")))) val publisherQueue = QueueFactory.makePublisherQueue() val dateAccessedQueue = QueueFactory.makeDateAccessedQueue() - val azureService = makeInterp(publisherQueue, dateAccessedQueue = dateAccessedQueue) + val azureService = makeInterp(publisherQueue, dateAccessedQueue = dateAccessedQueue, samService = samService) val res = for { _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with _ <- azureService .createRuntime( - unauthorizedUserInfo, // this email is not allowlisted + userInfo, runtimeName, workspaceId, false, defaultCreateAzureRuntimeReq ) - _ <- azureService.updateDateAccessed(unauthorizedUserInfo, workspaceId, runtimeName) + _ <- azureService.updateDateAccessed(userInfo, workspaceId, runtimeName) } yield () val thrown = the[ForbiddenError] thrownBy { res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) } - thrown shouldBe ForbiddenError(unauthorizedEmail) + thrown shouldBe ForbiddenError(userInfo.userEmail) } - - it should "not update date accessed when user has lost access to workspace" in isolatedDbTest { - val runtimeName = RuntimeName("clusterName1") - val workspaceId = WorkspaceId(UUID.randomUUID()) - - val publisherQueue = QueueFactory.makePublisherQueue() - val azureService = makeInterp(publisherQueue) - val azureService2 = makeInterp(publisherQueue, allowListAuthProvider2) - - val res = for { - _ <- publisherQueue.tryTake // just to make sure there's no messages in the queue to start with - - _ <- azureService - .createRuntime( - userInfo, - runtimeName, - workspaceId, - false, - defaultCreateAzureRuntimeReq - ) - azureCloudContext <- wsmClientProvider.getWorkspace("token", workspaceId).map(_.get.azureContext) - _ <- azureService2.updateDateAccessed(userInfo, workspaceId, runtimeName) - } yield () - - the[ForbiddenError] thrownBy { - res.unsafeRunSync()(cats.effect.unsafe.IORuntime.global) - } - } - } object TestContext extends Enumeration { diff --git a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/mocks.scala b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/mocks.scala index 3a101894c7..240176ccdc 100644 --- a/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/mocks.scala +++ b/http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/mocks.scala @@ -1,5 +1,6 @@ package org.broadinstitute.dsde.workbench.leonardo +import akka.http.scaladsl.model.StatusCodes import cats.data.NonEmptyList import cats.effect.IO import cats.mtl.Ask @@ -21,7 +22,7 @@ import org.broadinstitute.dsde.workbench.google2.mock.BaseFakeGoogleStorage import org.broadinstitute.dsde.workbench.google2.{GKEModels, GcsBlobName, GetMetadataResponse, KubernetesModels, PvName} import org.broadinstitute.dsde.workbench.leonardo.CommonTestData._ import org.broadinstitute.dsde.workbench.leonardo.SamResourceId.{AppSamResourceId, WorkspaceResourceSamResourceId} -import org.broadinstitute.dsde.workbench.leonardo.dao.sam.SamService +import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService} import org.broadinstitute.dsde.workbench.leonardo.model.{LeoAuthProvider, SamResource, SamResourceAction} import org.broadinstitute.dsde.workbench.leonardo.util._ import org.broadinstitute.dsde.workbench.model.google.{GcsBucketName, GoogleProject} @@ -312,7 +313,9 @@ class BaseMockSamService extends SamService[IO] { action: org.broadinstitute.dsde.workbench.leonardo.SamResourceAction )(implicit ev: Ask[IO, AppContext] - ): IO[Unit] = IO.unit + ): IO[Unit] = if (bearerToken.equals(unauthorizedUserInfo.accessToken.token)) + IO.raiseError(SamException.create("no access", StatusCodes.Forbidden.intValue, TraceId(""))) + else IO.unit override def listResources(bearerToken: String, samResourceType: SamResourceType)(implicit ev: Ask[IO, AppContext]