Skip to content

Commit

Permalink
Add Knn Search support (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
hkulekci authored Sep 26, 2024
1 parent c57c0fe commit d61ea65
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:

services:
elasticsearch:
image: elasticsearch:8.0.0
image: elasticsearch:8.4.0
ports:
- 9200:9200
env:
Expand Down
222 changes: 222 additions & 0 deletions src/Knn/Knn.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
<?php

/*
* This file is part of the ONGR package.
*
* (c) NFQ Technologies UAB <[email protected]>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace ONGR\ElasticsearchDSL\Knn;

use ONGR\ElasticsearchDSL\BuilderInterface;
use ONGR\ElasticsearchDSL\FieldAwareTrait;

class Knn implements BuilderInterface
{
use FieldAwareTrait;

/**
* @var string
*/
private $field;

/**
* @var array
*/
private $queryVector;

/**
* @var int
*/
private $k;

/**
* @var int
*/
private $numCandidates;

/**
* @var float|null
*/
private $boost;

/**
* @var float
*/
private $similarity = null;

/**
* @var BuilderInterface
*/
private $filter = null;


/**
* TermSuggest constructor.
* @param string $field
* @param array $queryVector
* @param int $k
* @param int $numCandidates
*/
public function __construct(
string $field,
array $queryVector,
int $k,
int $numCandidates
) {
$this->setField($field);
$this->setQueryVector($queryVector);
$this->setK($k);
$this->setNumCandidates($numCandidates);
}

/**
* @return string
*/
public function getField(): string
{
return $this->field;
}

/**
* @param string $field
*/
public function setField(string $field): void
{
$this->field = $field;
}

/**
* @return array
*/
public function getQueryVector(): array
{
return $this->queryVector;
}

/**
* @param array $queryVector
*/
public function setQueryVector(array $queryVector): void
{
$this->queryVector = $queryVector;
}

/**
* @return int
*/
public function getK(): int
{
return $this->k;
}

/**
* @param int $k
*/
public function setK(int $k): void
{
$this->k = $k;
}

/**
* @return int
*/
public function getNumCandidates(): int
{
return $this->numCandidates;
}

/**
* @param int $numCandidates
*/
public function setNumCandidates(int $numCandidates): void
{
$this->numCandidates = $numCandidates;
}

/**
* @return float|null
*/
public function getSimilarity(): ?float
{
return $this->similarity;
}

/**
* @param float $similarity
*/
public function setSimilarity(float $similarity): void
{
$this->similarity = $similarity;
}

/**
* @return float|null
*/
public function getBoost(): ?float
{
return $this->boost;
}

/**
* @param float $boost
*/
public function setBoost(float $boost): void
{
$this->boost = $boost;
}

/**
* @return BuilderInterface|null
*/
public function getFilter(): ?BuilderInterface
{
return $this->filter;
}

/**
* @param BuilderInterface $filter
*/
public function setFilter(BuilderInterface $filter): void
{
$this->filter = $filter;
}

/**
* {@inheritdoc}
*/
public function getType()
{
return 'knn';
}

/**
* {@inheritdoc}
*/
public function toArray()
{
$output = [
'field' => $this->getField(),
'query_vector' => $this->getQueryVector(),
'k' => $this->getK(),
'num_candidates' => $this->getNumCandidates(),
];

if ($this->getSimilarity()) {
$output['similarity'] = $this->getSimilarity();
}

if ($this->getBoost()) {
$output['boost'] = $this->getBoost();
}

if ($this->getFilter()) {
$output['filter'] = $this->getFilter()->toArray();
}

return $output;
}
}
17 changes: 17 additions & 0 deletions src/Search.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\SearchEndpointFactory;
Expand Down Expand Up @@ -229,6 +230,22 @@ public function addQuery(BuilderInterface $query, $boolType = BoolQuery::MUST, $
return $this;
}


/**
* Adds Knn to the search.
*
* @param BuilderInterface $query
*
* @return $this
*/
public function addKnn(BuilderInterface $query)
{
$endpoint = $this->getEndpoint(KnnEndpoint::NAME);
$endpoint->add($query);

return $this;
}

/**
* Returns endpoint instance.
*
Expand Down
63 changes: 63 additions & 0 deletions src/SearchEndpoint/KnnEndpoint.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<?php

/*
* This file is part of the ONGR package.
*
* (c) NFQ Technologies UAB <[email protected]>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace ONGR\ElasticsearchDSL\SearchEndpoint;

use ONGR\ElasticsearchDSL\BuilderInterface;
use ONGR\ElasticsearchDSL\Knn\Knn;
use Symfony\Component\Serializer\Normalizer\NormalizerInterface;

/**
* Search suggest dsl endpoint.
*/
class KnnEndpoint extends AbstractSearchEndpoint
{
/**
* Endpoint name
*/
const NAME = 'knn';

public function add(BuilderInterface $builder, $key = null)
{
if ($builder instanceof Knn) {
return parent::add($builder, $key);
}

throw new \LogicException('Add Knn builder instead!');
}

/**
* {@inheritdoc}
*/
public function normalize(
NormalizerInterface $normalizer,
$format = null,
array $context = []
): array|string|int|float|bool {
$knns = $this->getAll();
if (count($knns) === 1) {
/** @var Knn $knn */
$knn = array_values($knns)[0];
return $knn->toArray();
}

if (count($knns) > 1) {
$output = [];
/** @var Knn $knn */
foreach ($knns as $knn) {
$output[] = $knn->toArray();
}
return $output;
}

return [];
}
}
15 changes: 8 additions & 7 deletions src/SearchEndpoint/SearchEndpointFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class SearchEndpointFactory
* @var array Holds namespaces for endpoints.
*/
private static $endpoints = [
'query' => 'ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint',
'post_filter' => 'ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint',
'sort' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SortEndpoint',
'highlight' => 'ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint',
'aggregations' => 'ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint',
'suggest' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SuggestEndpoint',
'inner_hits' => 'ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint',
'query' => QueryEndpoint::class,
'knn' => KnnEndpoint::class,
'post_filter' => PostFilterEndpoint::class,
'sort' => SortEndpoint::class,
'highlight' => HighlightEndpoint::class,
'aggregations' => AggregationsEndpoint::class,
'suggest' => SuggestEndpoint::class,
'inner_hits' => InnerHitsEndpoint::class,
];

/**
Expand Down
4 changes: 2 additions & 2 deletions src/SearchEndpoint/SuggestEndpoint.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace ONGR\ElasticsearchDSL\SearchEndpoint;

use ONGR\ElasticsearchDSL\Suggest\TermSuggest;
use ONGR\ElasticsearchDSL\Suggest\Suggest;
use Symfony\Component\Serializer\Normalizer\NormalizerInterface;

/**
Expand All @@ -34,7 +34,7 @@ public function normalize(
): array|string|int|float|bool {
$output = [];
if (count($this->getAll()) > 0) {
/** @var TermSuggest $suggest */
/** @var Suggest $suggest */
foreach ($this->getAll() as $suggest) {
$output = array_merge($output, $suggest->toArray());
}
Expand Down
1 change: 0 additions & 1 deletion src/Suggest/Suggest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

use ONGR\ElasticsearchDSL\NamedBuilderInterface;
use ONGR\ElasticsearchDSL\ParametersTrait;
use Symfony\Component\Serializer\Exception\InvalidArgumentException;

class Suggest implements NamedBuilderInterface
{
Expand Down
Loading

0 comments on commit d61ea65

Please sign in to comment.