import json
from thefuzz import fuzz
from thefuzz import process

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

def normalize_key(key):
    """Normalize key names to a standard format."""
    key = key.lower().strip()
    return key.replace("_", "")

def extract_articles_with_titles(data):
    """Extracts article identifiers and titles, adapting to various JSON structures."""
    articles = set()
    title_to_id = {}  # Map enriched titles to identifiers

    # Possible keys for identifiers, titles, abstracts, and authors
    id_keys = ["doi", "pubmed_link", "pubmedlink", "id"]
    title_keys = ["title"]
    abstract_keys = ["abstract", "summary"]
    author_keys = ["authors", "author"]

    def process_article(article):
        # Normalize article keys
        normalized_article = {normalize_key(k): v for k, v in article.items()}

        # Find identifier
        identifier = None
        for key in id_keys:
            if normalize_key(key) in normalized_article:
                identifier = normalized_article[normalize_key(key)]
                break
        if not identifier:
            identifier = str(hash(str(article)))  # Fallback unique ID

        # Find title
        title = ""
        for key in title_keys:
            if normalize_key(key) in normalized_article:
                title = str(normalized_article[normalize_key(key)]).lower().strip()
                break
        if not title:
            title = "untitled_" + identifier  # Fallback title

        # Find abstract and authors for enrichment
        abstract = ""
        for key in abstract_keys:
            if normalize_key(key) in normalized_article:
                abstract = str(normalized_article[normalize_key(key)]).lower().strip()[:100]
                break

        authors = ""
        for key in author_keys:
            if normalize_key(key) in normalized_article:
                auth_val = normalized_article[normalize_key(key)]
                if isinstance(auth_val, list):
                    authors = " ".join(str(a).lower().strip() for a in auth_val if a)
                else:
                    authors = str(auth_val).lower().strip()
                break

        # Combine into enriched title
        enriched_title = title
        if abstract:
            enriched_title += " " + abstract
        if authors:
            enriched_title += " " + authors

        if identifier:
            articles.add(identifier.lower())
            title_to_id[enriched_title] = identifier.lower()

    # Handle different input structures
    if isinstance(data, dict):
        # Look for common article-containing keys
        for key in ["pubmed articles", "articles", "data", "results"]:
            if normalize_key(key) in {normalize_key(k) for k in data.keys()}:
                articles_data = data[next(k for k in data.keys() if normalize_key(k) == normalize_key(key))]
                break
        else:
            articles_data = [data]  # Treat dict as single article
    elif isinstance(data, list):
        articles_data = data
    else:
        articles_data = []

    for article in articles_data:
        if isinstance(article, dict):
            process_article(article)

    return articles, title_to_id

def find_similar_titles(test_titles, benchmark_titles, threshold=70):
    """Find similar titles with multiple fuzzy matching strategies."""
    similar_pairs = {}
    similarity_scores = {}
    for test_title in test_titles:
        # Use multiple scorers for robustness
        token_sort_match = process.extractOne(test_title, benchmark_titles.keys(), scorer=fuzz.token_sort_ratio)
        token_set_match = process.extractOne(test_title, benchmark_titles.keys(), scorer=fuzz.token_set_ratio)
        partial_match = process.extractOne(test_title, benchmark_titles.keys(), scorer=fuzz.partial_ratio)

        matches = [m for m in [token_sort_match, token_set_match, partial_match] if m]
        if matches:
            match = max(matches, key=lambda x: x[1])
            if match[1] >= threshold:
                similar_pairs[test_title] = benchmark_titles[match[0]]
                similarity_scores[test_title] = match[1]
    return similar_pairs, similarity_scores

def calculate_metrics(test_file, benchmark_file, similarity_threshold=70, print_near_matches=False):
    test_data = load_json(test_file)
    benchmark_data = load_json(benchmark_file)

    # Extract articles and titles
    retrieved, test_title_to_id = extract_articles_with_titles(test_data)
    relevant, benchmark_title_to_id = extract_articles_with_titles(benchmark_data)

    # Find similar titles
    similar_titles, similarity_scores = find_similar_titles(test_title_to_id.keys(), benchmark_title_to_id, similarity_threshold)

    # Exact matches based on identifiers
    exact_matches = retrieved & relevant

    # Similar matches based on titles
    similar_matches = set(similar_titles.keys())

    # Metrics calculation
    true_positives = len(exact_matches) + len(similar_matches)
    false_positives = max(0, len(retrieved - relevant) - len(similar_matches))
    false_negatives = max(0, len(relevant - retrieved) - len(similar_matches))

    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    # Near matches for debugging
    near_matches = []
    if print_near_matches:
        for test_title in test_title_to_id.keys():
            if test_title not in similar_titles:
                matches = [
                    process.extractOne(test_title, benchmark_title_to_id.keys(), scorer=fuzz.token_sort_ratio),
                    process.extractOne(test_title, benchmark_title_to_id.keys(), scorer=fuzz.token_set_ratio),
                    process.extractOne(test_title, benchmark_title_to_id.keys(), scorer=fuzz.partial_ratio)
                ]
                matches = [m for m in matches if m]
                if matches:
                    best_match = max(matches, key=lambda x: x[1])
                    if similarity_threshold - 20 <= best_match[1] < similarity_threshold:
                        near_matches.append({
                            "Test Title": test_title,
                            "Benchmark Title": best_match[0],
                            "Score": best_match[1],
                            "Algorithm": ["token_sort_ratio", "token_set_ratio", "partial_ratio"][matches.index(best_match)]
                        })

    # Similarity report
    similarity_report = {
        "Exact Matches": list(exact_matches),
        "Similar Title Matches": [
            {
                "Test Title": test_title,
                "Benchmark Title": next(t for t, id in benchmark_title_to_id.items() if id == similar_titles[test_title]),
                "Benchmark ID": similar_titles[test_title],
                "Similarity Score": similarity_scores[test_title]
            }
            for test_title in similar_titles.keys()
        ],
        "Unmatched Test Articles": [t for t in test_title_to_id.keys() if t not in similar_titles and test_title_to_id[t] not in exact_matches],
        "Unmatched Benchmark Articles": [t for t in benchmark_title_to_id.keys() if benchmark_title_to_id[t] not in (set(similar_titles.values()) | exact_matches)]
    }

    return {
        "True Positives": true_positives,
        "False Positives": false_positives,
        "False Negatives": false_negatives,
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1_score,
        "Similar Articles Found": len(similar_matches),
        "Near Matches (Below Threshold)": near_matches if print_near_matches else [],
        "Similarity Report": similarity_report
    }

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Evaluate article matching against a benchmark.')
    parser.add_argument('--test_file', default='test.json', help='Path to test file')
    parser.add_argument('--benchmark_file', default='benchmark.json', help='Path to benchmark file')
    parser.add_argument('--threshold', type=int, default=70, help='Similarity threshold (0-100)')
    parser.add_argument('--near_matches', action='store_true', help='Show detailed scores for near-matches')
    args = parser.parse_args()

    metrics = calculate_metrics(args.test_file, args.benchmark_file, args.threshold, args.near_matches)

    print("Evaluation Metrics:")
    for key, value in metrics.items():
        if key == "Similarity Report":
            print(f"\n{key}:")
            print(f"  Exact Matches ({len(value['Exact Matches'])}):")
            for match in value['Exact Matches'][:5]:
                print(f"    - {match}")
            if len(value['Exact Matches']) > 5:
                print(f"    ... and {len(value['Exact Matches']) - 5} more")
            print(f"  Similar Title Matches ({len(value['Similar Title Matches'])}):")
            for match in value['Similar Title Matches']:
                print(f"    Test Title: {match['Test Title']}")
                print(f"    Benchmark Title: {match['Benchmark Title']}")
                print(f"    Benchmark ID: {match['Benchmark ID']}")
                print(f"    Similarity Score: {match['Similarity Score']}%")
                print()
            print(f"  Unmatched Test Articles ({len(value['Unmatched Test Articles'])}):")
            for title in value['Unmatched Test Articles'][:5]:
                print(f"    - {title}")
            if len(value['Unmatched Test Articles']) > 5:
                print(f"    ... and {len(value['Unmatched Test Articles']) - 5} more")
            print(f"  Unmatched Benchmark Articles ({len(value['Unmatched Benchmark Articles'])}):")
            for title in value['Unmatched Benchmark Articles'][:5]:
                print(f"    - {title}")
            if len(value['Unmatched Benchmark Articles']) > 5:
                print(f"    ... and {len(value['Unmatched Benchmark Articles']) - 5} more")
        elif key == "Near Matches (Below Threshold)" and value:
            print(f"\n{key} ({len(value)}):")
            for match in value:
                print(f"  Test Title: {match['Test Title']}")
                print(f"  Benchmark Title: {match['Benchmark Title']}")
                print(f"  Score: {match['Score']}%")
                print(f"  Best Algorithm: {match['Algorithm']}")
                print()
        elif isinstance(value, float):
            print(f"{key}: {value:.4f}")
        else:
            print(f"{key}: {value}")