N, M = map(int, input().split())
matrix = [list(map(int, input())) for _ inrange(N)]
dp_table = [[0] * M for _ inrange(N)]
for i inrange(N):
for j inrange(M):
if matrix[i][j] ==1:
dp_table[i][j] =1
if 0<= i -1and0<= j -1:
dp_table[i][j] =min(dp_table[i -1][j], dp_table[i -1][j -1], dp_table[i][j -1]) +1result=0for d in dp_table:
result=max(result, max(d))
print(result**2)
N, M = map(int, input().split())
defsquare(i, j, length):for r inrange(i, length + i):
for c inrange(j, length + j):
if matrix[r][c] == 0:
returnFalsereturnTrue
matrix = [list(map(int, input())) for _ inrange(N)]
l = N
flag = Falsewhile flag == False:
for i inrange(N):
if flag == True:
breakfor j inrange(M):
if i + l <= N and j + l <= M:
flag = square(i, j, l)
if flag == True:
breakelse:
break
l -= 1print((l + 1) ** 2)