import numpy as np
from PIL import Image
import json
from pathlib import Path
import time

class QuickGlyphAnalyzer:
    def __init__(self, rotation_points=36, samples=5, percent_range=8):
        self.rotation_points = rotation_points
        self.samples = samples
        self.percent_range = percent_range
        
    def load_glyph(self, png_path):
        img = Image.open(png_path).convert('RGB')
        arr = np.array(img)
        h, w, _ = arr.shape
        
        points = []
        for y in range(h):
            for x in range(w):
                r, g, b = arr[y, x]
                if r > 5 or g > 5 or b > 5:
                    points.append({
                        'x': x - w/2,
                        'y': -(y - h/2)
                    })
        return points
    
    def cylindrical_transform(self, points):
        radii = np.abs([p['x'] for p in points])
        heights = np.array([p['y'] for p in points])
        
        angles = np.linspace(0, 2*np.pi, self.rotation_points, endpoint=False)
        radii_grid, angles_grid = np.meshgrid(radii, angles)
        heights_grid = np.tile(heights, (self.rotation_points, 1))
        
        x_3d = radii_grid * np.cos(angles_grid)
        z_3d = radii_grid * np.sin(angles_grid)
        y_3d = heights_grid
        
        return np.column_stack([x_3d.ravel(), y_3d.ravel(), z_3d.ravel()])
    
    def calculate_symmetry(self, positions):
        x = positions[:, 0]
        z = positions[:, 2]
        
        # Quadrant symmetry
        quad_counts = [
            ((x >= 0) & (z >= 0)).sum(),
            ((x < 0) & (z >= 0)).sum(),
            ((x < 0) & (z < 0)).sum(),
            ((x >= 0) & (z < 0)).sum()
        ]
        quadrant_sym = 1 - (max(quad_counts) - min(quad_counts)) / sum(quad_counts)
        
        # Reflective symmetry
        x_pos = (x >= 0).sum()
        x_neg = (x < 0).sum()
        reflective_sym = min(x_pos, x_neg) / max(x_pos, x_neg)
        
        # Rotational symmetry
        angles = np.arctan2(z, x)
        angle_bins = np.histogram(angles, bins=4)[0]
        rotational_sym = 1 - np.std(angle_bins) / (np.mean(angle_bins) + 1e-6)
        
        return (quadrant_sym + rotational_sym + reflective_sym) / 3
    
    def test_convergence(self, points):
        symmetries = []
        
        for percent in range(1, self.percent_range + 1):
            percent_sym = []
            sample_size = max(1, int(len(points) * (percent / 100.0)))
            
            for sample_idx in range(self.samples):
                np.random.seed(12345 + percent * 1000 + sample_idx)
                
                if sample_size >= len(points):
                    sampled = points
                else:
                    indices = np.random.choice(len(points), sample_size, replace=False)
                    sampled = [points[i] for i in indices]
                
                positions = self.cylindrical_transform(sampled)
                sym = self.calculate_symmetry(positions)
                percent_sym.append(sym)
            
            symmetries.append(np.mean(percent_sym))
        
        return np.mean(symmetries[-3:])
    
    def analyze_glyph(self, png_path):
        points = self.load_glyph(png_path)
        convergence = self.test_convergence(points)
        return convergence

def run_full_range_test(image_path, start_r=1, end_r=40):
    """Test ALL rotation_points from start_r to end_r"""
    results = {}
    total_tests = end_r - start_r + 1
    completed = 0
    
    print(f"Testing rotation_points {start_r} to {end_r}...")
    print("=" * 60)
    
    for r in range(start_r, end_r + 1):
        print(f"Testing r={r:2d}...", end="", flush=True)
        start_time = time.time()
        
        analyzer = QuickGlyphAnalyzer(
            rotation_points=r,
            samples=5,
            percent_range=8
        )
        
        try:
            convergence = analyzer.analyze_glyph(image_path)
            elapsed = time.time() - start_time
            
            results[r] = {
                'convergence': float(convergence),
                'time_sec': float(elapsed)
            }
            
            completed += 1
            print(f" ✓ {convergence:.6f} ({elapsed:.1f}s) [{completed}/{total_tests}]")
            
        except Exception as e:
            print(f" ✗ Error: {e}")
            results[r] = {'error': str(e)}
            completed += 1
    
    return results

def analyze_patterns(results):
    """Analyze the patterns in results"""
    print("\n" + "=" * 60)
    print("PATTERN ANALYSIS")
    print("=" * 60)
    
    valid_results = {r: v for r, v in results.items() if 'convergence' in v}
    
    # Group by modulo 4
    mod_groups = {0: [], 1: [], 2: [], 3: []}
    
    for r, data in valid_results.items():
        mod_groups[r % 4].append((r, data['convergence']))
    
    print("\nGrouped by r mod 4:")
    for mod in [0, 1, 2, 3]:
        if mod_groups[mod]:
            values = [v for _, v in mod_groups[mod]]
            print(f"  mod {mod}: {len(mod_groups[mod])} values, range {min(values):.4f}-{max(values):.4f}")
    
    # Check our formulas
    print("\nFormula Verification:")
    for r, data in valid_results.items():
        actual = data['convergence']
        
        if r % 4 == 0:  # Divisible by 4
            k = r / 8
            predicted = (12*k - 1) / (12*k)
            diff = abs(actual - predicted)
            if diff < 0.0001:
                print(f"  r={r:2d}: ✓ (12k-1)/(12k) with k={k:.1f}")
            else:
                print(f"  r={r:2d}: ✗ predicted {predicted:.6f}, got {actual:.6f}")
        
        elif r % 4 == 2:  # r ≡ 2 mod 4
            predicted = (r - 1) / r
            diff = abs(actual - predicted)
            if diff < 0.0001:
                print(f"  r={r:2d}: ✓ (r-1)/r = {predicted:.6f}")
            else:
                print(f"  r={r:2d}: ✗ predicted {predicted:.6f}, got {actual:.6f}")

# ====== MAIN ======
if __name__ == "__main__":
    GLYPH_PATH = "folder/test2/glyph_snapshot_2026-01-24T18-08-34-373Z.png"
    
    print(f"Analyzing: {Path(GLYPH_PATH).name}")
    print(f"Expected time: ~{40 * 0.7:.0f} seconds")
    
    # Run full range 1-40
    results = run_full_range_test(GLYPH_PATH, 1, 40)
    
    # Save results
    output_file = "full_range_1_to_40.json"
    with open(output_file, "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\nResults saved to {output_file}")
    
    # Analyze patterns
    analyze_patterns(results)
    
    # Generate summary table
    print("\n" + "=" * 60)
    print("SUMMARY TABLE (r, convergence, mod 4)")
    print("=" * 60)
    
    for r in sorted(results.keys()):
        if 'convergence' in results[r]:
            val = results[r]['convergence']
            mod = r % 4
            print(f"  r={r:2d} (mod {mod}): {val:.6f}")