Files
NoteNextra-origin/toolboxes/mlp_image_reconstruction.py
2025-07-06 12:40:25 -05:00

82 lines
2.9 KiB
Python

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
class MLPScalar(torch.nn.Module):
# Define your MLPScalar architecture here
def __init__(self):
super(MLPScalar, self).__init__()
# Example architecture
self.fc1 = torch.nn.Linear(2, 128)
self.fc2 = torch.nn.Linear(128, 3) # Outputs RGB
def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x)) # Normalize output to [0, 1]
return x
class MLPPositional(torch.nn.Module):
# Define your MLPPositional architecture here
def __init__(self, num_frequencies=10, include_input=True):
super(MLPPositional, self).__init__()
# Example architecture
self.include_input = include_input
self.fc1 = torch.nn.Linear(2, 128)
self.fc2 = torch.nn.Linear(128, 3) # Outputs RGB
def forward(self, x):
if self.include_input:
# Process coordinates, add positional encoding here if needed
x = torch.cat([x, self.positional_encoding(x)], dim=-1)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x)) # Normalize output to [0, 1]
return x
def positional_encoding(self, x):
# Example positional encoding
return torch.cat([torch.sin(x * (2 ** i)) for i in range(10)], dim=-1)
if __name__ == '__main__':
# Load a real image
image_path = input()[1:-1] # Replace with your image file path
image = Image.open(image_path).convert('RGB')
# Normalize and resize the image
transform = transforms.Compose([
transforms.Resize((256, 256)), # Resize image to desired dimensions
transforms.ToTensor(), # Convert to Tensor and normalize to [0,1]
])
image_tensor = transform(image)
# Create dummy normalized coordinates (assume image coordinates normalized to [0,1])
coords = torch.rand(10, 2) # 10 random coordinate pairs
print("Input coordinates:")
print(coords)
# Test MLP with scalar input
model_scalar = MLPScalar()
out_scalar = model_scalar(coords)
print("\nMLPScalar output (RGB):")
print(out_scalar)
# Test MLP with positional encoding
model_positional = MLPPositional(num_frequencies=10, include_input=True)
out_positional = model_positional(coords)
print("\nMLPPositional output (RGB):")
print(out_positional)
# Optionally, use the output to create a new image
output_image = (out_positional.view(10, 1, 3) * 255).byte().numpy() # Reshape and scale
output_image = output_image.transpose(0, 2, 1) # Prepare for visualization
# Visualize the output
plt.figure(figsize=(10, 2))
for i in range(output_image.shape[0]):
plt.subplot(2, 5, i + 1)
plt.imshow(output_image[i].reshape(1, 3), aspect='auto')
plt.axis('off')
plt.show()