# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import pytest
from unittest import mock
from azure.core.paging import ItemPaged, PageIterator
from azure.core.credentials import AzureKeyCredential

from azure.search.documents._operations._patch import SearchPageIterator

from azure.search.documents.models import (
    FacetResult,
    SearchDocumentsResult,
    SearchResult,
)

from azure.search.documents import (
    IndexDocumentsBatch,
    SearchClient,
    RequestEntityTooLargeError,
    ApiVersion,
)

CREDENTIAL = AzureKeyCredential(key="test_api_key")

CRUD_METHOD_NAMES = [
    "upload_documents",
    "delete_documents",
    "merge_documents",
    "merge_or_upload_documents",
]

CRUD_METHOD_MAP = dict(zip(CRUD_METHOD_NAMES, ["upload", "delete", "merge", "mergeOrUpload"]))


class TestSearchClient:
    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin.get_document_count")
    def test_get_document_count(self, mock_count):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        client.get_document_count()
        assert mock_count.called
        assert mock_count.call_args[0] == ()
        assert len(mock_count.call_args[1]) == 0

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin.get_document")
    def test_get_document(self, mock_get):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        client.get_document("some_key")
        assert mock_get.called
        assert mock_get.call_args[0] == ("some_key",)
        assert len(mock_get.call_args[1]) == 0

        mock_get.reset()

        client.get_document("some_key", selected_fields="foo")
        assert mock_get.called
        assert mock_get.call_args[0] == ("some_key",)
        assert len(mock_get.call_args[1]) == 1
        assert mock_get.call_args[1]["selected_fields"] == "foo"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._search_post")
    def test_search_query_argument(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.search(search_text="search text")
        assert isinstance(result, ItemPaged)
        assert result._page_iterator_class is SearchPageIterator
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult({"key": "val"})]
        mock_search_post.return_value = search_result
        assert not mock_search_post.called
        next(result)
        assert mock_search_post.called
        assert mock_search_post.call_args[0] == ()
        # assert mock_search_post.call_args[1]["search_request"].search_text == "search text"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._search_post")
    def test_search_enable_elevated_read(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.search(
            search_text="search text",
            x_ms_enable_elevated_read=True,
            x_ms_query_source_authorization="aad:fake-user",
        )
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult({"key": "val"})]
        mock_search_post.return_value = search_result
        next(result)

        assert mock_search_post.called
        assert mock_search_post.call_args[1]["x_ms_enable_elevated_read"] is True
        assert mock_search_post.call_args[1]["x_ms_query_source_authorization"] == "aad:fake-user"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._suggest_post")
    def test_suggest_query_argument(self, mock_suggest_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.suggest(search_text="search text", suggester_name="sg")
        assert mock_suggest_post.called
        assert mock_suggest_post.call_args[0] == ()
        # assert mock_suggest_post.call_args[1]["suggest_request"].search_text == "search text"

    def test_suggest_bad_argument(self):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        with pytest.raises(TypeError) as e:
            client.suggest("bad_query")
            assert str(e) == "Expected a SuggestQuery for 'query', but got {}".format(repr("bad_query"))

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._search_post")
    def test_get_count_reset_continuation_token(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.search(search_text="search text")
        assert isinstance(result, ItemPaged)
        assert result._page_iterator_class is SearchPageIterator
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult({"key": "val"})]
        mock_search_post.return_value = search_result
        result.__next__()
        result._first_page_iterator_instance.continuation_token = "fake token"
        result.get_count()
        assert not result._first_page_iterator_instance.continuation_token

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._autocomplete_post")
    def test_autocomplete_query_argument(self, mock_autocomplete_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.autocomplete(search_text="search text", suggester_name="sg")
        assert mock_autocomplete_post.called
        assert mock_autocomplete_post.call_args[0] == ()
        # assert mock_autocomplete_post.call_args[1]["autocomplete_request"].search_text == "search text"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin.get_document_count")
    def test_get_document_count_v2020_06_30(self, mock_count):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        client.get_document_count()
        assert mock_count.called
        assert mock_count.call_args[0] == ()
        assert len(mock_count.call_args[1]) == 0

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._search_post")
    def test_get_facets_with_aggregations(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.search(search_text="*")

        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult({"id": "1"})]

        facet_bucket = FacetResult()
        facet_bucket.count = 4
        facet_bucket.avg = 120.5
        facet_bucket.min = 75.0
        facet_bucket.max = 240.0
        facet_bucket.cardinality = 3

        search_result.facets = {"baseRate": [facet_bucket]}
        mock_search_post.return_value = search_result

        next(result)
        facets = result.get_facets()

        assert facets is not None
        assert "baseRate" in facets
        assert len(facets["baseRate"]) == 1
        bucket = facets["baseRate"][0]
        assert bucket["count"] == 4
        assert bucket["avg"] == 120.5
        assert bucket["min"] == 75.0
        assert bucket["max"] == 240.0
        assert bucket["cardinality"] == 3

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin.get_document")
    def test_get_document_v2020_06_30(self, mock_get):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        client.get_document("some_key")
        assert mock_get.called
        assert mock_get.call_args[0] == ("some_key",)
        assert len(mock_get.call_args[1]) == 0

        mock_get.reset()

        client.get_document("some_key", selected_fields="foo")
        assert mock_get.called
        assert mock_get.call_args[0] == ("some_key",)
        assert len(mock_get.call_args[1]) == 1
        assert mock_get.call_args[1]["selected_fields"] == "foo"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._search_post")
    def test_search_query_argument_v2020_06_30(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        result = client.search(search_text="search text")
        assert isinstance(result, ItemPaged)
        assert result._page_iterator_class is SearchPageIterator
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult({"key": "val"})]
        mock_search_post.return_value = search_result
        assert not mock_search_post.called
        next(result)
        assert mock_search_post.called
        assert mock_search_post.call_args[0] == ()
        # assert mock_search_post.call_args[1]["search_request"].search_text == "search text"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._suggest_post")
    def test_suggest_query_argument_v2020_06_30(self, mock_suggest_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        result = client.suggest(search_text="search text", suggester_name="sg")
        assert mock_suggest_post.called
        assert mock_suggest_post.call_args[0] == ()
        # assert mock_suggest_post.call_args[1]["suggest_request"].search_text == "search text"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._autocomplete_post")
    def test_autocomplete_query_argument_v2020_06_30(self, mock_autocomplete_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        result = client.autocomplete(search_text="search text", suggester_name="sg")
        assert mock_autocomplete_post.called
        assert mock_autocomplete_post.call_args[0] == ()
        # assert mock_autocomplete_post.call_args[1]["autocomplete_request"].search_text == "search text"

    def test_autocomplete_bad_argument(self):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        with pytest.raises(TypeError) as e:
            client.autocomplete("bad_query")
            assert str(e) == "Expected a AutocompleteQuery for 'query', but got {}".format(repr("bad_query"))

    @pytest.mark.parametrize(
        "arg", [[], [{"doc1": "doc"}], [{"doc1": "doc"}, {"doc2": "doc"}]], ids=lambda x: str(len(x)) + " docs"
    )
    @pytest.mark.parametrize("method_name", CRUD_METHOD_NAMES)
    def test_add_method(self, arg, method_name):
        with mock.patch.object(SearchClient, "index_documents", return_value=None) as mock_index_documents:
            client = SearchClient("endpoint", "index name", CREDENTIAL)

            method = getattr(client, method_name)
            method(arg, extra="foo")

            assert mock_index_documents.called
            assert len(mock_index_documents.call_args[0]) == 1
            batch = mock_index_documents.call_args[0][0]
            assert isinstance(batch, IndexDocumentsBatch)
            assert all(action.action_type == CRUD_METHOD_MAP[method_name] for action in batch.actions)
            assert mock_index_documents.call_args[1]["extra"] == "foo"

    @mock.patch("azure.search.documents._operations._operations._SearchClientOperationsMixin._index")
    def test_index_documents(self, mock_index):
        client = SearchClient("endpoint", "index name", CREDENTIAL)

        batch = IndexDocumentsBatch()
        actions = batch.add_upload_actions([{"upload1": "doc"}])
        assert len(actions) == 1
        for x in actions:
            assert x.action_type == "upload"
        actions = batch.add_delete_actions([{"delete1": "doc"}, {"delete2": "doc"}])
        assert len(actions) == 2
        for x in actions:
            assert x.action_type == "delete"
        actions = batch.add_merge_actions([{"merge1": "doc"}, {"merge2": "doc"}, {"merge3": "doc"}])
        for x in actions:
            assert x.action_type == "merge"
        actions = batch.add_merge_or_upload_actions([{"merge_or_upload1": "doc"}])
        for x in actions:
            assert x.action_type == "mergeOrUpload"

        client.index_documents(batch, extra="foo")
        assert mock_index.called
        assert mock_index.call_args[0] == ()
        assert mock_index.call_args[1]["extra"] == "foo"

    def test_request_too_large_error(self):
        with mock.patch.object(
            SearchClient,
            "_index_documents_actions",
            side_effect=RequestEntityTooLargeError("Error"),
        ):
            client = SearchClient("endpoint", "index name", CREDENTIAL)
            batch = IndexDocumentsBatch()
            batch.add_upload_actions([{"upload1": "doc"}])
            with pytest.raises(RequestEntityTooLargeError):
                client.index_documents(batch, extra="foo")
