import numpy as np
from PIL import Image
import json
from pathlib import Path
import time

class QuickGlyphAnalyzer:
    def __init__(self, rotation_points=36, samples=5, percent_range=10):
        self.rotation_points = rotation_points
        self.samples = samples  # Reduced from 90 to 5!
        self.percent_range = percent_range  # Test only up to 10% sampling
        
    def load_glyph(self, png_path):
        img = Image.open(png_path).convert('RGB')
        arr = np.array(img)
        h, w, _ = arr.shape
        
        points = []
        for y in range(h):
            for x in range(w):
                r, g, b = arr[y, x]
                if r > 5 or g > 5 or b > 5:
                    points.append({
                        'x': x - w/2,
                        'y': -(y - h/2)
                    })
        return points
    
    def cylindrical_transform(self, points):
        """Vectorized for speed"""
        radii = np.abs([p['x'] for p in points])
        heights = np.array([p['y'] for p in points])
        
        angles = np.linspace(0, 2*np.pi, self.rotation_points, endpoint=False)
        radii_grid, angles_grid = np.meshgrid(radii, angles)
        heights_grid = np.tile(heights, (self.rotation_points, 1))
        
        x_3d = radii_grid * np.cos(angles_grid)
        z_3d = radii_grid * np.sin(angles_grid)
        y_3d = heights_grid
        
        return np.column_stack([x_3d.ravel(), y_3d.ravel(), z_3d.ravel()])
    
    def calculate_symmetry(self, positions):
        x = positions[:, 0]
        z = positions[:, 2]
        
        # Quadrant symmetry
        quad_counts = [
            ((x >= 0) & (z >= 0)).sum(),
            ((x < 0) & (z >= 0)).sum(),
            ((x < 0) & (z < 0)).sum(),
            ((x >= 0) & (z < 0)).sum()
        ]
        quadrant_sym = 1 - (max(quad_counts) - min(quad_counts)) / sum(quad_counts)
        
        # Reflective symmetry
        x_pos = (x >= 0).sum()
        x_neg = (x < 0).sum()
        reflective_sym = min(x_pos, x_neg) / max(x_pos, x_neg)
        
        # Simplified rotational symmetry (faster)
        angles = np.arctan2(z, x)
        angle_bins = np.histogram(angles, bins=4)[0]  # Reduced from 8 to 4
        rotational_sym = 1 - np.std(angle_bins) / (np.mean(angle_bins) + 1e-6)
        
        return (quadrant_sym + rotational_sym + reflective_sym) / 3
    
    def test_convergence(self, points):
        """Quick test with minimal sampling"""
        symmetries = []
        
        for percent in range(1, self.percent_range + 1):
            percent_sym = []
            sample_size = max(1, int(len(points) * (percent / 100.0)))
            
            for _ in range(self.samples):
                np.random.seed(12345 + percent * 1000 + _)
                
                if sample_size >= len(points):
                    sampled = points
                else:
                    indices = np.random.choice(len(points), sample_size, replace=False)
                    sampled = [points[i] for i in indices]
                
                positions = self.cylindrical_transform(sampled)
                sym = self.calculate_symmetry(positions)
                percent_sym.append(sym)
            
            symmetries.append(np.mean(percent_sym))
        
        return np.mean(symmetries[-3:])  # Average of last 3 percentages
    
    def analyze_glyph(self, png_path):
        points = self.load_glyph(png_path)
        convergence = self.test_convergence(points)
        return convergence

# ====== MAIN TEST LOOP ======
def run_range_test(image_path, rotation_range):
    """Test multiple rotation_points quickly"""
    results = {}
    
    for r in rotation_range:
        print(f"Testing rotation_points = {r}...", end="")
        start_time = time.time()
        
        analyzer = QuickGlyphAnalyzer(
            rotation_points=r,
            samples=5,      # Minimal samples
            percent_range=8 # Only 1-8% sampling
        )
        
        convergence = analyzer.analyze_glyph(image_path)
        elapsed = time.time() - start_time
        
        results[r] = {
            'convergence': float(convergence),
            'time_sec': float(elapsed)
        }
        
        print(f" {convergence:.6f} ({elapsed:.1f}s)")
    
    return results

# ====== USAGE ======
if __name__ == "__main__":
    # Your glyph image
    GLYPH_PATH = "folder/test2/glyph_snapshot_2026-01-24T18-08-34-373Z.png"
    
    # Test range - choose one:
    
    # Option 1: Quick verification of our families
    test_range = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 32, 34, 36]  # ~1.5 minutes total
    
    # Option 2: Map Families C & D further
    # test_range = [30, 34, 38, 42]  # Family progression
    
    # Option 3: Complete modulo 8 mapping
    # test_range = list(range(2, 43, 2))  # All even numbers 2-42
    
    print(f"Testing {len(test_range)} rotation_points on {Path(GLYPH_PATH).name}")
    print("=" * 50)
    
    results = run_range_test(GLYPH_PATH, test_range)
    
    # Save results
    with open("quick_range_test.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print("\n" + "=" * 50)
    print("Results saved to quick_range_test.json")
    
    # Quick analysis
    print("\nConvergence values:")
    for r in sorted(results.keys()):
        val = results[r]['convergence']
        mod = r % 8
        print(f"  r={r:2d} (mod {mod}) → {val:.6f}")