Skip to content

Commit

Permalink
[pysrc2cpg] Handle self Property Type Recovery (#2249)
Browse files Browse the repository at this point in the history
* [pysrc2cpg] Handle `self` Property Type Recovery

- Load `self` field assignments into the global table
- Approximate field references by name (TODO: match objects with their type decl)
- Using dummy <indexAccess> for dict access types (TODO: model dict memory during analysis)
- Clearing global tableat the end of analysis
  • Loading branch information
DavidBakerEffendi authored Feb 6, 2023
1 parent cf045d4 commit f8fdeb2
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import scala.util.Try
class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {

override def computationalUnit: Traversal[File] = cpg.file

override def generateRecoveryForCompilationUnitTask(
unit: File,
builder: DiffGraphBuilder,
Expand Down Expand Up @@ -140,7 +141,7 @@ class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey
* the graph builder
*/
class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey])
extends RecoverForXCompilationUnit[File](cpg, cu, builder, globalTable) {
extends RecoverForXCompilationUnit[File](cu, builder) {

/** Adds built-in functions to expect.
*/
Expand Down Expand Up @@ -215,7 +216,7 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
val fieldTypes = symbolTable
.get(CallAlias(i.name))
.flatMap(recModule => globalTable.get(FieldVar(recModule, f.canonicalName)))
symbolTable.append(assigned, fieldTypes)
if (fieldTypes.nonEmpty) symbolTable.append(assigned, fieldTypes)
case List(assigned: Identifier, i: Identifier, f: FieldIdentifier)
if symbolTable
.contains(LocalVar(i.name)) =>
Expand Down Expand Up @@ -245,12 +246,46 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
val callAlias = CallAlias(s"${i.name}.${f.canonicalName}")
val importedTypes = symbolTable.get(callAlias)
setIdentifierFromFunctionType(assigned, callAlias.identifier, callAlias.identifier, importedTypes)
case List(assigned: Identifier, i: Identifier, f: FieldIdentifier) =>
// TODO: This is really tricky to find without proper object tracking, so we match name only
val fieldTypes = globalTable.view.filter(_._1.identifier.equals(f.canonicalName)).flatMap(_._2).toSet
if (fieldTypes.nonEmpty) symbolTable.append(assigned, fieldTypes)
case _ =>
}
// Field load from call
case List(fl: Call, c: Call) if fl.name.equals(Operators.fieldAccess) && symbolTable.contains(c) =>
(fl.astChildren.l, c.astChildren.l) match {
case (List(self: Identifier, fieldIdentifier: FieldIdentifier), args: List[_]) =>
symbolTable.append(fieldIdentifier, symbolTable.get(c))
globalTable.append(fieldVarName(fieldIdentifier), symbolTable.get(c))
case _ =>
}
// Field load from index access
case List(fl: Call, c: Call) if fl.name.equals(Operators.fieldAccess) && c.name.equals(Operators.indexAccess) =>
(fl.astChildren.l, c.astChildren.l) match {
case (List(self: Identifier, fieldIdentifier: FieldIdentifier), ::(rhsFAccess: Call, _))
if rhsFAccess.name.equals(Operators.fieldAccess) =>
val rhsField = rhsFAccess.fieldAccess.fieldIdentifier.head
// TODO: Check if a type for the RHS index access is recovered
val types = symbolTable.get(rhsField).map(t => s"$t.<indexAccess>")
symbolTable.append(fieldIdentifier, types)
globalTable.append(fieldVarName(fieldIdentifier), types)
case _ =>
}
case _ =>
}
}

private def fieldVarName(f: FieldIdentifier): FieldVar = {
if (f.astSiblings.map(_.code).exists(_.contains("self"))) {
// This will match the <meta> type decl
FieldVar(f.method.typeDecl.fullName.head, f.canonicalName)
} else {
// This will typically match the <module>
FieldVar(f.file.method.fullName.head, f.canonicalName)
}
}

private def setIdentifierFromFunctionType(
i: Identifier,
callName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
| 'age': self.age,
| 'address': self.address
| }
|""".stripMargin)
|""".stripMargin).cpg

"resolve 'db' identifier types from import information" in {
val List(clientAssignment, clientElseWhere) = cpg.identifier("db").take(2).l
Expand Down Expand Up @@ -146,7 +146,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
|from foo import abs
|
|x = abs(-1)
|""".stripMargin)
|""".stripMargin).cpg

"resolve 'print' and 'max' calls" in {
val Some(printCall) = cpg.call("print").headOption
Expand Down Expand Up @@ -186,7 +186,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
|foo.db.deleteTable()
|""".stripMargin,
"bar.py"
)
).cpg

"resolve 'x' and 'y' locally under foo.py" in {
val Some(x) = cpg.file.name(".*foo.*").ast.isIdentifier.name("x").headOption
Expand Down Expand Up @@ -268,7 +268,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
|db = SQLAlchemy()
|""".stripMargin,
"app.py"
)
).cpg

"be determined as a variable reference and have its type recovered correctly" in {
cpg.identifier("db").map(_.typeFullName).toSet shouldBe Set("flask_sqlalchemy.py:<module>.SQLAlchemy")
Expand Down Expand Up @@ -300,7 +300,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
|import logging
|log = logging.getLogger(__name__)
|log.error("foo")
|""".stripMargin)
|""".stripMargin).cpg

"provide a dummy type" in {
val Some(log) = cpg.identifier("log").headOption
Expand All @@ -318,7 +318,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
|import urllib.request
|
|req = urllib.request.Request(url=apiUrl, data=dataBytes, method='POST')
|""".stripMargin)
|""".stripMargin).cpg

"reasonably determine the constructor type" in {
val Some(tmp0) = cpg.identifier("tmp0").headOption
Expand All @@ -328,4 +328,50 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
}
}

"a method call inherited from a super class should be recovered" should {
lazy val cpg = code(
"""from pymongo import MongoClient
|from django.conf import settings
|
|
|class MongoConnection(object):
| def __init__(self):
| DATABASES = settings.DATABASES
| self.client = MongoClient(
| host=[DATABASES['MONGO']['HOST']],
| username=DATABASES['MONGO']['USERNAME'],
| password=DATABASES['MONGO']['PASSWORD'],
| authSource=DATABASES['MONGO']['AUTH_DATABASE']
| )
| self.db = self.client[DATABASES['MONGO']['DATABASE']]
| self.collection = None
|
| def get_collection(self, name):
| self.collection = self.db[name]
|""".stripMargin,
"MongoConnection.py"
).moreCode(
"""
|from MongoConnection import MongoConnection
|
|class InstallationsDAO(MongoConnection):
| def __init__(self):
| super(InstallationsDAO, self).__init__()
| self.get_collection("installations")
|
| def getCustomerId(self, installationId):
| res = self.collection.find_one({'_id': installationId})
| if res is None:
| return None
| return dict(res).get("customerId", None)
|""".stripMargin,
"InstallationDao.py"
).cpg

"recover a potential type for `self.collection` using the assignment at `get_collection` as a type hint" in {
val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption
selfFindFound.methodFullName shouldBe "pymongo.py:<module>.MongoClient.<init>.<indexAccess>.<indexAccess>.find_one"
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ object SBKey {
protected val logger: Logger = LoggerFactory.getLogger(getClass)
def fromNodeToLocalKey(node: AstNode): LocalKey = {
node match {
case n: Identifier => LocalVar(n.name)
case n: Local => LocalVar(n.name)
case n: Call => CallAlias(n.name)
case n: Method => CallAlias(n.name)
case n: MethodRef => CallAlias(n.code)
case n: Identifier => LocalVar(n.name)
case n: Local => LocalVar(n.name)
case n: Call => CallAlias(n.name)
case n: Method => CallAlias(n.name)
case n: MethodRef => CallAlias(n.code)
case n: FieldIdentifier => LocalVar(n.canonicalName)
case _ =>
throw new RuntimeException(s"Local node of type ${node.label} is not supported in the type recovery pass.")
}
Expand Down Expand Up @@ -112,7 +113,7 @@ class SymbolTable[K <: SBKey](fromNode: AstNode => K) {
def append(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] =
append(fromNode(node), typeFullNames)

private def append(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] = {
def append(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] = {
table.get(sbKey) match {
case Some(ts) => table.put(sbKey, ts ++ typeFullNames)
case None => table.put(sbKey, typeFullNames)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.joern.x2cpg.passes.frontend

import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes.{Method, _}
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames}
import io.shiftleft.passes.CpgPass
import io.shiftleft.semanticcpg.language._
Expand Down Expand Up @@ -49,11 +49,14 @@ abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations:
*/
protected val globalTable = new SymbolTable[GlobalKey](SBKey.fromNodeToGlobalKey)

override def run(builder: DiffGraphBuilder): Unit =
override def run(builder: DiffGraphBuilder): Unit = try {
for (_ <- 0 to iterations)
computationalUnit
.map(unit => generateRecoveryForCompilationUnitTask(unit, builder, globalTable).fork())
.foreach(_.get())
} finally {
globalTable.clear()
}

/** @return
* the computational units as per how the language is compiled. e.g. file.
Expand Down Expand Up @@ -150,10 +153,8 @@ abstract class SetXProcedureDefTask(node: CfgNode) extends RecursiveTask[Unit] {
* the [[AstNode]] type used to represent a computational unit of the language.
*/
abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
cpg: Cpg,
cu: ComputationalUnit,
builder: DiffGraphBuilder,
globalTable: SymbolTable[GlobalKey]
builder: DiffGraphBuilder
) extends RecursiveTask[Unit] {

/** Stores type information for local structures that live within this compilation unit, e.g. local variables.
Expand Down

0 comments on commit f8fdeb2

Please sign in to comment.