N, M = map(int, input().split())
matrix = [list(map(int, input())) for _ in range(N)]
dp_table = [[0] * M for _ in range(N)]
for i in range(N):
for j in range(M):
if matrix[i][j] == 1:
dp_table[i][j] = 1
if 0 <= i - 1 and 0 <= j - 1:
dp_table[i][j] = min(dp_table[i - 1][j], dp_table[i - 1][j - 1], dp_table[i][j - 1]) + 1
result = 0
for d in dp_table:
result = max(result, max(d))
print(result ** 2)
N, M = map(int, input().split())
def square(i, j, length):
for r in range(i, length + i):
for c in range(j, length + j):
if matrix[r][c] == 0:
return False
return True
matrix = [list(map(int, input())) for _ in range(N)]
l = N
flag = False
while flag == False:
for i in range(N):
if flag == True:
break
for j in range(M):
if i + l <= N and j + l <= M:
flag = square(i, j, l)
if flag == True:
break
else:
break
l -= 1
print((l + 1) ** 2)