Skip to content

Commit

Permalink
Add SQL string literal escape method (#1017)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvorisek authored Jul 6, 2022
1 parent 836b28c commit f89264b
Show file tree
Hide file tree
Showing 33 changed files with 432 additions and 812 deletions.
15 changes: 0 additions & 15 deletions docs/persistence/sql/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,21 +299,6 @@ parts of the query. You must not call them in normal circumstances.
$query->consume('first_name'); // `first_name`
$query->consume($other_query); // will merge parameters and return string

.. php:method:: escape($value)
Creates new expression where $value appears escaped. Use this method as a
conventional means of specifying arguments when you think they might have
a nasty back-ticks or commas in the field names. I generally **discourage**
you from using this method. Example use would be::

$query->field('foo, bar'); // escapes and adds 2 fields to the query
$query->field($query->escape('foo, bar')); // adds field `foo, bar` to the query
$query->field(['foo, bar']); // adds single field `foo, bar`

$query->order('foo desc'); // escapes and add `foo` desc to the query
$query->field($query->escape('foo desc')); // adds field `foo desc` to the query
$query->field(['foo desc']); // adds `foo` desc anyway

.. php:method:: escapeIdentifier($sql_code)
Always surrounds `$sql code` with back-ticks.
Expand Down
3 changes: 2 additions & 1 deletion phpunit.xml.dist
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
</listeners>
<coverage>
<include>
<directory suffix=".php">src</directory>
<directory>src</directory>
<directory>tests</directory>
</include>
<report>
<php outputFile="coverage/phpunit.cov" />
Expand Down
17 changes: 6 additions & 11 deletions src/Persistence/Sql.php
Original file line number Diff line number Diff line change
Expand Up @@ -168,28 +168,23 @@ protected function initPersistence(Model $model): void

/**
* Creates new Expression object from expression string.
*
* @param mixed $expr
*/
public function expr(Model $model, $expr, array $args = []): Expression
public function expr(Model $model, string $template, array $arguments = []): Expression
{
if (!is_string($expr)) {
return $this->getConnection()->expr($expr, $args);
}
preg_replace_callback(
'~\[\w*\]|\{\w*\}~',
function ($matches) use (&$args, $model) {
function ($matches) use ($model, &$arguments) {
$identifier = substr($matches[0], 1, -1);
if ($identifier && !isset($args[$identifier])) {
$args[$identifier] = $model->getField($identifier);
if ($identifier !== '' && !isset($arguments[$identifier])) {
$arguments[$identifier] = $model->getField($identifier);
}

return $matches[0];
},
$expr
$template
);

return $this->getConnection()->expr($expr, $args);
return $this->getConnection()->expr($template, $arguments);
}

/**
Expand Down
3 changes: 1 addition & 2 deletions src/Persistence/Sql/Connection.php
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,8 @@ public function dsql($properties = []): Query
* Returns Expression object with connection already set.
*
* @param string|array $properties
* @param array $arguments
*/
public function expr($properties = [], $arguments = null): Expression
public function expr($properties = [], array $arguments = []): Expression
{
$c = $this->expression_class;
$e = new $c($properties, $arguments);
Expand Down
147 changes: 79 additions & 68 deletions src/Persistence/Sql/Expression.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use Doctrine\DBAL\ParameterType;
use Doctrine\DBAL\Platforms\OraclePlatform;
use Doctrine\DBAL\Platforms\PostgreSQLPlatform;
use Doctrine\DBAL\Platforms\SqlitePlatform;
use Doctrine\DBAL\Platforms\SQLServerPlatform;
use Doctrine\DBAL\Result as DbalResult;

Expand Down Expand Up @@ -74,39 +75,18 @@ class Expression implements Expressionable, \ArrayAccess
* If $properties is passed as string, then it's treated as template.
*
* @param string|array $properties
* @param array $arguments
*/
public function __construct($properties = [], $arguments = null)
public function __construct($properties = [], array $arguments = [])
{
// save template
if (is_string($properties)) {
$properties = ['template' => $properties];
} elseif (!is_array($properties)) {
throw (new Exception('Incorrect use of Expression constructor'))
->addMoreInfo('properties', $properties)
->addMoreInfo('arguments', $arguments);
}

// supports passing template as property value without key 'template'
if (isset($properties[0])) {
$properties['template'] = $properties[0];
unset($properties[0]);
}

// save arguments
if ($arguments !== null) {
if (!is_array($arguments)) {
throw (new Exception('Expression arguments must be an array'))
->addMoreInfo('properties', $properties)
->addMoreInfo('arguments', $arguments);
}
$this->args['custom'] = $arguments;
}

// deal with remaining properties
foreach ($properties as $key => $val) {
$this->{$key} = $val;
}

$this->args['custom'] = $arguments;
}

/**
Expand Down Expand Up @@ -170,28 +150,10 @@ public function offsetUnset($offset): void
* new expression to the same connection as the parent.
*
* @param string|array $properties
* @param array $arguments
*
* @return Expression
*/
public function expr($properties = [], $arguments = null)
public function expr($properties = [], array $arguments = []): self
{
if ($this->connection !== null) {
// TODO condition above always satisfied when connection is set - adjust tests,
// so connection is always set and remove the code below
return $this->connection->expr($properties, $arguments);
}

// make a smart guess :) when connection is not set
if ($this instanceof Query) {
$e = new self($properties, $arguments);
} else {
$e = new static($properties, $arguments);
}

$e->identifierEscapeChar = $this->identifierEscapeChar;

return $e;
return $this->connection->expr($properties, $arguments);
}

/**
Expand Down Expand Up @@ -281,21 +243,6 @@ protected function consume($expr, string $escapeMode = self::ESCAPE_PARAM)
return $sql;
}

/**
* Creates new expression where $value appears escaped. Use this
* method as a conventional means of specifying arguments when you
* think they might have a nasty back-ticks or commas in the field
* names.
*
* @param string $value
*
* @return Expression
*/
public function escape($value)
{
return $this->expr('{}', [$value]);
}

/**
* Converts value into parameter and returns reference. Use only during
* query rendering. Consider using `consume()` instead, which will
Expand All @@ -315,7 +262,72 @@ protected function escapeParam($value): string
}

/**
* Escapes argument by adding backticks around it.
* This method should be used only when string value cannot be bound.
*/
protected function escapeStringLiteral(string $value): string
{
$platform = $this->connection->getDatabasePlatform();
if ($platform instanceof PostgreSQLPlatform || $platform instanceof SQLServerPlatform) {
$dummyPersistence = new Persistence\Sql($this->connection);
if (\Closure::bind(fn () => $dummyPersistence->binaryTypeValueIsEncoded($value), null, Persistence\Sql::class)()) {
$value = \Closure::bind(fn () => $dummyPersistence->binaryTypeValueDecode($value), null, Persistence\Sql::class)();

if ($platform instanceof PostgreSQLPlatform) {
return 'decode(\'' . bin2hex($value) . '\', \'hex\')';
}

return 'CONVERT(VARBINARY(MAX), \'' . bin2hex($value) . '\', 2)';
}
}

$parts = [];
foreach (explode("\0", $value) as $i => $v) {
if ($i > 0) {
if ($platform instanceof PostgreSQLPlatform) {
// will raise SQL error, PostgreSQL does not support \0 character
$parts[] = 'convert_from(decode(\'00\', \'hex\'), \'UTF8\')';
} elseif ($platform instanceof SQLServerPlatform) {
$parts[] = 'NCHAR(0)';
} elseif ($platform instanceof OraclePlatform) {
$parts[] = 'CHR(0)';
} else {
$parts[] = 'x\'00\'';
}
}

if ($v !== '') {
$parts[] = '\'' . str_replace('\'', '\'\'', $v) . '\'';
}
}
if ($parts === []) {
$parts = ['\'\''];
}

$buildConcatSqlFx = function (array $parts) use (&$buildConcatSqlFx, $platform): string {
if (count($parts) > 1) {
$partsLeft = array_slice($parts, 0, intdiv(count($parts), 2));
$partsRight = array_slice($parts, count($partsLeft));

$sqlLeft = $buildConcatSqlFx($partsLeft);
if ($platform instanceof SQLServerPlatform && count($partsLeft) === 1) {
$sqlLeft = 'CAST(' . $sqlLeft . ' AS NVARCHAR(MAX))';
}

return ($platform instanceof SqlitePlatform ? '(' : 'CONCAT(')
. $sqlLeft
. ($platform instanceof SqlitePlatform ? ' || ' : ', ')
. $buildConcatSqlFx($partsRight)
. ')';
}

return reset($parts);
};

return $buildConcatSqlFx($parts);
}

/**
* Escapes identifier from argument.
* This will allow you to use reserved SQL words as table or field
* names such as "table" as well as other characters that SQL
* permits in the identifiers (e.g. spaces or equation signs).
Expand Down Expand Up @@ -459,7 +471,7 @@ public function getDebugQuery(): string
} elseif (is_float($val)) {
$replacement = self::castFloatToString($val);
} elseif (is_string($val)) {
$replacement = '\'' . addslashes($val) . '\'';
$replacement = '\'' . str_replace('\'', '\'\'', $val) . '\'';
} else {
continue;
}
Expand Down Expand Up @@ -585,8 +597,7 @@ public function execute(object $connection = null, bool $fromExecuteStatement =
} elseif (is_string($val)) {
$type = ParameterType::STRING;

if ($platform instanceof PostgreSQLPlatform
|| $platform instanceof SQLServerPlatform) {
if ($platform instanceof PostgreSQLPlatform || $platform instanceof SQLServerPlatform) {
$dummyPersistence = new Persistence\Sql($this->connection);
if (\Closure::bind(fn () => $dummyPersistence->binaryTypeValueIsEncoded($val), null, Persistence\Sql::class)()) {
$val = \Closure::bind(fn () => $dummyPersistence->binaryTypeValueDecode($val), null, Persistence\Sql::class)();
Expand Down Expand Up @@ -689,11 +700,11 @@ private function castGetValue($v): ?string
}

// for PostgreSQL/Oracle CLOB/BLOB datatypes and PDO driver
if (is_resource($v) && get_resource_type($v) === 'stream' && (
$this->connection->getDatabasePlatform() instanceof PostgreSQLPlatform
|| $this->connection->getDatabasePlatform() instanceof OraclePlatform
)) {
$v = stream_get_contents($v);
if (is_resource($v) && get_resource_type($v) === 'stream') {
$platform = $this->connection->getDatabasePlatform();
if ($platform instanceof PostgreSQLPlatform || $platform instanceof OraclePlatform) {
$v = stream_get_contents($v);
}
}

return $v; // throw a type error if not null nor string
Expand Down
12 changes: 8 additions & 4 deletions src/Persistence/Sql/Mssql/ExpressionTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ trait ExpressionTrait
{
protected function escapeIdentifier(string $value): string
{
return preg_replace('~\]([^\[\]\'"(){}]*?\])~s', '[$1', parent::escapeIdentifier($value));
$res = parent::escapeIdentifier($value);

return $this->identifierEscapeChar === ']' && str_starts_with($res, ']') && str_ends_with($res, ']')
? '[' . substr($res, 1)
: $res;
}

public function render(): array
{
[$sql, $params] = parent::render();

// convert all SQL strings to NVARCHAR, eg 'text' to N'text'
$sql = preg_replace_callback('~N?(\'(?:\'\'|\\\\\'|[^\'])*+\')~s', function ($matches) {
return 'N' . $matches[1];
// convert all string literals to NVARCHAR, eg. 'text' to N'text'
$sql = preg_replace_callback('~N?\'(?:\'\'|\\\\\'|[^\'])*+\'~s', function ($matches) {
return (substr($matches[0], 0, 1) === 'N' ? '' : 'N') . $matches[0];
}, $sql);

return [$sql, $params];
Expand Down
2 changes: 1 addition & 1 deletion src/Persistence/Sql/Mssql/Query.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public function _render_limit(): ?string

public function groupConcat($field, string $delimiter = ',')
{
return $this->expr('string_agg({}, \'' . $delimiter . '\')', [$field]);
return $this->expr('string_agg({}, ' . $this->escapeStringLiteral($delimiter) . ')', [$field]);
}

public function exists()
Expand Down
5 changes: 5 additions & 0 deletions src/Persistence/Sql/Mysql/ExpressionTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

trait ExpressionTrait
{
protected function escapeStringLiteral(string $value): string
{
return str_replace('\\', '\\\\', parent::escapeStringLiteral($value));
}

protected function hasNativeNamedParamSupport(): bool
{
$dbalConnection = $this->connection->getConnection();
Expand Down
2 changes: 1 addition & 1 deletion src/Persistence/Sql/Mysql/Query.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class Query extends BaseQuery

public function groupConcat($field, string $delimiter = ',')
{
return $this->expr('group_concat({} separator \'' . str_replace('\'', '\'\'', $delimiter) . '\')', [$field]);
return $this->expr('group_concat({} separator ' . $this->escapeStringLiteral($delimiter) . ')', [$field]);
}
}
Loading

0 comments on commit f89264b

Please sign in to comment.