출처 : 백준, https://www.acmicpc.net/problem/14888
14888번: 연산자 끼워넣기
첫째 줄에 수의 개수 N(2 ≤ N ≤ 11)가 주어진다. 둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 100) 셋째 줄에는 합이 N-1인 4개의 정수가 주어지는데, 차례대로 덧셈(+)의 개수, 뺄셈(-)의 개수,
www.acmicpc.net
첫 번째 풀이 : 시간초과
from itertools import permutations
def solution(num_arr, char_arr):
per_char = permutations(char_arr, len(char_arr))
max_val = -1000000001
min_val = 1000000001
for per in per_char:
eq = num_arr[0]
for index in range(len(per)):
eq = eval(eq + per[index] + num_arr[index+1])
if eq >= 0:
eq = str(eq)
else:
eq = str(eq + 1)
if min_val > int(eq):
min_val = int(eq)
if max_val < int(eq):
max_val = int(eq)
print(max_val, min_val, sep='\n')
if __name__ == '__main__':
sym = ['+', '-', '*', '//']
n = int(input())
num_arr = input().split()
temp = list(map(int, input().split()))
char_arr = []
for index in range(4):
char_arr += [sym[index]] * temp[index]
solution(num_arr, char_arr)
permutation을 다 구해서 하나씩 가져오고
이 permutation을 구성하는 기호와 숫자를 하나씩 가져와서 항의 2개인 식을 만들어 eval로 계산하는 것을 반복해
min, max 값을 구해 출력하는 알고리즘이다.
이 때까지는 중복이 나올 것 같긴 했는데 그보다 permutations 함수가 시간 초과의 원인인 줄 알았다.
두 번째 풀이 : 시간 초과
import sys
max_val = -1000000001
min_val = 1000000001
def solution(num_arr, char_arr, depth = 0):
global max_val
global min_val
if depth == len(char_arr):
eq = num_arr[0]
for index in range(len(char_arr)):
eq = eval(eq + char_arr[index] + num_arr[index+1])
if eq < 0:
eq += 1
eq = str(eq)
eq = int(eq)
if min_val > eq:
min_val = eq
if max_val < eq:
max_val = eq
return
for index in range(depth, len(char_arr)):
char_arr[depth], char_arr[index] = char_arr[index], char_arr[depth]
solution(num_arr, char_arr, depth + 1)
char_arr[depth], char_arr[index] = char_arr[index], char_arr[depth]
if __name__ == '__main__':
sym = ['+', '-', '*', '//']
n = int(sys.stdin.readline().rstrip())
num_arr = sys.stdin.readline().rstrip().split()
temp = list(map(int, sys.stdin.readline().rstrip().split()))
char_arr = []
for index in range(4):
char_arr += [sym[index]] * temp[index]
solution(num_arr, char_arr)
print(max_val, min_val, sep='\n')
근데 아니더라. permutations 함수를 사용하는 것이 아니라 swap하는 방식으로 풀어보았는데
이 방식도 시간초과가 나오는걸 보고 그제야 아 중복을 제거해야하는구나 하는 생각이 들어
set 자료형을 사용하기로 생각했다.
세 번째 풀이 : 틀림
from itertools import permutations
def solution(num_arr, char_arr):
per_char = set(permutations(char_arr, len(char_arr)))
max_val = -1000000001
min_val = 1000000001
for per in per_char:
eq = num_arr[0]
for index in range(len(per)):
eq = eval(eq + per[index] + num_arr[index+1])
if eq < 0 and per[index] == '//':
eq += 1
eq = str(eq)
if min_val > int(eq):
min_val = int(eq)
if max_val < int(eq):
max_val = int(eq)
print(max_val, min_val, sep='\n')
if __name__ == '__main__':
sym = ['+', '-', '*', '//']
n = int(input())
num_arr = input().split()
temp = list(map(int, input().split()))
char_arr = []
for index in range(4):
char_arr += [sym[index]] * temp[index]
solution(num_arr, char_arr)
반례
3
1 2 1
0 1 0 1
set을 통해 중복제거를 해주니 드디어 시간초과가 아니라 '틀렸습니다'가 나왔다.어디가 틀렸는지 감을 못 잡고 있다가 반례를 보고 음수 나누기에서 나누어 떨어질 경우에 에러가 난다는 것을 알게 되어나누어 떨어지지 않을 때로 수정해야겠다고 생각했다.
네 번째 풀이 : 틀림
from itertools import permutations
def solution(num_arr, char_arr):
per_char = set(permutations(char_arr, len(char_arr)))
max_val = -1000000001
min_val = 1000000001
for per in per_char:
eq = num_arr[0]
for index in range(len(per)):
eq = eval(eq + per[index] + num_arr[index+1])
if eq < 0 and per[index] == '/' and eq != int(eq):
eq += 1
eq = str(int(eq))
eq = int(eq)
if min_val > eq:
min_val = eq
if max_val < eq:
max_val = eq
print(max_val, min_val, sep='\n')
if __name__ == '__main__':
sym = ['+', '-', '*', '/']
n = int(input())
num_arr = input().split()
temp = list(map(int, input().split()))
char_arr = []
for index in range(4):
char_arr += [sym[index]] * temp[index]
solution(num_arr, char_arr)
반례
2
-3 2
0 0 0 1
그런데도 틀려서 나눗셈을 처리하는 과정에서 문제가 생긴 것 같아 유심히 살펴보니
+1을 하는 처리가 잘못되었더라......왜 저렇게 처리했는지 이해가 안간다.
다섯 번째 풀이 : 성공
from itertools import permutations
def solution(num_arr, char_arr):
per_arr = set(permutations(char_arr, len(char_arr)))
max_val = -1000000001
min_val = 1000000001
for per in per_arr:
eq = num_arr[0]
for index in range(len(per)):
eq = eval(eq + per[index] + num_arr[index+1])
eq = str(int(eq))
eq = int(eq)
if min_val > eq:
min_val = eq
if max_val < eq:
max_val = eq
print(max_val, min_val, sep='\n')
if __name__ == '__main__':
sym = ['+', '-', '*', '/']
n = int(input())
num_arr = input().split()
temp = list(map(int, input().split()))
char_arr = []
for index in range(4):
char_arr += [sym[index]] * temp[index]
solution(num_arr, char_arr)
결국 +1이 아니라 int 함수를 통해 뒤의 소수점을 제거하여 음수 나눗셈에서도 정상값이 나오도록 바꾸었다.
로직은 다음과 같다.
1. 숫자들을 정수가 아니라 문자로 받아 저장한다. (eval 함수를 사용할 식 문자열을 만들기 위해)
2. permutations 함수를 통해 모든 기호의 조합을 만든 후 이를 집합으로 만들어 중복을 제거한다.
3. 숫자 리스트에서 첫 숫자를 식에 넣고 다음부터 기호 하나와 숫자 하나를 가져와서 이항식을 만든다.
4. eval 함수로 이를 계산하고 소수점이 있다면 int로 날려버린 후 모든 기호에 대해 이를 반복한다.
5. 최댓값, 최솟값을 갱신하고 다음 순열을 가져와 반복
시간 복잡도
-
다른 사람의 풀이를 보면서 알게 된 점
-
고찰
먼저 알고리즘 문제에서 시간 초과는 알고리즘적 문제보다 일단 중복을 먼저 제거해 보아야한다.
저번에 게임맵 최단거리 문제에서도 Queue내의 중복을 제거해야 시간 내로 풀 수 있는 문제였으니
알고리즘 문제에서 시간 초과가 나올 경우 알고리즘적 요소보다는 중복된 요소가 시간을 많이 잡아 먹는 경우가 많은 듯 하다.
알고리즘적 요소는 중복을 제거한 다음에 생각해보면 될 것 같다.
둘째로, permtations 함수를 쓰면 시간초과에 걸릴 줄 알아서
permutations 함수를 쓰지 않으려고 N과 M(1)까지 풀면서 다른 방법을 생각해봤는데
다시 한 번 깨닫지만, 가장 간단하고 무식한 방법을 생각하면 그걸 실천해 보아야한다.
컴퓨터는 생각보다 빠르기 때문이다.
'알고리즘 > Python' 카테고리의 다른 글
[python] 백준 14719 - 빗물 (0) | 2022.03.19 |
---|---|
[python] 백준 2504 - 괄호의 값 (0) | 2022.03.19 |
[python] 백준 15649 - N과 M(1) (0) | 2022.03.18 |
[python] 백준 6588 - 골드바흐의 추측 (0) | 2022.03.18 |
[python] 백준 10610 - 30 (0) | 2022.03.17 |