#! /usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt

def ari(characters, words, sentences):
    return 4.71 * (characters / words) + 0.5 * (words / sentences) - 21.43

# Define realistic ranges for characters per word and words per sentence
characters_per_word_range = np.linspace(2, 15, 50)  # Characters per word (x-axis)
words_per_sentence_range = np.linspace(3, 30, 50)  # Words per sentence (y-axis)

# Create a meshgrid for the heatmap
characters_grid, words_grid = np.meshgrid(characters_per_word_range, words_per_sentence_range)

# Calculate ARI for each combination in the grid
ari_heatmap = ari(
    characters_grid * words_grid,  # Total characters
    words_grid,  # Total words
    1  # Assume one sentence
)

# Clip the ARI values to a realistic range (0-20)
ari_heatmap_clipped = np.clip(ari_heatmap, 1, 20)

# Plot the heatmap
plt.figure(figsize=(10, 8))
plt.contourf(
    characters_per_word_range, words_per_sentence_range, ari_heatmap_clipped,
    levels=40, cmap="plasma"  # Adjust levels for more granularity
)

cbar = plt.colorbar()
cbar.set_label("ARI Score (Clipped 0-20)", rotation=270, labelpad=20)
cbar.ax.text(1, -0.025, "Kindergarden", ha='left', va='center', fontsize=10, transform=cbar.ax.transAxes)
cbar.ax.text(1, 1.025, "Professor", ha='left', va='center', fontsize=10, transform=cbar.ax.transAxes)

plt.title("ARI Heatmap")
plt.xlabel("Characters per Word")
plt.ylabel("Words per Sentence")
plt.grid(alpha=0.3, linestyle="--")
plt.savefig("ari_plot.png", transparent=False)
